In [1]:
import math

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import dataset

import numpy as np
import matplotlib.pyplot as plt

In [2]:
class MultiHeadAttention(nn.Module):
    '''
    d_model: word embedding length
    n_head: number of attention heads
    d_k: word embedding is split across multiple heads. This is their new length
    '''
    def __init__(self, d_k, d_model, n_heads, max_len, causal=False):
        super().__init__()

        self.d_k = d_k
        self.n_heads = n_heads

        self.key = nn.Linear(d_model, d_k * n_heads)
        self.query = nn.Linear(d_model, d_k * n_heads)
        self.value = nn.Linear(d_model, d_k * n_heads)

        self.fc = nn.Linear(d_k * n_heads, d_model)

        self.causal = causal
        if self.causal:
            cm = torch.tril(torch.ones(max_len, max_len))
            self.register_buffer('causal_mask', cm.view(1, 1, max_len, max_len))

    def forward(self, q, k, v, pad_mask=None):
        q = self.query(q).to(torch.float32)  # Explicitly cast to float32
        k = self.key(k).to(torch.float32)
        v = self.value(v).to(torch.float32)

        N = q.shape[0]
        T_output = q.shape[1]
        T_input = k.shape[1]

        q = q.view(N, T_output, self.n_heads, self.d_k).transpose(1, 2)
        k = k.view(N, T_input, self.n_heads, self.d_k).transpose(1, 2)
        v = v.view(N, T_input, self.n_heads, self.d_k).transpose(1, 2)

        attn_scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)
        if pad_mask is not None:
            attn_scores = attn_scores.masked_fill(pad_mask[:, None, None, :] == 0, float('-inf'))
        if self.causal:
            attn_scores = attn_scores.masked_fill(self.causal_mask[:, :, :T_output, :T_input] == 0, float('-inf'))
        attn_weights = F.softmax(attn_scores, dim=-1)

        A = torch.matmul(attn_weights, v)

        A = A.transpose(1, 2)
        A = A.contiguous().view(N, T_output, self.d_k * self.n_heads)

        return self.fc(A)

class EncoderBlock(nn.Module):
    def __init__(self, d_k, d_model, n_heads, max_len, dropout_prob=0.1):
        super().__init__()

        self.ln1 = nn.LayerNorm(d_model)
        self.ln2 = nn.LayerNorm(d_model)
        self.mha = MultiHeadAttention(d_k, d_model, n_heads, max_len, causal=False)
        self.ffn = nn.Sequential(
            nn.Linear(d_model, d_model * 4),
            nn.GELU(),
            nn.Linear(d_model * 4, d_model),
            nn.Dropout(dropout_prob)
        )
        self.dropout = nn.Dropout(dropout_prob)

    def forward(self, x, pad_mask=None):
        x = self.ln1(x + self.mha(x, x, x, pad_mask))
        x = self.ln2(x + self.ffn(x))
        x = self.dropout(x)
        return x

class DecoderBlock(nn.Module):
    def __init__(self, d_k, d_model, n_heads, max_len, dropout_prob=0.1):
        super().__init__()

        self.ln1 = nn.LayerNorm(d_model)
        self.ln2 = nn.LayerNorm(d_model)
        self.ln3 = nn.LayerNorm(d_model)
        self.mha1 = MultiHeadAttention(d_k, d_model, n_heads, max_len, causal=True)
        self.mha2 = MultiHeadAttention(d_k, d_model, n_heads, max_len, causal=False)
        self.ffn = nn.Sequential(
            nn.Linear(d_model, d_model * 4),
            nn.GELU(),
            nn.Linear(d_model * 4, d_model),
            nn.Dropout(dropout_prob)
        )
        self.dropout = nn.Dropout(dropout_prob)

    def forward(self, enc_output, dec_input, enc_mask=None, dec_mask=None):
        x = self.ln1(dec_input + self.mha1(dec_input, dec_input, dec_input, dec_mask))
        x = self.ln2(x + self.mha2(x, enc_output, enc_output, enc_mask))
        x = self.ln3(x + self.ffn(x))
        x = self.dropout(x)
        return x

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=2048, dropout_prob=0.1):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout_prob)

        position = torch.arange(max_len).unsqueeze(1)
        exp_term = torch.arange(0, d_model, 2)
        div_term = torch.exp(exp_term * (-math.log(10_000.0) / d_model))
        pe = torch.zeros(1, max_len, d_model)
        pe[0, :, 0::2] = torch.sin(position * div_term)
        pe[0, :, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + self.pe[:, :x.size(1), :]
        return self.dropout(x)

class Encoder(nn.Module):
    def __init__(self, vocab_size, max_len, d_k, d_model, n_heads, n_layers, dropout_prob=0.1):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.pos_encoding = PositionalEncoding(d_model, max_len, dropout_prob)
        transformer_blocks = [
            EncoderBlock(d_k, d_model, n_heads, max_len, dropout_prob) for _ in range(n_layers)
        ]
        self.transformer_blocks = nn.Sequential(*transformer_blocks)
        self.ln = nn.LayerNorm(d_model)

    def forward(self, x, pad_mask=None):
        x = self.embedding(x)
        x = self.pos_encoding(x)
        for block in self.transformer_blocks:
            x = block(x, pad_mask)
        x = self.ln(x)
        return x

class Decoder(nn.Module):
    def __init__(self, vocab_size, max_len, d_k, d_model, n_heads, n_layers, dropout_prob=0.1):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.pos_encoding = PositionalEncoding(d_model, max_len, dropout_prob)
        transformer_blocks = [
            DecoderBlock(d_k, d_model, n_heads, max_len, dropout_prob) for _ in range(n_layers)
        ]
        self.transformer_blocks = nn.Sequential(*transformer_blocks)
        self.ln = nn.LayerNorm(d_model)
        self.fc = nn.Linear(d_model, vocab_size)

    def forward(self, enc_output, dec_input, enc_mask=None, dec_mask=None):
        x = self.embedding(dec_input)
        x = self.pos_encoding(x)
        for block in self.transformer_blocks:
            x = block(enc_output, x, enc_mask, dec_mask)
        x = self.ln(x)
        x = self.fc(x) # many-to-many
        return x

class Transformer(nn.Module):
    def __init__(self, encoder, decoder):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder

    def forward(self, enc_input, dec_input, enc_mask, dec_mask):
        enc_output = self.encoder(enc_input, enc_mask)
        dec_output = self.decoder(enc_output, dec_input, enc_mask, dec_mask)
        return dec_output


In [3]:
encoder = Encoder(vocab_size=20_000,
                  max_len=512,
                  d_k=16,
                  d_model=64,
                  n_heads=4,
                  n_layers=2,
                  dropout_prob=0.1)
decoder = Decoder(vocab_size=10_000,
                  max_len=512,
                  d_k=16,
                  d_model=64,
                  n_heads=4,
                  n_layers=2,
                  dropout_prob=0.1)
transformer = Transformer(encoder, decoder)

In [4]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print(device)
encoder.to(device)
decoder.to(device)
transformer.to(device)


cuda:0


Transformer(
  (encoder): Encoder(
    (embedding): Embedding(20000, 64)
    (pos_encoding): PositionalEncoding(
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (transformer_blocks): Sequential(
      (0): EncoderBlock(
        (ln1): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
        (ln2): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
        (mha): MultiHeadAttention(
          (key): Linear(in_features=64, out_features=64, bias=True)
          (query): Linear(in_features=64, out_features=64, bias=True)
          (value): Linear(in_features=64, out_features=64, bias=True)
          (fc): Linear(in_features=64, out_features=64, bias=True)
        )
        (ffn): Sequential(
          (0): Linear(in_features=64, out_features=256, bias=True)
          (1): GELU(approximate='none')
          (2): Linear(in_features=256, out_features=64, bias=True)
          (3): Dropout(p=0.1, inplace=False)
        )
        (dropout): Dropout(p=0.1, inplace=False)
      )
  

In [5]:

import numpy as np
import torch

# Generate input data as integers for embedding, but convert to float32 for further processing
xe = np.random.randint(0, 20_000, size=[8, 512]) # B x Seq Length
xe_t = torch.tensor(xe, dtype=torch.long).to(device)  # Use Long type for embedding indices

xd = np.random.randint(0, 10_000, size=[8, 256]) # B x Seq Length
xd_t = torch.tensor(xd, dtype=torch.long).to(device)  # Use Long type for embedding indices

# Create masks as float32
maske = np.ones((8, 512), dtype=np.float32)
maske[:, 256:] = 0
maske_t = torch.tensor(maske, dtype=torch.float32).to(device)

maskd = np.ones((8, 256), dtype=np.float32)
maskd[:, 128:] = 0
maskd_t = torch.tensor(maskd, dtype=torch.float32).to(device)

# Convert input tensors and masks to float32 (although xe_t and xd_t are already long for embeddings)
maske_t = maske_t.to(torch.float32)
maskd_t = maskd_t.to(torch.float32)

In [6]:
# Assuming transformer is your model
# Forward pass
print("shapeshape is " + str(xd_t.shape))
out = transformer(xe_t, xd_t, maske_t, maskd_t)
print(out.shape)  # Output shape check

shapeshape is torch.Size([8, 256])
torch.Size([8, 256, 10000])


In [7]:
import pandas as pd
df = pd.read_csv('spa.txt', sep="\t", header=None)
df.head()

Unnamed: 0,0,1
0,Go.,Ve.
1,Go.,Vete.
2,Go.,Vaya.
3,Hi.,Hola.
4,Run!,¡Corre!


In [8]:
df.shape

(115245, 2)

In [9]:
df = df.iloc[:30_000]
df.columns = ['en', 'es']
df.to_csv('spa.csv', index=None)

In [10]:
from datasets import load_dataset
raw_dataset = load_dataset('csv', data_files='spa.csv')

Generating train split: 0 examples [00:00, ? examples/s]

In [11]:
raw_dataset

DatasetDict({
    train: Dataset({
        features: ['en', 'es'],
        num_rows: 30000
    })
})

In [12]:
split = raw_dataset['train'].train_test_split(test_size=0.3, seed=42)
split

DatasetDict({
    train: Dataset({
        features: ['en', 'es'],
        num_rows: 21000
    })
    test: Dataset({
        features: ['en', 'es'],
        num_rows: 9000
    })
})

In [13]:
from transformers import AutoTokenizer

model_checkpoint = "Helsinki-NLP/opus-mt-en-es"
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)

In [14]:
en_sentence = split["train"][0]["en"]
es_sentence = split["train"][0]["es"]

inputs = tokenizer(en_sentence)
targets = tokenizer(text_target=es_sentence)

tokenizer.convert_ids_to_tokens(targets['input_ids'])

['▁Yo', '▁puedo', '▁arreglarlo', '.', '</s>']

In [15]:
es_sentence

'Yo puedo arreglarlo.'

In [16]:
max_input_length = 128
max_target_length = 128

def preprocess_function(batch):
    model_inputs = tokenizer(batch['en'], max_length=max_input_length, truncation=True)
    labels = tokenizer(text_target=batch['es'], max_length=max_target_length, truncation=True)
    
    model_inputs["labels"] = labels["input_ids"]
    return model_inputs

In [17]:
tokenized_datasets = split.map(preprocess_function, batched=True, remove_columns=split["train"].column_names)

Map:   0%|          | 0/21000 [00:00<?, ? examples/s]

Map:   0%|          | 0/9000 [00:00<?, ? examples/s]

In [18]:
tokenized_datasets

DatasetDict({
    train: Dataset({
        features: ['input_ids', 'attention_mask', 'labels'],
        num_rows: 21000
    })
    test: Dataset({
        features: ['input_ids', 'attention_mask', 'labels'],
        num_rows: 9000
    })
})

In [19]:
from transformers import DataCollatorForSeq2Seq

data_collator = DataCollatorForSeq2Seq(tokenizer)

In [20]:
batch = data_collator([tokenized_datasets["train"][i] for i in range(0, 5)])
batch.keys()

dict_keys(['input_ids', 'attention_mask', 'labels'])

In [21]:
batch['input_ids']

tensor([[   33,    88,  9222,    48,     3,     0, 65000, 65000],
        [  552, 11490,     9,   310,   255,     3,     0, 65000],
        [  143,    31,   125,  1208,     3,     0, 65000, 65000],
        [ 1093,   220,  1890,    23,    48,     3,     0, 65000],
        [  124,    20,   100, 18422,    48,   141,     3,     0]])

In [22]:
batch['attention_mask']

tensor([[1, 1, 1, 1, 1, 1, 0, 0],
        [1, 1, 1, 1, 1, 1, 1, 0],
        [1, 1, 1, 1, 1, 1, 0, 0],
        [1, 1, 1, 1, 1, 1, 1, 0],
        [1, 1, 1, 1, 1, 1, 1, 1]])

In [23]:
batch['labels']

tensor([[  711,  1039, 44159,     3,     0,  -100,  -100,  -100],
        [ 2722, 18663,   239,   212,     3,     0,  -100,  -100],
        [  539,    43,   155,   960,     3,     0,  -100,  -100],
        [15165,  1250,   380,  3564,    36,  1016,     3,     0],
        [  350,     8, 19153,    29, 31326,     3,     0,  -100]])

In [24]:
tokenizer.all_special_ids

[0, 1, 65000]

In [25]:
tokenizer.all_special_tokens

['</s>', '<unk>', '<pad>']

In [26]:
tokenizer('<pad>')

{'input_ids': [65000, 0], 'attention_mask': [1, 1]}

In [27]:
from torch.utils.data import DataLoader

train_loader = DataLoader(
    tokenized_datasets["train"],
    shuffle=True,
    batch_size=32,
    collate_fn=data_collator
)
valid_loader = DataLoader(
    tokenized_datasets["test"],
    batch_size=32,
    collate_fn=data_collator
)

In [28]:
for batch in train_loader:
    for k, v in batch.items():
        print("k:", k, "v.shape:", v.shape)
    break

k: input_ids v.shape: torch.Size([32, 9])
k: attention_mask v.shape: torch.Size([32, 9])
k: labels v.shape: torch.Size([32, 9])


In [29]:
tokenizer.vocab_size

65001

In [30]:
tokenizer.decode([60000])

'ѕэр'

In [31]:
tokenizer.add_special_tokens({"cls_token": "<s>"})
tokenizer("<s>")

{'input_ids': [65001, 0], 'attention_mask': [1, 1]}

In [32]:
tokenizer.vocab_size

65001

In [33]:
encoder = Encoder(vocab_size = tokenizer.vocab_size + 1,
                  max_len=512,
                  d_k=16,
                  d_model=64,
                  n_heads=4,
                  n_layers=2,
                  dropout_prob=0.1)
decoder = Decoder(vocab_size = tokenizer.vocab_size + 1,
                  max_len=512,
                  d_k=16,
                  d_model=64,
                  n_heads=4,
                  n_layers=2,
                  dropout_prob=0.1)
transformer = Transformer(encoder, decoder)
transformer.to(device)

Transformer(
  (encoder): Encoder(
    (embedding): Embedding(65002, 64)
    (pos_encoding): PositionalEncoding(
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (transformer_blocks): Sequential(
      (0): EncoderBlock(
        (ln1): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
        (ln2): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
        (mha): MultiHeadAttention(
          (key): Linear(in_features=64, out_features=64, bias=True)
          (query): Linear(in_features=64, out_features=64, bias=True)
          (value): Linear(in_features=64, out_features=64, bias=True)
          (fc): Linear(in_features=64, out_features=64, bias=True)
        )
        (ffn): Sequential(
          (0): Linear(in_features=64, out_features=256, bias=True)
          (1): GELU(approximate='none')
          (2): Linear(in_features=256, out_features=64, bias=True)
          (3): Dropout(p=0.1, inplace=False)
        )
        (dropout): Dropout(p=0.1, inplace=False)
      )
  

In [34]:
print(device)
encoder.to(device)
decoder.to(device)

cuda:0


Decoder(
  (embedding): Embedding(65002, 64)
  (pos_encoding): PositionalEncoding(
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (transformer_blocks): Sequential(
    (0): DecoderBlock(
      (ln1): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
      (ln2): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
      (ln3): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
      (mha1): MultiHeadAttention(
        (key): Linear(in_features=64, out_features=64, bias=True)
        (query): Linear(in_features=64, out_features=64, bias=True)
        (value): Linear(in_features=64, out_features=64, bias=True)
        (fc): Linear(in_features=64, out_features=64, bias=True)
      )
      (mha2): MultiHeadAttention(
        (key): Linear(in_features=64, out_features=64, bias=True)
        (query): Linear(in_features=64, out_features=64, bias=True)
        (value): Linear(in_features=64, out_features=64, bias=True)
        (fc): Linear(in_features=64, out_features=64, bias=True)
 

In [35]:
criterion = nn.CrossEntropyLoss(ignore_index=-100)
optimizer = torch.optim.Adam(transformer.parameters())

In [36]:
from datetime import datetime
import os
os.environ['CUDA_LAUNCH_BLOCKING'] = "1"


def train(model, criterion, optimizer, train_loader, valid_loader, epochs):
    train_losses = np.zeros(epochs)
    test_losses = np.zeros(epochs)
    
    for it in range(epochs):
        model.train()
        t0 = datetime.now()
        train_loss = []
        for batch in train_loader:
            batch = {k: v.to(device) for k, v in batch.items()}
            optimizer.zero_grad()
            
            enc_input = batch['input_ids']
            enc_mask = batch['attention_mask']
            targets = batch['labels']
            
            dec_input = targets.clone().detach()
            dec_input = torch.roll(dec_input, shifts=1, dims=1)
            dec_input[:, 0] = 65_001
            
            dec_input = dec_input.masked_fill(dec_input == -100, tokenizer.pad_token_id)
            
            dec_mask = torch.ones_like(dec_input)
            dec_mask = dec_mask.masked_fill(dec_input == tokenizer.pad_token_id, 0)
            
            outputs = model(enc_input, dec_input, enc_mask, dec_mask)
            loss = criterion(outputs.transpose(2, 1), targets)
            
            loss.backward()
            optimizer.step()
            train_loss.append(loss.item())
            
        train_loss = np.mean(train_loss)
        
        model.eval()
        test_loss = []
        for batch in valid_loader:
            batch = {k: v.to(device) for k, v in batch.items()}
            
            enc_input = batch['input_ids']
            enc_mask = batch['attention_mask']
            targets = batch['labels']
            
            dec_input = targets.clone().detach()
            dec_input = torch.roll(dec_input, shifts=1, dims=1)
            dec_input[:, 0] = 65_001
            
            dec_input = dec_input.masked_fill(dec_input == -100, tokenizer.pad_token_id)
            
            dec_mask = torch.ones_like(dec_input)
            dec_mask = dec_mask.masked_fill(dec_input == tokenizer.pad_token_id, 0)
            
            outputs = model(enc_input, dec_input, enc_mask, dec_mask)
            loss = criterion(outputs.transpose(2, 1), targets)
            test_loss.append(loss.item())
        test_loss = np.mean(test_loss)
        
        train_losses[it] = train_loss
        test_losses[it] = test_loss
        
        dt = datetime.now() - t0
        
        print(f'Epoch {it+1}/{epochs}, Train Loss: {train_loss:.4f}, Test Loss: {test_loss:.4f}, Duration: {dt}')
    return train_losses, test_losses

In [54]:
train_losses, test_losses = train(transformer, criterion, optimizer, train_loader, valid_loader, epochs=3)

Epoch 1/3, Train Loss: 3.0902, Test Loss: 3.0908, Duration: 0:00:21.419990
Epoch 2/3, Train Loss: 2.7346, Test Loss: 2.8607, Duration: 0:00:20.816014
Epoch 3/3, Train Loss: 2.4463, Test Loss: 2.6809, Duration: 0:00:20.846570


In [38]:
input_sentence = split['test'][10]['en']
input_sentence

'Can I take a day off?'

In [39]:
enc_input = tokenizer(input_sentence, return_tensors='pt')
enc_input

{'input_ids': tensor([[1283,   33,  273,    8,  502,  843,   21,    0]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1]])}

In [40]:
dec_input_str = '<s>'
dec_input = tokenizer(text_target=dec_input_str, return_tensors='pt')
dec_input

{'input_ids': tensor([[65001,     0]]), 'attention_mask': tensor([[1, 1]])}

In [41]:
enc_input.to(device)
dec_input.to(device)
output = transformer(
    enc_input['input_ids'],
    dec_input['input_ids'][:, :-1],
    enc_input['attention_mask'],
    dec_input['attention_mask'][:, :-1]
)
output

tensor([[[ 2.0279, -6.5225,  2.8077,  ..., -5.7599, -6.1446, -4.9369]]],
       device='cuda:0', grad_fn=<ViewBackward0>)

In [42]:
output.shape

torch.Size([1, 1, 65002])

In [43]:
enc_output = encoder(enc_input['input_ids'], enc_input['attention_mask'])
enc_output.shape

torch.Size([1, 8, 64])

In [44]:
dec_output = decoder(enc_output, dec_input['input_ids'][:, :-1], enc_input['attention_mask'], dec_input['attention_mask'][:, :-1])
dec_output.shape

torch.Size([1, 1, 65002])

In [45]:
torch.allclose(output, dec_output)

True

In [46]:
dec_input_ids = dec_input['input_ids'][:, :-1]
dec_attn_mask = dec_input['attention_mask'][:, :-1]

for _ in range(32):
    dec_output = decoder(
        enc_output,
        dec_input_ids,
        enc_input['attention_mask'],
        dec_attn_mask
    )
    prediction_id = torch.argmax(dec_output[:, -1, :], axis=-1)
    
    
    dec_input_ids = torch.hstack((dec_input_ids, prediction_id.view(1, 1)))
    
    dec_attn_mask = torch.ones_like(dec_input_ids)
    
    if prediction_id == 0:
        break

In [47]:
tokenizer.decode(dec_input_ids[0])

'<s> ¿Quién puedo un buen?</s>'

In [48]:
split['test'][10]['es']

'¿Puedo tomarme un día libre?'

In [49]:
def translate(input_sentence):
    enc_input = tokenizer(input_sentence, return_tensors='pt').to(device)
    enc_output = encoder(enc_input['input_ids'], enc_input['attention_mask'])
    
    dec_input_ids = torch.tensor([[65_001]], device=device)
    dec_attn_mask = torch.ones_like(dec_input_ids, device=device)
    
    for _ in range(32):
        dec_output = decoder(
            enc_output,
            dec_input_ids,
            enc_input['attention_mask'],
            dec_attn_mask
        )
        
        prediction_id = torch.argmax(dec_output[:, -1, :], axis=-1)
        
        dec_input_ids = torch.hstack((dec_input_ids, prediction_id.view(1, 1)))
        
        dec_attn_mask = torch.ones_like(dec_input_ids, device=device)
        
        if prediction_id == 0:
            break
            
    translation = tokenizer.decode(dec_input_ids[0, 1:])
    print(translation)
        

In [56]:
translate("How are you?")

¿Cómo estás hecho?</s>
