In [1]:
import math

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import dataset

import numpy as np
import matplotlib.pyplot as plt

In [2]:
class MultiHeadAttention(nn.Module):
  def __init__(self, d_k, d_model, n_heads, max_len, causal=False):
    super().__init__()

    # Assume d_v = d_k

    self.d_k = d_k
    self.n_heads = n_heads

    self.key = nn.Linear(d_model, d_k * n_heads)
    self.query = nn.Linear(d_model, d_k*n_heads)
    self.value = nn.Linear(d_model, d_k*n_heads)

    #final layer
    self.fc = nn.Linear(d_k*n_heads, d_model)

    #causal mask
    self.causal = causal
    if causal:

      cm = torch.tril(torch.ones(max_len, max_len))
      self.register_buffer(
          "causal_mask",
          cm.view(1,1, max_len, max_len)
      )

  def forward(self, q, k, v, pad_mask=None):
    q = self.query(q) # N x T x (hd_k)
    k = self.key(k) # N x T x (hd_k)
    v = self.value(v) # N x T x (hd_v)

    N = q.shape[0]
    T_output = q.shape[1]
    T_input = k.shape[1]


    #change the sahpe to:
    # (N, T, h, d_k) -> (N, h, T, d_k)
    q = q.view(N, T_output, self.n_heads, self.d_k).transpose(1,2)
    k = k.view(N, T_input, self.n_heads, self.d_k).transpose(1,2)
    v = v.view(N, T_input, self.n_heads, self.d_k).transpose(1,2)

    #compute attention weights
    # (N, h, T, d_k) X (N, h, d_k, T) --> (N, h, T, T)

    attn_scores = q @ k.transpose(-2,-1) / math.sqrt(self.d_k)
    if pad_mask is not None:
      attn_scores = attn_scores.masked_fill(
          pad_mask[:, None, None, :] == 0, float("-inf")
      )
    if self.causal:
      attn_scores = attn_scores.masked_fill(
          self.causal_mask[:, :, :T_output, :T_input] == 0, float("-inf")
      )
    attn_weights = F.softmax(attn_scores, dim=-1)


    A = attn_weights @ v

    A = A.transpose(1,2)
    A = A.contiguous().view(N, T_output, self.d_k * self.n_heads)

    return self.fc(A)

In [3]:
class EncoderBlock(nn.Module):
  def  __init__(self, d_k, d_model, n_heads, max_len, dropout_prob=0.1):
    super().__init__()


    self.ln1 = nn.LayerNorm(d_model)
    self.ln2 = nn.LayerNorm(d_model)
    self.mha = MultiHeadAttention(d_k, d_model, n_heads, max_len, causal=False)
    self.ann = nn.Sequential(
        nn.Linear(d_model, d_model * 4),
        nn.GELU(),
        nn.Linear(d_model *4, d_model),
        nn.Dropout(dropout_prob),
    )
    self.dropout = nn.Dropout(p=dropout_prob)

  def forward(self, x, pad_mask=None):
    x = self.ln1(x + self.mha(x, x, x, pad_mask))
    x = self.ln2(x+self.ann(x))
    x = self.dropout(x)
    return x

In [4]:
# 2 multi head attentions and 1 forward network
class DecoderBlock(nn.Module):
  def __init__(self, d_k, d_model, n_heads, max_len, dropout_prob=0.1):
    super().__init__()

    self.ln1 = nn.LayerNorm(d_model)
    self.ln2 = nn.LayerNorm(d_model)
    self.ln3 = nn.LayerNorm(d_model)

    #causal masking allows

    #first multihead attention in decoder with causal masking
    self.mha1 = MultiHeadAttention(d_k, d_model, n_heads, max_len, causal=True)
    #second multihead attention in decoder with no causal masking
    self.mha2 = MultiHeadAttention(d_k, d_model, n_heads, max_len, causal=False)
    self.ann = nn.Sequential(
        nn.Linear(d_model, d_model * 4),
        nn.GELU(),
        nn.Linear(d_model * 4, d_model),
        nn. Dropout(dropout_prob),
    )
    self.dropout =nn.Dropout(p=dropout_prob)

  def forward(self, enc_output, dec_input, enc_mask=None, dec_mask=None):
    #selt-attention on decoder input
    x = self.ln1(
        dec_input + self.mha1(dec_input, dec_input, dec_input, dec_mask))
    # multi-head attention including encoder output
    x = self.ln2(x + self.mha2(x, enc_output, enc_output, enc_mask))

    x = self.ln3(x + self.ann(x))
    x = self.dropout(x)
    return x

In [5]:
class PositionalEncoding(nn.Module):
  def __init__(self, d_model, max_len=2048, dropout_prob=0.1):
    super().__init__()
    self.dropout = nn.Dropout(p=dropout_prob)

    position = torch.arange(max_len).unsqueeze(1)
    exp_term = torch.arange(0, d_model, 2)
    div_term = torch.exp(exp_term * (-math.log(1000.0) / d_model))
    pe = torch.zeros(1, max_len, d_model)
    pe[0, :, 0::2] = torch.sin(position * div_term)
    pe[0, :, 1::2] = torch.cos(position *div_term)

    self.register_buffer('pe', pe)

  def forward(self, x):
    # x.shape N x T x D
    x = x + self.pe[:, :x.size(1), :]
    return self.dropout(x)


In [6]:
class Encoder(nn.Module):
  def __init__(self,
               vocab_size,
               max_len,
               d_k,
               d_model,
               n_heads,
               n_layers,
               dropout_prob):
    super().__init__()

    self.embedding = nn.Embedding(vocab_size, d_model)
    self.pos_encoding = PositionalEncoding(d_model, max_len, dropout_prob)
    transformer_blocks = [
        EncoderBlock(
            d_k,
            d_model,
            n_heads,
            max_len,
            dropout_prob) for _ in range(n_layers)]
    self.transformer_blocks = nn.Sequential(*transformer_blocks)
    self.ln = nn.LayerNorm(d_model)

  def forward(self, x, pad_mask=None):
    x = self.embedding(x)
    x = self.pos_encoding(x)
    for block in self.transformer_blocks:
      x = block(x, pad_mask)

    # many-to-one (x has the shape N x T x D)
    # x = x[:, 0, :]

    x = self.ln(x)
    # x = self.fc(x)

    return x

In [7]:
class Decoder(nn.Module):
  def __init__(self,
               vocab_size,
               max_len,
               d_k,
               d_model,
               n_heads,
               n_layers,
               dropout_prob):
    super().__init__()

    self.embedding = nn.Embedding(vocab_size, d_model)
    self.pos_encoding = PositionalEncoding(d_model, max_len, dropout_prob)
    transformer_blocks = [
        DecoderBlock(
            d_k,
            d_model,
            n_heads,
            max_len,
            dropout_prob) for _ in range(n_layers)
    ]
    self.transformer_blocks = nn.Sequential(*transformer_blocks)
    self.ln = nn.LayerNorm(d_model)
    self.fc = nn.Linear(d_model, vocab_size)


  def forward(self, enc_output, dec_input, enc_mask= None, dec_mask=None):
    x = self.embedding(dec_input)
    x = self.pos_encoding(x)
    for block in self.transformer_blocks:
      x = block(enc_output, x, enc_mask, dec_mask)
    x = self.ln(x)
    x = self.fc(x) #many-to-many
    return x

In [8]:
class Transformer(nn.Module):
  def __init__(self, encoder, decoder):
    super().__init__()
    self.encoder = encoder
    self.decoder = decoder

  def forward(self, enc_input, dec_input, enc_mask, dec_mask):
    enc_output = self.encoder(enc_input, enc_mask)
    dec_output = self.decoder(enc_output, dec_input, enc_mask, dec_mask)
    return dec_output

In [9]:
# test it
encoder = Encoder(vocab_size = 20_0000,
                  max_len=512,
                  d_k=16,
                  d_model=64,
                  n_heads=4,
                  n_layers = 2,
                  dropout_prob=0.1)
decoder = Decoder(vocab_size=10_000,
                  max_len = 512,
                  d_k=16,
                  d_model=64,
                  n_heads=4,
                  n_layers=2,
                  dropout_prob=0.1)
transformer = Transformer(encoder, decoder)

In [10]:
device = torch.device("cuda:0" if torch.cuda.is_available() else 'cpu')
print(device)
encoder.to(device)
decoder.to(device)

cuda:0


Decoder(
  (embedding): Embedding(10000, 64)
  (pos_encoding): PositionalEncoding(
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (transformer_blocks): Sequential(
    (0): DecoderBlock(
      (ln1): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
      (ln2): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
      (ln3): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
      (mha1): MultiHeadAttention(
        (key): Linear(in_features=64, out_features=64, bias=True)
        (query): Linear(in_features=64, out_features=64, bias=True)
        (value): Linear(in_features=64, out_features=64, bias=True)
        (fc): Linear(in_features=64, out_features=64, bias=True)
      )
      (mha2): MultiHeadAttention(
        (key): Linear(in_features=64, out_features=64, bias=True)
        (query): Linear(in_features=64, out_features=64, bias=True)
        (value): Linear(in_features=64, out_features=64, bias=True)
        (fc): Linear(in_features=64, out_features=64, bias=True)
 

In [11]:
xe = np.random.randint(0, 20_000, size=(8, 512))
xe_t = torch.tensor(xe).to(device)

In [12]:
xd = np.random.randint(0, 10_000, size=(8,256))
xd_t = torch.tensor(xd).to(device)

maske = np.ones((8, 512))
maske[:, 256:] = 0
maske_t = torch.tensor(maske).to(device)

maskd = np.ones((8, 256))
maskd[:, 128:] = 0
maskd_t = torch.tensor(maskd).to(device)

out = transformer(xe_t, xd_t, maske_t, maskd_t)
out.shape

torch.Size([8, 256, 10000])

In [13]:
out

tensor([[[ 1.0296e+00,  2.5310e-01, -2.9792e-01,  ..., -1.8506e-01,
          -1.0349e-01,  8.8766e-01],
         [ 3.4757e-01, -8.2178e-02, -4.6838e-01,  ..., -1.2514e+00,
          -5.0817e-01,  4.0420e-01],
         [ 7.6369e-01,  4.0040e-01,  1.1665e-01,  ..., -1.1132e+00,
          -1.5190e+00, -2.4238e-01],
         ...,
         [ 1.7020e-01,  5.2279e-01,  1.6845e-01,  ..., -1.9008e-01,
           5.2193e-01,  7.4153e-01],
         [-7.1098e-01,  3.9611e-01,  1.1081e+00,  ..., -6.7240e-01,
           1.4406e-01,  6.6201e-01],
         [-3.0525e-01,  1.7331e-01,  8.4595e-02,  ...,  3.5961e-01,
          -2.7954e-01, -2.5132e-01]],

        [[ 1.2669e+00,  1.8146e-01,  1.9970e-01,  ..., -8.6964e-01,
           1.6903e-03,  7.0043e-01],
         [-2.6934e-02,  2.7283e-01,  3.0136e-01,  ..., -1.4598e-01,
          -2.9583e-01,  3.2296e-01],
         [ 5.4447e-01,  1.8635e-01,  5.4154e-01,  ..., -7.7369e-01,
          -9.9691e-01,  2.5795e-01],
         ...,
         [-7.8496e-01,  2

In [14]:
!wget -nc https://lazyprogrammer.me/course_files/nlp3/spa.txt

--2023-11-23 12:22:22--  https://lazyprogrammer.me/course_files/nlp3/spa.txt
Resolving lazyprogrammer.me (lazyprogrammer.me)... 104.21.23.210, 172.67.213.166, 2606:4700:3030::ac43:d5a6, ...
Connecting to lazyprogrammer.me (lazyprogrammer.me)|104.21.23.210|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: unspecified [text/plain]
Saving to: ‘spa.txt’

spa.txt                 [  <=>               ]   7.45M  19.1MB/s    in 0.4s    

2023-11-23 12:22:23 (19.1 MB/s) - ‘spa.txt’ saved [7817148]



In [15]:
!head spa.txt

Go.	Ve.
Go.	Vete.
Go.	Vaya.
Hi.	Hola.
Run!	¡Corre!
Who?	¿Quién?
Wow!	¡Órale!
Fire!	¡Fuego!
Fire!	¡Incendio!
Fire!	¡Disparad!


In [16]:
import pandas as pd
df = pd.read_csv('spa.txt', sep='\t', header=None)
df.head()

Unnamed: 0,0,1
0,Go.,Ve.
1,Go.,Vete.
2,Go.,Vaya.
3,Hi.,Hola.
4,Run!,¡Corre!


In [17]:
df.shape

(115245, 2)

In [18]:
df = df.iloc[:30_000] #takes too long

In [19]:
df.columns = ["en", "es"]
df.to_csv("spa.csv", index=None)

In [20]:
!head spa.csv

en,es
Go.,Ve.
Go.,Vete.
Go.,Vaya.
Hi.,Hola.
Run!,¡Corre!
Who?,¿Quién?
Wow!,¡Órale!
Fire!,¡Fuego!
Fire!,¡Incendio!


In [21]:
!pip install transformers datasets sentencepiece sacremoses

Collecting datasets
  Downloading datasets-2.15.0-py3-none-any.whl (521 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m521.2/521.2 kB[0m [31m4.0 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting sentencepiece
  Downloading sentencepiece-0.1.99-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.3 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.3/1.3 MB[0m [31m10.4 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting sacremoses
  Downloading sacremoses-0.1.1-py3-none-any.whl (897 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m897.5/897.5 kB[0m [31m14.0 MB/s[0m eta [36m0:00:00[0m
Collecting pyarrow-hotfix (from datasets)
  Downloading pyarrow_hotfix-0.6-py3-none-any.whl (7.9 kB)
Collecting dill<0.3.8,>=0.3.0 (from datasets)
  Downloading dill-0.3.7-py3-none-any.whl (115 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m115.3/115.3 kB[0m [31m13.4 MB/s[0m eta [36m0:00:00[0m
Collecting multiproc

In [22]:
from datasets import load_dataset
raw_dataset = load_dataset("csv", data_files="spa.csv")

Downloading data files:   0%|          | 0/1 [00:00<?, ?it/s]

Extracting data files:   0%|          | 0/1 [00:00<?, ?it/s]

Generating train split: 0 examples [00:00, ? examples/s]

In [23]:
raw_dataset

DatasetDict({
    train: Dataset({
        features: ['en', 'es'],
        num_rows: 30000
    })
})

In [24]:
split = raw_dataset["train"].train_test_split(test_size=0.3, seed=42)
split

DatasetDict({
    train: Dataset({
        features: ['en', 'es'],
        num_rows: 21000
    })
    test: Dataset({
        features: ['en', 'es'],
        num_rows: 9000
    })
})

In [25]:
from transformers import AutoTokenizer

In [26]:
model_checkpoint = 'Helsinki-NLP/opus-mt-en-es'
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)

tokenizer_config.json:   0%|          | 0.00/44.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/1.47k [00:00<?, ?B/s]

source.spm:   0%|          | 0.00/802k [00:00<?, ?B/s]

target.spm:   0%|          | 0.00/826k [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/1.59M [00:00<?, ?B/s]

In [27]:
en_sentence = split["train"][0]["en"]
es_sentence = split["train"][0]["es"]

In [28]:
inputs = tokenizer(en_sentence)
targets = tokenizer(text_target=es_sentence)

tokenizer.convert_ids_to_tokens(targets["input_ids"])

['▁Yo', '▁puedo', '▁arreglarlo', '.', '</s>']

In [29]:
es_sentence

'Yo puedo arreglarlo.'

In [30]:
#sequence length
max_input_length = 128
max_target_length = 128

def preprocess_function(batch):
  model_inputs = tokenizer(
      batch["en"], max_length=max_input_length, truncation=True
  )

  # Set up the tokenizer for targets
  labels = tokenizer(
      text_target=batch["es"], max_length = max_target_length, truncation=True
  )

  model_inputs["labels"] = labels["input_ids"]
  return model_inputs

In [31]:
tokenized_datasets = split.map(
    preprocess_function,
    batched=True,
    remove_columns=split["train"].column_names,
)

Map:   0%|          | 0/21000 [00:00<?, ? examples/s]

Map:   0%|          | 0/9000 [00:00<?, ? examples/s]

In [32]:
tokenized_datasets

DatasetDict({
    train: Dataset({
        features: ['input_ids', 'attention_mask', 'labels'],
        num_rows: 21000
    })
    test: Dataset({
        features: ['input_ids', 'attention_mask', 'labels'],
        num_rows: 9000
    })
})

In [33]:
from transformers import DataCollatorForSeq2Seq

In [34]:
data_collator = DataCollatorForSeq2Seq(tokenizer)

In [35]:
batch = data_collator([tokenized_datasets["train"][i] for i in range(0,5)])
batch.keys()

dict_keys(['input_ids', 'attention_mask', 'labels'])

In [36]:
batch["input_ids"]

tensor([[   33,    88,  9222,    48,     3,     0, 65000, 65000],
        [  552, 11490,     9,   310,   255,     3,     0, 65000],
        [  143,    31,   125,  1208,     3,     0, 65000, 65000],
        [ 1093,   220,  1890,    23,    48,     3,     0, 65000],
        [  124,    20,   100, 18422,    48,   141,     3,     0]])

In [37]:
batch["attention_mask"]

tensor([[1, 1, 1, 1, 1, 1, 0, 0],
        [1, 1, 1, 1, 1, 1, 1, 0],
        [1, 1, 1, 1, 1, 1, 0, 0],
        [1, 1, 1, 1, 1, 1, 1, 0],
        [1, 1, 1, 1, 1, 1, 1, 1]])

In [38]:
batch["labels"]

tensor([[  711,  1039, 44159,     3,     0,  -100,  -100,  -100],
        [ 2722, 18663,   239,   212,     3,     0,  -100,  -100],
        [  539,    43,   155,   960,     3,     0,  -100,  -100],
        [15165,  1250,   380,  3564,    36,  1016,     3,     0],
        [  350,     8, 19153,    29, 31326,     3,     0,  -100]])

In [39]:
#what each tokens does
tokenizer.all_special_ids

[0, 1, 65000]

In [40]:
#what are these special tokens
tokenizer.all_special_tokens

['</s>', '<unk>', '<pad>']

In [41]:
tokenizer("<pad>")

{'input_ids': [65000, 0], 'attention_mask': [1, 1]}

In [42]:
from torch.utils.data import DataLoader

In [43]:
train_loader = DataLoader(
    tokenized_datasets["train"],
    shuffle=True,
    batch_size=32,
    collate_fn=data_collator
)

valid_loader = DataLoader(
    tokenized_datasets["test"],
    batch_size = 32,
    collate_fn = data_collator
)

In [44]:
for batch in train_loader:
  for k, v in batch.items():
    print("k:", k, "v.shape", v.shape)
  break

k: input_ids v.shape torch.Size([32, 9])
k: attention_mask v.shape torch.Size([32, 9])
k: labels v.shape torch.Size([32, 10])


In [45]:
tokenizer.vocab_size

65001

In [46]:
tokenizer.decode([60020])

'دونم'

In [47]:
tokenizer.add_special_tokens({"cls_token": "<s>"})

1

In [48]:
tokenizer("<s>")

{'input_ids': [65001, 0], 'attention_mask': [1, 1]}

In [49]:
tokenizer.vocab_size

65001

In [50]:
encoder = Encoder(vocab_size = tokenizer.vocab_size + 1,
                  max_len = 512,
                  d_k=16,
                  d_model=64,
                  n_heads = 4,
                  n_layers=2,
                  dropout_prob=0.1)
decoder = Decoder(vocab_size=tokenizer.vocab_size + 1,
                  max_len=512,
                  d_k=16,
                  d_model=64,
                  n_heads=4,
                  n_layers=2,
                  dropout_prob=0.1)
transformer = Transformer(encoder, decoder)

In [51]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)
encoder.to(device)
decoder.to(device)

cuda:0


Decoder(
  (embedding): Embedding(65002, 64)
  (pos_encoding): PositionalEncoding(
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (transformer_blocks): Sequential(
    (0): DecoderBlock(
      (ln1): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
      (ln2): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
      (ln3): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
      (mha1): MultiHeadAttention(
        (key): Linear(in_features=64, out_features=64, bias=True)
        (query): Linear(in_features=64, out_features=64, bias=True)
        (value): Linear(in_features=64, out_features=64, bias=True)
        (fc): Linear(in_features=64, out_features=64, bias=True)
      )
      (mha2): MultiHeadAttention(
        (key): Linear(in_features=64, out_features=64, bias=True)
        (query): Linear(in_features=64, out_features=64, bias=True)
        (value): Linear(in_features=64, out_features=64, bias=True)
        (fc): Linear(in_features=64, out_features=64, bias=True)
 

In [52]:
criterion = nn.CrossEntropyLoss(ignore_index=-100)
optimizer = torch.optim.Adam(transformer.parameters())

In [53]:
from datetime import datetime
# A function to encapsulate th training loop
def train(model, criterion, optimizer, train_loader, valid_loader, epochs):
  train_losses = np.zeros(epochs)
  test_losses = np.zeros(epochs)

  for it in range(epochs):
    model.train()
    t0 = datetime.now()
    train_loss = []
    for batch in train_loader:
      #move data to GPU (enc_input, enc_mask, translation)
      batch = {k: v.to(device) for k,v in batch.items()}

      #zero the parameter gradients
      optimizer.zero_grad()

      #encoder inputs and masking
      enc_input = batch["input_ids"]
      enc_mask = batch["attention_mask"]

      #decoder target
      targets = batch["labels"]

      #shift targets forwards to det decoder_input
      #this is the opposite of before since we are getting the input not target
      dec_input = targets.clone().detach()
      dec_input = torch.roll(dec_input, shifts=1, dims=1)
      dec_input[:, 0] = 65_001

      #also convert all -100 to pad token id
      dec_input = dec_input.masked_fill(
          dec_input == -100, tokenizer.pad_token_id
      )

      #make decoder input mask
      dec_mask = torch.ones_like(dec_input)
      dec_mask = dec_mask.masked_fill(dec_input == tokenizer.pad_token_id, 0)

      #forward pass
      outputs = model(enc_input, dec_input, enc_mask, dec_mask) #transformer model
      loss = criterion(outputs.transpose(2,1), targets)

      #backward and optimize
      loss.backward()
      optimizer.step()
      train_loss.append(loss.item())

    # Get  train loss and test loss
    train_loss = np.mean(train_loss)

    model.eval()
    test_loss = []
    for batch in valid_loader:
      batch = {k: v.to(device) for k,v in batch.items()}

      enc_input = batch["input_ids"]
      enc_mask = batch["attention_mask"]
      targets = batch["labels"]

      #shift targets forwards to get decoder_input
      dec_input = targets.clone().detach()
      dec_input = torch.roll(dec_input, shifts=1, dims=1)
      dec_input[:, 0] = 65_001

      #also convert all -100 to pad token id
      dec_input = dec_input.masked_fill(
          dec_input == -100, tokenizer.pad_token_id
      )

      dec_mask = torch.ones_like(dec_input)
      dec_mask = dec_mask.masked_fill(dec_input == tokenizer.pad_token_id, 0)

      outputs = model(enc_input, dec_input, enc_mask, dec_mask)
      loss = criterion(outputs.transpose(2,1), targets)
      test_loss.append(loss.item())
    test_loss = np.mean(test_loss)

    #save losses
    train_losses[it] = train_loss
    test_losses[it]= test_loss

    dt = datetime.now() - t0
    print(f"Epoch {it+1}/{epochs}, Train Loss: {train_loss: .4f}, \
    Test Loss: {test_loss:.4f}, Duration {dt}")
  return train_losses, test_losses

In [71]:
train_losses, test_losses = train(
    transformer, criterion, optimizer, train_loader, valid_loader, epochs=50
)

Epoch 1/50, Train Loss:  1.0272,     Test Loss: 2.4160, Duration 0:00:19.807527
Epoch 2/50, Train Loss:  1.0095,     Test Loss: 2.4347, Duration 0:00:17.546426
Epoch 3/50, Train Loss:  0.9823,     Test Loss: 2.4360, Duration 0:00:18.635150
Epoch 4/50, Train Loss:  0.9610,     Test Loss: 2.4480, Duration 0:00:17.746363
Epoch 5/50, Train Loss:  0.9360,     Test Loss: 2.4376, Duration 0:00:17.891666
Epoch 6/50, Train Loss:  0.9278,     Test Loss: 2.4349, Duration 0:00:18.392787
Epoch 7/50, Train Loss:  0.9071,     Test Loss: 2.4456, Duration 0:00:18.119882
Epoch 8/50, Train Loss:  0.8853,     Test Loss: 2.4616, Duration 0:00:18.721745
Epoch 9/50, Train Loss:  0.8641,     Test Loss: 2.4617, Duration 0:00:17.940765
Epoch 10/50, Train Loss:  0.8596,     Test Loss: 2.4574, Duration 0:00:18.557909
Epoch 11/50, Train Loss:  0.8425,     Test Loss: 2.4748, Duration 0:00:17.770093
Epoch 12/50, Train Loss:  0.8214,     Test Loss: 2.4812, Duration 0:00:17.696917
Epoch 13/50, Train Loss:  0.8139,    

In [82]:
# try it out

input_sentence = split["test"][10]["en"]
input_sentence

'Can I take a day off?'

In [83]:
enc_input = tokenizer(input_sentence, return_tensors='pt')
enc_input

{'input_ids': tensor([[1283,   33,  273,    8,  502,  843,   21,    0]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1]])}

In [84]:
dec_input_str = "<s>"
dec_input = tokenizer(text_target=dec_input_str, return_tensors='pt')
dec_input

{'input_ids': tensor([[65001,     0]]), 'attention_mask': tensor([[1, 1]])}

In [85]:
enc_input.to(device)
dec_input.to(device)
output = transformer(
    enc_input["input_ids"],
    dec_input["input_ids"][:, :-1],
    enc_input["attention_mask"],
    dec_input["attention_mask"][:,:-1]
)

output

tensor([[[  6.2134, -15.4073,   1.1439,  ..., -16.0116, -14.1494, -13.0643]]],
       device='cuda:0', grad_fn=<ViewBackward0>)

In [86]:
dec_input["input_ids"][:, :-1]

tensor([[65001]], device='cuda:0')

In [87]:
output.shape # N x T x V

torch.Size([1, 1, 65002])

In [88]:
enc_output = encoder(enc_input["input_ids"], enc_input["attention_mask"])
enc_output.shape

torch.Size([1, 8, 64])

In [89]:
dec_output = decoder(
    enc_output,
    dec_input["input_ids"][:, :-1],
    enc_input["attention_mask"],
    dec_input["attention_mask"][:, :-1]
)
dec_output.shape

torch.Size([1, 1, 65002])

In [90]:
torch.allclose(output, dec_output)

True

In [91]:
dec_input_ids = dec_input["input_ids"][:, :-1]
dec_attn_mask = dec_input["attention_mask"][:,:-1] #we remove the end sentence tokenizer

for _ in range(32):
  dec_output = decoder(
      enc_output,
      dec_input_ids,
      enc_input["attention_mask"],
      dec_attn_mask,
  )

  #choose the best value
  prediction_id = torch.argmax(dec_output[:, - 1, :], axis=-1)

  # append to decoder input
  dec_input_ids = torch.hstack((dec_input_ids, prediction_id.view(1,1)))

  #recreate mask
  dec_attn_mask = torch.ones_like(dec_input_ids)

  # exit when reach </s>
  if prediction_id ==0:
    break

In [92]:
tokenizer.decode(dec_input_ids[0])

'<s> ¿Puedo hacer un día libre?</s>'

In [93]:
split["test"][10]['es']

'¿Puedo tomarme un día libre?'

In [94]:
def translate(input_sentence):
  # get encoder output first
  enc_input = tokenizer(input_sentence, return_tensors='pt').to(device)
  enc_output = encoder(enc_input['input_ids'], enc_input['attention_mask'])

  # setup initial decoder input
  dec_input_ids = torch.tensor([[65_001]], device=device)
  dec_attn_mask = torch.ones_like(dec_input_ids, device=device)

  # now do the decoder loop
  for _ in range(32):
    dec_output = decoder(
        enc_output,
        dec_input_ids,
        enc_input['attention_mask'],
        dec_attn_mask,
    )

    # choose the best value (or sample)
    prediction_id = torch.argmax(dec_output[:, -1, :], axis=-1)

    # append to decoder input
    dec_input_ids = torch.hstack((dec_input_ids, prediction_id.view(1, 1)))

    # recreate mask
    dec_attn_mask = torch.ones_like(dec_input_ids)

    # exit when reach </s>
    if prediction_id == 0:
      break

  translation = tokenizer.decode(dec_input_ids[0, 1:])
  print(translation + "\n")

In [118]:
for _ in range(3):
  sentence = input(" ")
  translate(sentence)

 do you like sports?
¿Te gusta el deporte?</s>

 come over here.
Ven acá.</s>

 go to your room now!
¡Vete a tu cuarto ahora mismo!</s>

