In [1]:
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import dataset
import numpy as np
import matplotlib.pyplot as plt

In [2]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_k, d_model, n_heads, max_len, causal=False):
        super().__init__()

        self.d_k = d_k
        self.n_heads = n_heads

        self.key = nn.Linear(d_model, d_k*n_heads)
        self.query = nn.Linear(d_model, d_k*n_heads)
        self.value = nn.Linear(d_model, d_k*n_heads)

        self.fc = nn.Linear(d_k*n_heads, d_model)

        self.causal = causal

        if causal:
            cm = torch.tril(torch.ones(max_len, max_len))
            self.register_buffer(
                "causal_mask",
                cm.view(1, 1, max_len, max_len)
            )


    def forward(self, q, k, v, pad_mask=None):
        q = self.query(q)
        k = self.key(k)
        v = self.value(v)

        N = q.shape[0]
        T_output = q.shape[1]
        T_input = k.shape[1]

        q = q.view(N, T_output, self.n_heads, self.d_k).transpose(1, 2)
        k = k.view(N, T_input, self.n_heads, self.d_k).transpose(1, 2)
        v = v.view(N, T_input, self.n_heads, self.d_k).transpose(1, 2)

        attn_scores = q @ k.transpose(-2, -1) / math.sqrt(self.d_k)
        if pad_mask is not None:
            attn_scores = attn_scores.masked_fill(
                pad_mask[:, None, None, :] == 0, float('-inf')
            )

        if self.causal:
            attn_scores = attn_scores.masked_fill(
                self.causal_mask[:, :, :T_output, :T_output] == 0, float('-inf')
            )

        attn_weights = F.softmax(attn_scores, dim=-1)
        A = attn_weights @ v

        A = A.transpose(1, 2)
        A = A.contiguous().view(N, T_output, self.d_k * self.n_heads)

        return self.fc(A)

In [3]:
class EncoderBlock(nn.Module):
    def __init__(self, d_k, d_model, n_heads, max_len, dropout_prob=0.1):
        super().__init__()
        self.ln1 = nn.LayerNorm(d_model)
        self.ln2 = nn.LayerNorm(d_model)
        self.mha = MultiHeadAttention(d_k, d_model, n_heads, max_len, causal=False)
        self.ann = nn.Sequential(
            nn.Linear(d_model, d_model*4),
            nn.GELU(),
            nn.Linear(d_model*4, d_model),
            nn.Dropout(dropout_prob),
        )
        self.dropout = nn.Dropout(p=dropout_prob)

    def forward(self, x, pad_mask=None):
        x = self.ln1(x + self.mha(x, x, x, pad_mask))
        x = self.ln2(x + self.ann(x))
        x = self.dropout(x)
        return x

In [4]:
class DecoderBlock(nn.Module):
    def __init__(self, d_k, d_model, n_heads, max_len, dropout_prob=0.1):
        super().__init__()
        self.ln1 = nn.LayerNorm(d_model)
        self.ln2 = nn.LayerNorm(d_model)
        self.ln3 = nn.LayerNorm(d_model)
        self.mha_1 = MultiHeadAttention(d_k, d_model, n_heads, max_len, causal=True)
        self.mha_2 = MultiHeadAttention(d_k, d_model, n_heads, max_len, causal=False)
        self.ann = nn.Sequential(
            nn.Linear(d_model, d_model*4),
            nn.GELU(),
            nn.Linear(d_model*4, d_model),
            nn.Dropout(dropout_prob),
        )
        self.dropout = nn.Dropout(p=dropout_prob)

    def forward(self, enc_output, dec_input, enc_mask=None, dec_mask=None):
        x = self.ln1(dec_input + self.mha_1(dec_input, dec_input, dec_input, dec_mask))
        x = self.ln2(x + self.mha_2(x, enc_output, enc_output, enc_mask))
        x = self.ln3(x + self.ann(x))
        x = self.dropout(x)
        return x

In [5]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=2048, dropout_prob=0.1):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout_prob)
        position = torch.arange(max_len).unsqueeze(1)
        exp_term = torch.arange(0, d_model, 2)
        div_term = torch.exp(exp_term*(-math.log(10000.0)/ d_model))
        pe = torch.zeros(1, max_len, d_model)
        pe[0, :, 0::2] = torch.sin(position * div_term)
        pe[0, :, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + self.pe[:, :x.size(1), :]
        return self.dropout(x)

In [6]:
class Encoder(nn.Module):
    def __init__(self, vocab_size, max_len, d_k, d_model, n_heads, n_layers, dropout):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.pos_encoding = PositionalEncoding(d_model, max_len, dropout)
        transformer_blocks = [EncoderBlock(d_k, d_model, n_heads, dropout) for _ in range(n_layers)]
        self.transformer_blocks = nn.Sequential(*transformer_blocks)
        self.ln = nn.LayerNorm(d_model)

    def forward(self, x, pad_mask=None):
        x = self.embedding(x)
        x = self.pos_encoding(x)
        for block in self.transformer_blocks:
            x = block(x, pad_mask)

        x = self.ln(x)
        return x

In [7]:
class Decoder(nn.Module):
    def __init__(self, vocab_size, max_len, d_k, d_model, n_heads, n_layers, dropout_prob):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.pos_encoding = PositionalEncoding(d_model, max_len, dropout_prob)
        transformer_blocks = [DecoderBlock(d_k, d_model, n_heads, max_len, dropout_prob) for _ in range(n_layers)]
        self.transformer_blocks = nn.Sequential(*transformer_blocks)
        self.ln = nn.LayerNorm(d_model)
        self.fc = nn.Linear(d_model, vocab_size)

    def forward(self, enc_output, dec_input, enc_mask=None, dec_mask=None):
        x = self.embedding(dec_input)
        x = self.pos_encoding(x)
        for block in self.transformer_blocks:
            x = block(enc_output, x, enc_mask, dec_mask)
        x = self.ln(x)
        x = self.fc(x)
        return x

In [8]:
class Transformer(nn.Module):
    def __init__(self, encoder, decoder):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder

    def forward(self, enc_input, dec_input, enc_mask, dec_mask):
        enc_output = self.encoder(enc_input, enc_mask)
        dec_output = self.decoder(enc_output, dec_input, enc_mask, dec_mask)
        return dec_output

In [9]:
encoder = Encoder(vocab_size=20_000, max_len=512, d_k=16, d_model=64, n_heads=4, n_layers=2, dropout=0.1)
decoder = Decoder(vocab_size=10_000, max_len=512, d_k=16, d_model=64, n_heads=4, n_layers=2, dropout_prob=0.1)
transformer = Transformer(encoder, decoder)

In [10]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print(device)
encoder.to(device)
decoder.to(device)

cuda:0


Decoder(
  (embedding): Embedding(10000, 64)
  (pos_encoding): PositionalEncoding(
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (transformer_blocks): Sequential(
    (0): DecoderBlock(
      (ln1): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
      (ln2): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
      (ln3): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
      (mha_1): MultiHeadAttention(
        (key): Linear(in_features=64, out_features=64, bias=True)
        (query): Linear(in_features=64, out_features=64, bias=True)
        (value): Linear(in_features=64, out_features=64, bias=True)
        (fc): Linear(in_features=64, out_features=64, bias=True)
      )
      (mha_2): MultiHeadAttention(
        (key): Linear(in_features=64, out_features=64, bias=True)
        (query): Linear(in_features=64, out_features=64, bias=True)
        (value): Linear(in_features=64, out_features=64, bias=True)
        (fc): Linear(in_features=64, out_features=64, bias=True)

In [11]:
xe = np.random.randint(0, 20_000, size=(8, 512))
xe_t = torch.tensor(xe).to(device)

xd = np.random.randint(0, 10_000, size=(8, 256))
xd_t = torch.tensor(xd).to(device)

maske = np.ones((8, 512))
maske[:, 256:] = 0
maske_t = torch.tensor(maske).to(device)

maskd = np.ones((8, 256))
maskd[:, 128:] = 0
maskd_t = torch.tensor(maskd).to(device)

out = transformer(xe_t, xd_t, maske_t, maskd_t)
out.shape

torch.Size([8, 256, 10000])

In [12]:
out

tensor([[[ 0.5668, -0.3361,  0.1731,  ..., -0.0596,  0.0472,  0.9232],
         [-0.8454, -0.1654,  0.6588,  ...,  0.2035,  0.9799, -0.4447],
         [-0.1413, -0.4566,  0.3479,  ...,  0.2634, -0.3624,  0.2161],
         ...,
         [ 0.4689,  0.2202, -0.0609,  ..., -0.0480,  0.5865,  0.2795],
         [-1.1109,  0.3770,  0.0130,  ...,  0.5176,  0.6324, -0.2239],
         [-0.2892,  1.1359,  0.8920,  ..., -0.0225,  0.9042, -0.2479]],

        [[-0.0527, -0.1882, -1.0876,  ..., -0.4523,  0.2982, -0.6925],
         [-0.2568, -0.3868,  0.0182,  ..., -1.4608,  0.0201, -0.5967],
         [ 0.0343, -0.9393, -0.0073,  ..., -0.4455,  0.0993, -0.3178],
         ...,
         [-0.5268,  0.7419, -0.0461,  ..., -0.2047,  0.4140,  0.0680],
         [-0.1349, -1.0301, -0.0811,  ...,  0.3478,  0.0983,  0.8592],
         [ 0.2877,  0.0407, -0.4841,  ..., -0.2165,  0.2372,  0.5614]],

        [[-0.0577, -0.5636, -0.3619,  ..., -0.0161, -0.4010,  0.3100],
         [-0.3609, -0.3165, -0.8564,  ..., -0

In [13]:
!wget -nc https://lazyprogrammer.me/course_files/nlp3/spa.txt

--2024-11-13 17:52:00--  https://lazyprogrammer.me/course_files/nlp3/spa.txt
Resolving lazyprogrammer.me (lazyprogrammer.me)... 172.67.213.166, 104.21.23.210, 2606:4700:3030::ac43:d5a6, ...
Connecting to lazyprogrammer.me (lazyprogrammer.me)|172.67.213.166|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: unspecified [text/plain]
Saving to: ‘spa.txt’

spa.txt                 [    <=>             ]   7.45M  10.3MB/s    in 0.7s    

2024-11-13 17:52:01 (10.3 MB/s) - ‘spa.txt’ saved [7817148]



In [14]:
import pandas as pd

df = pd.read_csv('spa.txt', sep='\t', header=None)
df.head()

Unnamed: 0,0,1
0,Go.,Ve.
1,Go.,Vete.
2,Go.,Vaya.
3,Hi.,Hola.
4,Run!,¡Corre!


In [15]:
df.shape

(115245, 2)

In [16]:
df = df.iloc[:30_000]

In [17]:
df.columns = ['en', 'es']
df.to_csv('spa.csv', index=None)

In [19]:
!pip3 install datasets

Collecting datasets
  Downloading datasets-3.1.0-py3-none-any.whl.metadata (20 kB)
Collecting dill<0.3.9,>=0.3.0 (from datasets)
  Downloading dill-0.3.8-py3-none-any.whl.metadata (10 kB)
Collecting xxhash (from datasets)
  Downloading xxhash-3.5.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (12 kB)
Collecting multiprocess<0.70.17 (from datasets)
  Downloading multiprocess-0.70.16-py310-none-any.whl.metadata (7.2 kB)
Collecting fsspec<=2024.9.0,>=2023.1.0 (from fsspec[http]<=2024.9.0,>=2023.1.0->datasets)
  Downloading fsspec-2024.9.0-py3-none-any.whl.metadata (11 kB)
Downloading datasets-3.1.0-py3-none-any.whl (480 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m480.6/480.6 kB[0m [31m9.6 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading dill-0.3.8-py3-none-any.whl (116 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m116.3/116.3 kB[0m [31m9.5 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading fsspec-2024.9.0-py3-none-any.whl (1

In [20]:
from datasets import load_dataset

In [21]:
raw_dataset = load_dataset('csv', data_files='spa.csv')

Generating train split: 0 examples [00:00, ? examples/s]

In [22]:
raw_dataset

DatasetDict({
    train: Dataset({
        features: ['en', 'es'],
        num_rows: 30000
    })
})

In [23]:
split = raw_dataset['train'].train_test_split(test_size=0.3, seed=42)
split

DatasetDict({
    train: Dataset({
        features: ['en', 'es'],
        num_rows: 21000
    })
    test: Dataset({
        features: ['en', 'es'],
        num_rows: 9000
    })
})

In [24]:
from transformers import AutoTokenizer

model_checkpoint = 'Helsinki-NLP/opus-mt-en-es'
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


tokenizer_config.json:   0%|          | 0.00/44.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/1.47k [00:00<?, ?B/s]

source.spm:   0%|          | 0.00/802k [00:00<?, ?B/s]

target.spm:   0%|          | 0.00/826k [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/1.59M [00:00<?, ?B/s]



In [25]:
en_sentence = split['train'][0]['en']
es_sentence = split['train'][0]['es']

inputs = tokenizer(en_sentence)
targets = tokenizer(text_target=es_sentence)

tokenizer.convert_ids_to_tokens(targets['input_ids'])

['▁Yo', '▁puedo', '▁arreglarlo', '.', '</s>']

In [26]:
es_sentence

'Yo puedo arreglarlo.'

In [27]:
max_input_length = 128
max_target_length = 128

def process_function(batch):
    model_inputs = tokenizer(batch['en'], max_length=max_input_length, truncation=True)
    labels = tokenizer(text_target=batch['es'], max_length=max_target_length, truncation=True)

    model_inputs['labels'] = labels['input_ids']
    return model_inputs

In [28]:
tokenized_datasets = split.map(
    process_function,
    batched=True,
    remove_columns=split['train'].column_names
)

Map:   0%|          | 0/21000 [00:00<?, ? examples/s]

Map:   0%|          | 0/9000 [00:00<?, ? examples/s]

In [29]:
tokenized_datasets

DatasetDict({
    train: Dataset({
        features: ['input_ids', 'attention_mask', 'labels'],
        num_rows: 21000
    })
    test: Dataset({
        features: ['input_ids', 'attention_mask', 'labels'],
        num_rows: 9000
    })
})

In [30]:
from transformers import DataCollatorForSeq2Seq

data_collator = DataCollatorForSeq2Seq(tokenizer)
batch = data_collator([tokenized_datasets['train'][i] for i in range(0, 5)])
batch.keys()

dict_keys(['input_ids', 'attention_mask', 'labels'])

In [31]:
batch['input_ids']

tensor([[   33,    88,  9222,    48,     3,     0, 65000, 65000],
        [  552, 11490,     9,   310,   255,     3,     0, 65000],
        [  143,    31,   125,  1208,     3,     0, 65000, 65000],
        [ 1093,   220,  1890,    23,    48,     3,     0, 65000],
        [  124,    20,   100, 18422,    48,   141,     3,     0]])

In [32]:
tokenizer.all_special_ids

[0, 1, 65000]

In [33]:
tokenizer.all_special_tokens

['</s>', '<unk>', '<pad>']

In [34]:
from torch.utils.data import DataLoader

train_loader = DataLoader(
    tokenized_datasets["train"],
    shuffle=True,
    batch_size=32,
    collate_fn=data_collator
)

valid_loader = DataLoader(
    tokenized_datasets["test"],
    batch_size=32,
    collate_fn=data_collator
)


In [35]:
tokenizer.add_special_tokens({'cls_token':'<s>'})

1

In [67]:
encoder = Encoder(
    vocab_size=tokenizer.vocab_size + 1,
    max_len=512,
    d_k=16,
    d_model=64,
    n_heads=4,
    n_layers=2,
    dropout=0.1
)

decoder = Decoder(
    vocab_size=tokenizer.vocab_size + 1,
    max_len=512,
    d_k=16,
    d_model=64,
    n_heads=4,
    n_layers=2,
    dropout_prob=0.1
)

transformer = Transformer(encoder, decoder)

In [68]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print(device)
encoder.to(device)
decoder.to(device)

cuda:0


Decoder(
  (embedding): Embedding(65002, 64)
  (pos_encoding): PositionalEncoding(
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (transformer_blocks): Sequential(
    (0): DecoderBlock(
      (ln1): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
      (ln2): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
      (ln3): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
      (mha_1): MultiHeadAttention(
        (key): Linear(in_features=64, out_features=64, bias=True)
        (query): Linear(in_features=64, out_features=64, bias=True)
        (value): Linear(in_features=64, out_features=64, bias=True)
        (fc): Linear(in_features=64, out_features=64, bias=True)
      )
      (mha_2): MultiHeadAttention(
        (key): Linear(in_features=64, out_features=64, bias=True)
        (query): Linear(in_features=64, out_features=64, bias=True)
        (value): Linear(in_features=64, out_features=64, bias=True)
        (fc): Linear(in_features=64, out_features=64, bias=True)

In [69]:
criterion = nn.CrossEntropyLoss(ignore_index=-100)
optimizer = torch.optim.Adam(transformer.parameters())

In [70]:
from datetime import datetime

# A function to encapsulate the training loop
def train(model, criterion, optimizer, train_loader, valid_loader, epochs):
    train_losses = np.zeros(epochs)
    test_losses = np.zeros(epochs)

    for it in range(epochs):
        model.train()
        t0 = datetime.now()
        train_loss = []

        for batch in train_loader:
            # move data to GPU (enc_input, enc_mask, translation)
            batch = {k: v.to(device) for k, v in batch.items()}

            # zero the parameter gradients
            optimizer.zero_grad()

            enc_input = batch['input_ids']
            enc_mask = batch['attention_mask']
            targets = batch['labels']
            # shift targets forwards to get decoder_input
            dec_input = targets.clone().detach()
            dec_input = torch.roll(dec_input, shifts=1, dims=1)
            dec_input[:, 0] = 65_001

            # also convert all -100 to pad token id
            dec_input = dec_input.masked_fill(
                dec_input == -100, tokenizer.pad_token_id
            )

            # make decoder input mask
            dec_mask = torch.ones_like(dec_input)
            dec_mask = dec_mask.masked_fill(dec_input == tokenizer.pad_token_id, 0)

            # Forward pass
            outputs = model(enc_input, dec_input, enc_mask, dec_mask)
            loss = criterion(outputs.transpose(2, 1), targets)

            loss.backward()
            optimizer.step()
            train_loss.append(loss.item())

        # Calculate mean train loss for the epoch
        mean_train_loss = np.mean(train_loss)
        train_losses[it] = mean_train_loss

        # Validation phase
        model.eval()
        test_loss = []
        for batch in valid_loader:
            batch = {k: v.to(device) for k, v in batch.items()}

            enc_input = batch['input_ids']
            enc_mask = batch['attention_mask']
            targets = batch['labels']

            # shift targets forwards to get decoder_input
            dec_input = targets.clone().detach()
            dec_input = torch.roll(dec_input, shifts=1, dims=1)
            dec_input[:, 0] = 65_001

            # change -100s to regular padding
            dec_input = dec_input.masked_fill(
                dec_input == -100, tokenizer.pad_token_id
            )

            # make decoder input mask
            dec_mask = torch.ones_like(dec_input)
            dec_mask = dec_mask.masked_fill(dec_input == tokenizer.pad_token_id, 0)

            outputs = model(enc_input, dec_input, enc_mask, dec_mask)
            loss = criterion(outputs.transpose(2, 1), targets)
            test_loss.append(loss.item())

        # Calculate mean test loss for the epoch
        mean_test_loss = np.mean(test_loss)
        test_losses[it] = mean_test_loss

        dt = datetime.now() - t0
        print(f'Epoch {it+1}/{epochs}, Train loss: {mean_train_loss:.4f}, Test loss: {mean_test_loss:.4f}, Duration: {dt}')

    return train_losses, test_losses

# Call the train function
train_losses, test_losses = train(transformer, criterion, optimizer, train_loader, valid_loader, epochs=30)


Epoch 1/30, Train loss: 4.8922, Test loss: 3.9284, Duration: 0:00:17.215966
Epoch 2/30, Train loss: 3.5691, Test loss: 3.3859, Duration: 0:00:20.253297
Epoch 3/30, Train loss: 3.0890, Test loss: 3.1017, Duration: 0:00:17.012245
Epoch 4/30, Train loss: 2.7431, Test loss: 2.8691, Duration: 0:00:18.461730
Epoch 5/30, Train loss: 2.4597, Test loss: 2.7211, Duration: 0:00:17.167320
Epoch 6/30, Train loss: 2.2226, Test loss: 2.5932, Duration: 0:00:17.482307
Epoch 7/30, Train loss: 2.0296, Test loss: 2.5069, Duration: 0:00:17.762385
Epoch 8/30, Train loss: 1.8621, Test loss: 2.4480, Duration: 0:00:17.078671
Epoch 9/30, Train loss: 1.7230, Test loss: 2.4124, Duration: 0:00:17.068806
Epoch 10/30, Train loss: 1.6075, Test loss: 2.3949, Duration: 0:00:17.694889
Epoch 11/30, Train loss: 1.5112, Test loss: 2.3883, Duration: 0:00:17.036084
Epoch 12/30, Train loss: 1.4293, Test loss: 2.3799, Duration: 0:00:17.592312
Epoch 13/30, Train loss: 1.3608, Test loss: 2.3825, Duration: 0:00:17.630148
Epoch 14

In [71]:
input_sentence = split['test'][10]['en']
input_sentence

'Can I take a day off?'

In [72]:
enc_input = tokenizer(input_sentence, return_tensors='pt')
enc_input

{'input_ids': tensor([[1283,   33,  273,    8,  502,  843,   21,    0]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1]])}

In [73]:
dec_input_str = '<s>'
dec_input = tokenizer(text_target=dec_input_str, return_tensors='pt')
dec_input

{'input_ids': tensor([[65001,     0]]), 'attention_mask': tensor([[1, 1]])}

In [74]:
enc_input.to(device)
dec_input.to(device)
output = transformer(
    enc_input['input_ids'],
    dec_input['input_ids'][:, :-1],
    enc_input['attention_mask'],
    dec_input['attention_mask'][:, :-1]
)

In [75]:
output

tensor([[[-0.1380, -6.9155,  2.0359,  ..., -7.6014, -6.6189, -6.0984]]],
       device='cuda:0', grad_fn=<ViewBackward0>)

In [76]:
output.shape

torch.Size([1, 1, 65002])

In [77]:
enc_output = encoder(enc_input['input_ids'], enc_input['attention_mask'])
enc_output.shape

torch.Size([1, 8, 64])

In [78]:
dec_output = decoder(
    enc_output,
    dec_input['input_ids'][:, :-1],
    enc_input['attention_mask'],
    dec_input['attention_mask'][:, :-1],
)

dec_output.shape

torch.Size([1, 1, 65002])

In [79]:
torch.allclose(output, dec_output)

True

In [80]:
dec_input_ids = dec_input['input_ids'][:, :-1]
dec_attn_mask = dec_input['attention_mask'][:, :-1]

for _ in range(32):
  dec_output = decoder(
      enc_output,
      dec_input_ids,
      enc_input['attention_mask'],
      dec_attn_mask
  )

  prediction_id = torch.argmax(dec_output[:, -1, :], axis=-1)
  dec_input_ids = torch.hstack((dec_input_ids, prediction_id.view(1, 1)))
  dec_attn_mask = torch.ones_like(dec_input_ids)
  if prediction_id == 0:
    break

In [81]:
tokenizer.decode(dec_input_ids[0])

'<s> ¿Puedo tomar un día libre?</s>'

In [82]:
split['test'][10]['es']

'¿Puedo tomarme un día libre?'

In [83]:
def translate(input_sentence):
    # Get encoder output first
    enc_input = tokenizer(input_sentence, return_tensors='pt').to(device)
    enc_output = encoder(enc_input['input_ids'], enc_input['attention_mask'])

    # Setup initial decoder input
    dec_input_ids = torch.tensor([[65001]], device=device)  # Fixed the underscore here
    dec_attn_mask = torch.ones_like(dec_input_ids, device=device)

    # Now do the decoder loop
    for _ in range(32):  # Add a variable name for the loop
        dec_output = decoder(enc_output, dec_input_ids, enc_input['attention_mask'], dec_attn_mask)  # Correct function parameters here
        prediction_id = torch.argmax(dec_output[:, -1, :], axis=-1)
        dec_input_ids = torch.hstack((dec_input_ids, prediction_id.view(1, 1)))
        dec_attn_mask = torch.ones_like (dec_input_ids)

        if prediction_id == 0:
            break

    translation = tokenizer.decode(dec_input_ids[0, 1:])
    print(translation)

In [84]:
translate('how are you')

Los pescado eres.</s>
