## Imports

In [21]:
from src.tokenizer import Tokenizer, normalize_text
import random
import numpy as np
import pytest

import torch
from torch import nn
from torch import optim
from torch.utils.data import DataLoader, TensorDataset
import torch.nn.functional as F
from sklearn.model_selection import train_test_split
from tqdm.notebook import tqdm
import evaluate

import matplotlib.pyplot as plt
import jupyter_black

device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")

jupyter_black.load()
%matplotlib inline

## Dataset

In [42]:
src_tok: Tokenizer = torch.load("data/src_tok.pt")
tgt_tok: Tokenizer = torch.load("data/tgt_tok.pt")
vocab_size = src_tok.vocab_size
print(f"Vocab size: {vocab_size}")

Vocab size: 512


In [11]:
train_dl: DataLoader = torch.load("data/train_dl.pt")
val_dl: DataLoader = torch.load("data/val_dl.pt")
test_dl: DataLoader = torch.load("data/test_dl.pt")
tiny_train_dl: DataLoader = torch.load("data/tiny_train_dl.pt")

## Model Architecture

### Encoder

#### Architecture

In [15]:
class s2sEncoder(nn.Module):
    def __init__(
        self,
        input_size: int,
        emb_size: int,
        hidden_size: int,
        num_layers: int = 1,
    ):
        super().__init__()
        self.emb = nn.Embedding(input_size, emb_size)  # (D, V)
        self.rnn = nn.GRU(
            emb_size, hidden_size, num_layers=num_layers, batch_first=True
        )

    def forward(self, X):
        """
        :param X: (N, T) where N is the batch size and T is the sequence length
        :return: (N, T, H) full hidden state tensor where H is the hidden size,
        and (L, N, H) final hidden state tensors where L is the number of layers
        """
        emb = self.emb(X)  # (N, T, D)
        out, hidden = self.rnn(emb)  # (N, T, H), (L, N, H)
        return out, hidden

#### Testing the Encoder

In [24]:
D = H = 10  # Embedding and hidden state dimensions
L = 1  # Number of GRU layers
X, _ = next(iter(tiny_train_dl))  # Unpack just the first source batch (N, T)

"""Testing the Encoder"""

enc = s2sEncoder(
    input_size=vocab_size,
    emb_size=D,
    hidden_size=H,
    num_layers=L,
).to(device)
out, hidden = enc(X)

print(out.shape, out.shape == (*X.shape, H))  # (N, T, H)
print(hidden.shape, hidden.shape == (L, X.shape[0], H))  # (L, N, H)

torch.Size([5, 12, 10]) True
torch.Size([1, 5, 10]) True


### Decoder

#### Arhitecture

In [30]:
class s2sDecoder(nn.Module):
    def __init__(
        self,
        output_size: int,
        emb_size: int,
        hidden_size: int,
        num_layers: int = 1,
    ):
        super().__init__()
        self.output_size = output_size
        self.emb = nn.Embedding(output_size, emb_size)
        self.rnn = nn.GRU(
            emb_size, hidden_size, num_layers=num_layers, batch_first=True
        )
        self.lin = nn.Linear(hidden_size, output_size)

    def forward(self, X, hidden):
        """
        :param X: (N, T) where N is the batch size and T is the sequence length
        :param hidden: final encoder hidden state (L, N, H) where L is the number of layers,
        N is the batch size, and H is the hidden size
        :return: (N, T, V) where V is the output size
        """
        emb = self.emb(X)  # (N, T, D)
        out, hidden = self.rnn(emb, hidden)  # (N, T, H), (L, N, H)
        out = self.lin(out)  # (N, T, V)
        return out, hidden

#### Testing the Decoder

In [31]:
D = H = 10  # Embedding and hidden state dimensions
L = 1  # Number of GRU layers
X, Y = next(iter(tiny_train_dl))  # Unpack just the first source batch (N, T)

"""Testing the Encoder"""

enc = s2sEncoder(
    input_size=vocab_size,
    emb_size=D,
    hidden_size=H,
    num_layers=L,
).to(device)
out, hidden = enc(X)

print(out.shape, out.shape == (*X.shape, H))  # (N, T, H)
print(hidden.shape, hidden.shape == (L, X.shape[0], H))  # (L, N, H)

"""Testing the Decoder"""

dec = s2sDecoder(
    output_size=vocab_size,
    emb_size=D,
    hidden_size=H,
    num_layers=L,
).to(device)
out, hidden = dec(Y, hidden)

print(out.shape, out.shape == (*Y.shape, vocab_size))  # (N, T, V)
print(hidden.shape, hidden.shape == (L, Y.shape[0], H))  # (L, N, H)

torch.Size([5, 12, 10]) True
torch.Size([1, 5, 10]) True
torch.Size([5, 12, 512]) True
torch.Size([1, 5, 10]) True


### Seq2Seq Network

#### Architecture

In [142]:
class Seq2Seq(nn.Module):
    def __init__(
        self,
        encoder: s2sEncoder,
        decoder: s2sDecoder,
    ):
        super().__init__()
        self.enc = encoder
        self.dec = decoder
        self.apply(self._init_weights)

    def _init_weights(self, module: nn.Module):
        # Initialize uniform weights between -0.08 and 0.08
        # for the model
        for _, param in module.named_parameters():
            nn.init.uniform_(param.data, -0.08, 0.08)

    def forward(
        self,
        source: torch.tensor,
        target: torch.tensor,
        teacher_force_ratio: float = 0.0,
    ):
        """
        :param source: (N, T) input tensor where N is the batch size and T is the sequence length
        :param target: (N, T) target tensor
        :param teacher_force_ratio: float ratio of teacher forcing (0.0 to 1.0)
        :return: (N, T, V) where V is the output size
        """
        N, T = source.shape
        V = self.dec.output_size
        # Encoder step
        _, hidden = self.enc.forward(source)  # (N, T, H), (L, N, H)

        # Decoder step
        outputs = torch.zeros(N, T, V)  # (N, T, V)
        target_t = target[:, :1]  # (N, 1) initial decoder input token

        # We loop here as to let the function decide which input to use in each proceeding
        # RNN cell
        for t in range(1, T):
            dec_out, hidden = self.dec.forward(target_t, hidden)  # (N, 1, V), (L, N, H)
            # Set decoder output into total outputs
            outputs[:, t : t + 1] = dec_out  # (N, 1, V) -> (N, T, V)

            # Set up next input to decoder
            # If teacher_force_ratio is 0.0, then we use the decoder output as the next input
            # If teacher_force_ratio is 1.0, then we use the target as the next input
            teacher_force = random.random() < teacher_force_ratio
            target_t = target[:, t : t + 1] if teacher_force else dec_out.argmax(-1)

        return outputs

#### Testing the Seq2Seq Network

In [143]:
D = H = 10  # Embedding and hidden state dimensions
L = 1  # Number of GRU layers
X, Y = next(iter(tiny_train_dl))  # Unpack just the first source batch (N, T)

"""Testing the Seq2Seq model"""

enc = s2sEncoder(
    input_size=vocab_size,
    emb_size=D,
    hidden_size=H,
    num_layers=L,
)

dec = s2sDecoder(
    output_size=vocab_size,
    emb_size=D,
    hidden_size=H,
    num_layers=L,
)

model = Seq2Seq(enc, dec).to(device)
out = model(X, Y)
print(out.shape, out.shape == (*Y.shape, vocab_size))  # (N, T, V)

torch.Size([5, 12, 512]) True


## Training the Seq2Seq Network

In [144]:
def model_forward(
    model,
    source,
    target,
    loss_fn,
    teacher_force_ratio,
):
    # Forward pass - grab the logits that we'll map
    # to probabilities in the loss calculation
    logits = model.forward(
        source=source,
        target=target,
        teacher_force_ratio=teacher_force_ratio,
    )  # (N, T, V)
    _, _, V = logits.shape
    # Fit the logits into 2 dimensions
    logits = logits[:, 1:].reshape(-1, V).to(device)  # (N*(T-1), V)
    target = target[:, 1:].reshape(-1)  # (N*(T-1),)

    # Loss calculation
    loss = loss_fn(logits, target)
    return loss

In [145]:
def train_epoch(
    model,
    data_loader,
    loss_fn,
    optim,
    teacher_force_ratio: float = 0.5,
):
    # Iterate through one epoch-worth of data
    model.train()
    epoch_loss = 0
    # Iterate through the data loader
    for it, batch in enumerate(data_loader):
        optim.zero_grad()

        # Unpack the data loader
        # into source and target sequences
        xb, yb = batch  # (N, T), (N, T)

        # Forward pass - grab the logits that we'll map
        # to probabilities in the loss calculation
        loss = model_forward(model, xb, yb, loss_fn, teacher_force_ratio)
        epoch_loss += loss.item()

        # Backward pass
        loss.backward()

        # Optimization step
        optim.step()

    return epoch_loss / len(data_loader)

In [146]:
@torch.no_grad()
def evaluate_epoch(
    model,
    data_loader,
    loss_fn,
):
    model.eval()
    epoch_loss = 0
    # Iterate through all data in the data loader
    for batch in data_loader:
        # Unpack the data loader
        xb, yb = batch

        # Forward pass
        loss = model_forward(model, xb, yb, loss_fn, teacher_force_ratio=0.0)
        epoch_loss += loss.item()

    return epoch_loss / len(data_loader)

In [147]:
D, H, L = 256, 512, 2
lr = 1e-3

enc = s2sEncoder(vocab_size, D, H, L)
dec = s2sDecoder(vocab_size, D, H, L)
model = Seq2Seq(enc, dec).to(device)

loss_fn = nn.CrossEntropyLoss(ignore_index=tgt_tok.wtoi[tgt_tok.pad_token])
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
teacher_force_ratio = 0.5

epochs = 20
best_valid_loss = float("inf")
for epoch in tqdm(range(epochs), desc="Epochs"):
    train_loss = train_epoch(
        model=model,
        data_loader=train_dl,
        optim=optimizer,
        loss_fn=loss_fn,
        teacher_force_ratio=teacher_force_ratio,
    )
    val_loss = evaluate_epoch(
        model=model,
        data_loader=val_dl,
        loss_fn=loss_fn,
    )
    if val_loss < best_valid_loss:
        torch.save(model.state_dict(), "best-model.pt")
        best_valid_loss = val_loss
    print(
        f"({epoch+1}/{epochs})\tTrain Loss: {train_loss:7.3f} | Train PPL: {np.exp(train_loss):7.3f}",
        end="",
    )
    print(f"\tValid Loss: {val_loss:7.3f} | Valid PPL: {np.exp(val_loss):7.3f}")

Epochs:   0%|          | 0/20 [00:00<?, ?it/s]

(1/20)	Train Loss:   3.446 | Train PPL:  31.382	Valid Loss:   2.996 | Valid PPL:  19.999
(2/20)	Train Loss:   2.716 | Train PPL:  15.116	Valid Loss:   2.674 | Valid PPL:  14.505
(3/20)	Train Loss:   2.439 | Train PPL:  11.463	Valid Loss:   2.667 | Valid PPL:  14.390
(4/20)	Train Loss:   2.225 | Train PPL:   9.256	Valid Loss:   2.469 | Valid PPL:  11.811
(5/20)	Train Loss:   2.024 | Train PPL:   7.572	Valid Loss:   2.407 | Valid PPL:  11.100
(6/20)	Train Loss:   1.853 | Train PPL:   6.381	Valid Loss:   2.275 | Valid PPL:   9.726
(7/20)	Train Loss:   1.660 | Train PPL:   5.261	Valid Loss:   2.207 | Valid PPL:   9.091
(8/20)	Train Loss:   1.439 | Train PPL:   4.218	Valid Loss:   2.230 | Valid PPL:   9.298
(9/20)	Train Loss:   1.273 | Train PPL:   3.572	Valid Loss:   2.233 | Valid PPL:   9.328
(10/20)	Train Loss:   1.082 | Train PPL:   2.951	Valid Loss:   2.217 | Valid PPL:   9.180
(11/20)	Train Loss:   1.003 | Train PPL:   2.726	Valid Loss:   2.178 | Valid PPL:   8.827
(12/20)	Train Loss:

In [148]:
model.load_state_dict(torch.load("best-model.pt"))
test_loss = evaluate_epoch(model, test_dl, loss_fn)
print(f"| Test Loss: {test_loss:.3f} | Test PPL: {np.exp(test_loss):7.3f} |")

| Test Loss: 2.240 | Test PPL:   9.393 |


In [149]:
@torch.no_grad()
def translate_sentence(
    sentence,
    model,
    src_tokenizer: Tokenizer,
    tgt_tokenizer: Tokenizer,
    device,
    sos_token: str = "<SOS>",
    eos_token: str = "<EOS>",
    max_output_length: int = 25,
):
    """
    sentence: (T,)
    """
    model.eval()
    sentence = sentence.unsqueeze(0).to(device)  # (1, T)
    _, hidden = model.enc(sentence)  # (1, T, D)

    X = torch.tensor(
        [tgt_tokenizer.wtoi[sos_token]], dtype=torch.long, device=device
    ).reshape(1, -1)

    for i in range(max_output_length):
        dec_out, hidden = model.dec(X, hidden)  # (N, T, V)
        logits = dec_out[:, -1]  # (N, V)

        pred_token = logits.argmax(-1).reshape(1, -1)
        X = torch.cat((X, pred_token), dim=1)  # (N, T+1)

        if pred_token.item() == tgt_tokenizer.wtoi[eos_token]:
            break

    tokens = tgt_tok.untokenize(X.squeeze(0).tolist())

    return tokens

In [150]:
X, Y = next(iter(train_dl))

In [151]:
X.shape

torch.Size([25, 12])

In [152]:
ix = 21
src, tgt = X[ix], Y[ix]
print(src_tok.untokenize(src.tolist()))
print(tgt_tok.untokenize(tgt.tolist()))

<SOS> ils sont trs <UNK> <EOS>
<SOS> they are very kind <EOS>


In [153]:
translation = translate_sentence(
    src.to(device),
    model,
    src_tok,
    tgt_tok,
    device,
    max_output_length=tgt_tok.max_length,
)
translation

'<SOS> they very <UNK> <EOS>'

In [154]:
Xte = torch.tensor([]).type(torch.long).to(device)
Yte = torch.tensor([]).to(device)
for data in test_dl:
    Xte = torch.cat((Xte, data[0]), dim=0)
    Yte = torch.cat((Yte, data[1]), dim=0)
Xte.shape, Yte.shape

(torch.Size([332, 12]), torch.Size([332, 12]))

In [155]:
translations = [
    translate_sentence(
        src,
        model,
        src_tok,
        tgt_tok,
        device,
        max_output_length=tgt_tok.max_length,
    )
    for src in tqdm(Xte)
]

  0%|          | 0/332 [00:00<?, ?it/s]

In [156]:
preds = [" ".join(t.split()[1:-1]) for t in translations]
targets = [[" ".join(tgt_tok.untokenize(t.tolist()).split()[1:-1])] for t in Yte]

In [157]:
ix = -1
preds[ix], targets[ix]

('he the <UNK> <UNK>', ['he is very <UNK> about his <UNK>'])

In [135]:
bleu = evaluate.load("bleu")

In [158]:
results = bleu.compute(
    predictions=preds, references=targets, tokenizer=lambda x: x.split()
)
results

{'bleu': 0.03194757779682639,
 'precisions': [0.626362735381566,
  0.08862629246676514,
  0.025936599423631124,
  0.026785714285714284],
 'brevity_penalty': 0.4054028021991379,
 'length_ratio': 0.5255208333333333,
 'translation_length': 1009,
 'reference_length': 1920}