# Install & Imports

In [None]:
%%capture
!pip install torchtext --upgrade
!pip install sentencepiece
!pip install transformers[sentencepiece] datasets

In [None]:
import math
import random
import numpy as np
import seaborn as sns
from pylab import rcParams
import matplotlib.pyplot as plt
from matplotlib import rc

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader

import torchtext
from torchtext.data.metrics import bleu_score

import warnings
from tqdm.notebook import tqdm
warnings.filterwarnings('ignore')

In [None]:
# plot style configuration
%matplotlib inline
%config InlineBackend.figure_format = 'retina'
sns.set(style='whitegrid', palette='muted', font_scale=1.2)
rcParams['figure.figsize'] = 12, 8

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(torch.cuda.is_available())

# Data loading & preprocessing

In [None]:
from datasets import load_dataset
raw_datasets = load_dataset("opus_books", lang1="en", lang2="fr")
raw_datasets["train"] = raw_datasets["train"].select(range(500))
split_datasets = raw_datasets["train"].train_test_split(train_size=0.9, seed=20)

from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("t5-base",
                                          bos_token='<s>',
                                          add_bos_token=True,
                                          return_tensors="pt")
max_length = 64
def preprocess_function(examples):
    inputs = [ex["en"] for ex in examples["translation"]]
    targets = [ex["fr"] for ex in examples["translation"]]
    model_inputs = tokenizer(
        inputs, text_target=targets, max_length=max_length, truncation=True, padding='max_length'
    )
    model_inputs["labels"] = [[tokenizer.bos_token_id] + _labels[:-1] for _labels in model_inputs["labels"]]
    return model_inputs

tokenized_datasets = split_datasets.map(
    preprocess_function,
    batched=True,
    remove_columns=split_datasets["train"].column_names,
)

In [None]:
tokenized_datasets.set_format("torch")
train_dataloader = DataLoader(
    tokenized_datasets["train"],
    shuffle=True,
    batch_size=8,
)
test_dataloader = DataLoader(
    tokenized_datasets["test"], batch_size=8
)

In [None]:
_batch = 0
for batch in train_dataloader:
  _batch += 1
  for i in range(8):
    print(f"Entry {i}")
    print(tokenizer.decode(batch["input_ids"][i]))
    print(tokenizer.decode(batch["labels"][i]))
  if _batch > 4:
    break

# Model definition => Your turn



In [None]:
class ScaledDotProductAttention(nn.Module):
    """ Computes scaled dot product attention
    """

    def __init__(self, scale):
        super(ScaledDotProductAttention, self).__init__()
        self.scale = scale


    def forward(self, query, key, value, mask=None):
        """ query: (batch_size, query_len, head_dim)
         key: (batch_size, key_len, head_dim)
         value: (batch_size, value_len, head_dim)
         mask: (batch_size, 1, 1, source_seq_len) for source mask(batch_size, 1, target_seq_len, target_seq_len) for target mask
        """
        
        # calculate alignment scores
        
        scores = torch.matmul(query, key.transpose(1, 2)) / self.scale  # (batch_size, query_len, key_len)
        # (batch_size, query_len, value_len)

        # mask out invalid positions
        if mask is not None:
            scores = scores.masked_fill(mask == 0, float('-inf'))  # (batch_size, query_len, value_len)

        # calculate the attention weights (prob) from alignment scores*
        attn_probs = F.softmax(scores, dim=-1)  # (batch_size, query_len, value_len)

        # calculate context vector as the weighted sum of values
        
        output = torch.matmul(attn_probs, value)  # (batch_size, query_len, head_dim)

        

        # output: (batch_size, query_len, head_dim)
        # attn_probs: (batch_size, query_len, value_len)
        return output, attn_probs

In [None]:
class SingleHeadAttention(nn.Module):
    def __init__(self, d_model):
        super(SingleHeadAttention, self).__init__()

        self.d_model = d_model
        self.d_k = self.d_v = d_model
        self.W_q = nn.Linear(d_model, d_model, bias=False)
        self.W_k = nn.Linear(d_model, d_model, bias=False)
        self.W_v = nn.Linear(d_model, d_model, bias=False)
        self.W_o = nn.Linear(d_model, d_model)

        self.attention = ScaledDotProductAttention(np.sqrt(self.d_k))


    def forward(self, query, key, value, mask=None):
        """ query: (batch_size, query_len, d_model)
            key: (batch_size, key_len, d_model)
            value: (batch_size, value_len, d_model)
            mask: (batch_size, 1, source_seq_len) for source mask
                  (batch_size, target_seq_len, target_seq_len) for target mask
        """
        # apply linear projctions to query, key and value
        Q = self.W_q(query)  # (batch_size, query_len, d_model) 
        K = self.W_k(key)  # (batch_size, key_len, d_model)
        V = self.W_v(value)  # (batch_size, value_len, d_model)
        

        # calculate attention weights and context vector for each of the heads
        x, attn = self.attention(Q, K, V, mask)

        # x: (batch_size, query_len, head_dim)
        # attn: (batch_size, query_len, value_len)

        # apply linear projection to concatenated context vector
        x = self.W_o(x)  # (batch_size, query_len, d_model)

        # x: (batch_size, query_len, d_model)
        # attn: (batch_size, query_len, value_len)
        return x, attn

In [None]:
class EncoderLayer(nn.Module):

    def __init__(self, d_model, d_ff):
        super(EncoderLayer, self).__init__()
        self.d_model = d_model
        self.d_ff = d_ff

        self.attn_layer = SingleHeadAttention(d_model)
        self.attn_layer_norm = nn.LayerNorm(d_model, eps=1e-6)

        self.ff_w_1 = nn.Linear(d_model, d_ff)
        self.ff_w_2 = nn.Linear(d_ff, d_model)
        self.ff_layer_norm = nn.LayerNorm(d_model, eps=1e-6)



    def forward(self, x, mask):
        """ x: (batch_size, source_seq_len, d_model)
            mask: (batch_size, 1, source_seq_len)
        """
        # apply self-attention with query, key and value
        
        attn_output, _ = self.attn_layer(x, x, x, mask)
    
         # (batch_size, source_seq_len, d_model)

        # apply residual connection followed by layer normalization
        attn_output = self.attn_layer_norm(x+attn_output)  # (batch_size, source_seq_len, d_model)

        # apply position-wise feed-forward
        ff_output = F.relu(self.ff_w_1(attn_output))  # (batch_size, seq_len, d_ff)
        ff_output = self.ff_w_2(ff_output)  # (batch_size, seq_len, d_model)

        # apply residual connection followed by layer normalization
        block_output = self.ff_layer_norm(attn_output + ff_output)  # (batch_size, source_seq_len, d_model)

        # x: (batch_size, source_seq_len, d_model)
        return block_output

In [None]:
class DecoderLayer(nn.Module):

    def __init__(self, d_model, d_ff):
        super(DecoderLayer, self).__init__()
        self.d_model = d_model
        self.d_ff = d_ff

        self.attn_layer = SingleHeadAttention(d_model)
        self.attn_layer_norm = nn.LayerNorm(d_model, eps=1e-6)

        self.enc_attn_layer = SingleHeadAttention(d_model)
        self.enc_attn_layer_norm = nn.LayerNorm(d_model, eps=1e-6)

        self.ff_w_1 = nn.Linear(d_model, d_ff)
        self.ff_w_2 = nn.Linear(d_ff, d_model)
        self.ff_layer_norm = nn.LayerNorm(d_model, eps=1e-6)


    def forward(self, x, memory, src_mask, tgt_mask):
        """ x: (batch_size, target_seq_len, d_model)
            memory: (batch_size, source_seq_len, d_model)
            src_mask: (batch_size, 1, source_seq_len)
            tgt_mask: (batch_size, target_seq_len, target_seq_len)
        """
        # apply self-attention
        self_attn_output, _ = self.attn_layer(x, x, x, tgt_mask)  # (batch_size, source_seq_len, d_model)

        # apply residual connection followed by layer normalization
        self_attn_output = self.attn_layer_norm(x + self_attn_output)  # (batch_size, source_seq_len, d_model)

        # apply encoder-decoder attention
        # memory is the output from encoder block (encoder states)
        attn_output, attn = self.enc_attn_layer(self_attn_output, memory, memory, src_mask)

        # attn_output: (batch_size, target_seq_len, d_model)
        # attn: (batch_size, n_heads, target_seq_len, source_seq_len)

        # apply residual connection followed by layer normalization
        attn_output = self.enc_attn_layer_norm(self_attn_output + attn_output)  # (batch_size, target_seq_len, d_model)

        # apply position-wise feed-forward
        ff_output = F.relu(self.ff_w_1(attn_output))  # (batch_size, seq_len, d_ff)
        ff_output = self.ff_w_2(ff_output)  # (batch_size, seq_len, d_model)

        # apply residual connection followed by layer normalization
        block_output = self.ff_layer_norm(attn_output + ff_output)  # (batch_size, source_seq_len, d_model)

        # x: (batch_size, source_seq_len, d_model)
        return block_output, attn

# Full architecture

In [None]:
class PositionalEncoding(nn.Module):
    """ Implements the sinusoidal positional encoding.
    """

    def __init__(self, d_model, max_len=5000):
        super(PositionalEncoding, self).__init__()
        self.d_model = d_model
        self.max_len = max_len

        # compute positional encodings
        pe = torch.zeros(max_len, d_model)  # (max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)  # (max_len, 1)
        div_term = torch.exp(
            torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)
        )  # (d_model,)
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).transpose(0, 1)  # (max_len, 1, d_model)
        self.register_buffer('pe', pe)

    def forward(self, x):
        """ x: (batch_size, seq_len, d_model)
        """
        x = x + self.pe[:x.size(0), :]  # (batch_size, seq_len, d_model)

        # x: (batch_size, seq_len, d_model)
        return x

In [None]:
class Encoder(nn.Module):
    """ Encoder block is a stack of N identical encoder layers.
    """

    def __init__(self, vocab_size, d_model, n_layers, d_ff, pad_idx, max_len=5000):
        super(Encoder, self).__init__()
        self.vocab_size = vocab_size
        self.d_model = d_model
        self.n_layers = n_layers
        self.d_ff = d_ff
        self.pad_idx = pad_idx
        self.max_len = max_len

        self.tok_embedding = nn.Embedding(vocab_size, d_model, padding_idx=pad_idx)
        self.pos_embedding = PositionalEncoding(d_model, max_len)
        self.layers = nn.ModuleList([
            EncoderLayer(d_model, d_ff)
            for _ in range(n_layers)
        ])
        self.layer_norm = nn.LayerNorm(d_model, eps=1e-6)


    def forward(self, x, mask):
        """ x: (batch_size, source_seq_len)
            mask: (batch_size, 1, source_seq_len)
        """
        # apply positional encoding to token sequences
        x = self.tok_embedding(x)  # (batch_size, source_seq_len, d_model)
        x = self.pos_embedding(x)  # (batch_size, source_seq_len, d_model)

        for layer in self.layers:
            x = layer(x, mask)  # (batch_size, source_seq_len, d_model)

        x = self.layer_norm(x)  # (batch_size, source_seq_len, d_model)

        # x: (batch_size, source_seq_len, d_model)
        return x

In [None]:
class Decoder(nn.Module):
    """ Decoder block is a stack of N identical decoder layers.
    """

    def __init__(self, vocab_size, d_model, n_layers, d_ff, pad_idx, max_len=5000):
        super(Decoder, self).__init__()
        self.vocab_size = vocab_size
        self.d_model = d_model
        self.n_layers = n_layers
        self.d_ff = d_ff
        self.pad_idx = pad_idx
        self.max_len = max_len

        self.tok_embedding = nn.Embedding(vocab_size, d_model, padding_idx=pad_idx)
        self.pos_embedding = PositionalEncoding(d_model, max_len)
        self.layers = nn.ModuleList([
            DecoderLayer(d_model, d_ff)
            for _ in range(n_layers)
        ])
        self.layer_norm = nn.LayerNorm(d_model, eps=1e-6)


    def forward(self, x, memory, src_mask, tgt_mask):
        """ x: (batch_size, target_seq_len, d_model)
            memory: (batch_size, source_seq_len, d_model)
            src_mask: (batch_size, 1, source_seq_len)
            tgt_mask: (batch_size, target_seq_len, target_seq_len)
        """
        # apply positional encoding to token sequences
        x = self.tok_embedding(x)  # (batch_size, target_seq_len, d_model)
        x = self.pos_embedding(x)  # (batch_size, target_seq_len, d_model)

        for layer in self.layers:
            x, attn = layer(x, memory, src_mask, tgt_mask)  # (batch_size, target_seq_len, d_model)

        x = self.layer_norm(x)  # (batch_size, target_seq_len, d_model)

        # x: (batch_size, target_seq_len, d_model)
        # attn: (batch_size, n_heads, target_seq_len, source_seq_len)
        return x, attn

In [None]:
class Transformer(nn.Module):
    """ Transformer wrapper for encoder and decoder.
    """

    def __init__(self, encoder, decoder, generator, pad_idx):
        super(Transformer, self).__init__()
        self.pad_idx = pad_idx

        self.encoder = encoder
        self.decoder = decoder
        self.generator = generator


    def get_pad_mask(self, x, pad_idx):
        """ x: (batch_size, seq_len)
        """
        x = (x != pad_idx).unsqueeze(-2)  # (batch_size, 1, seq_len)

        # x: (batch_size, 1, seq_len)
        return x


    def get_causal_mask(self, x):
        """ x: (batch_size, seq_len)
        """
        seq_len = x.size(1)
        subsequent_mask = np.triu(np.ones((1, seq_len, seq_len)), k=1).astype(np.int8)  # (batch_size, seq_len, seq_len)
        subsequent_mask = (torch.from_numpy(subsequent_mask) == 0).to(x.device)  # (batch_size, seq_len, seq_len)


        # subsequent_mask: (batch_size, seq_len, seq_len)
        return subsequent_mask


    def forward(self, src, tgt):
        """ src: (batch_size, source_seq_len)
            tgt: (batch_size, target_seq_len)
        """
        # create masks for source and target
        src_mask = self.get_pad_mask(src, self.pad_idx)
        tgt_mask = self.get_pad_mask(tgt, self.pad_idx) & self.get_causal_mask(tgt)

        # src_mask: (batch_size, 1, seq_len)
        # tgt_mask: (batch_size, seq_len, seq_len)

        # encode the source sequence)
        enc_output = self.encoder(src, src_mask)  # (batch_size, source_seq_len, d_model)

        # decode based on source sequence and target sequence generated so far
        dec_output, attn = self.decoder(tgt, enc_output, src_mask, tgt_mask)

        # dec_output: (batch_size, target_seq_len, d_model)
        # attn: (batch_size, n_heads, target_seq_len, source_seq_len)

        # apply linear projection to obtain the output distribution
        output = self.generator(dec_output)  # (batch_size, target_seq_len, vocab_size)

        # output: (batch_size, target_seq_len, vocab_size)
        # attn: (batch_size, n_heads, target_seq_len, source_seq_len)
        return output, attn

In [None]:
class Generator(nn.Module):
    """ Linear projection layer for generating output distribution.
    """

    def __init__(self, d_model, vocab_size):
        super(Generator, self).__init__()
        self.proj = nn.Linear(d_model, vocab_size)


    def forward(self, x):
        """ x: (batch_size, target_seq_len, d_model)
        """
        # apply linear projection followed by softmax to obtain output distribution
        x = self.proj(x)  # (batch_size, target_seq_len, vocab_size)
        output = F.log_softmax(x, dim=-1)  # (batch_size, target_seq_len)

        # output: (batch_size, target_seq_len)
        return output

# Training


In [None]:
def train_fn(model, iterator, optimizer, critertion, clip=1.0):
    model.train()
    total_loss = 0
    steps = 0

    tk0 = tqdm(iterator, total=len(iterator), position=0, leave=True)

    for idx, batch in enumerate(tk0):
        source = batch["input_ids"].to(device)
        target = batch["labels"].to(device)

        # source: (batch_size, source_seq_len)
        # target: (batch_size, target_seq_len)

        # forward pass
        optimizer.zero_grad()
        output, _ = model(source, target[:, :-1])  # (batch_size, target_seq_len - 1, vocab_size) # Remove last target token (should be predicted)

        # calculate the loss
        loss = criterion(
            output.view(-1, output.size(-1)),  # (batch_size * (target_seq_len - 1), vocab_size)
            target[:, 1:].contiguous().view(-1)  # (batch_size * (target_seq_len - 1)) # Shift everything
        )
        total_loss += loss.item()
        steps += 1

        output = output.argmax(dim=-1)  # (batch_size, target_seq_len - 1)

        # backward pass
        loss.backward()

        # clip gradients to avoid exploding gradients issue
        nn.utils.clip_grad_norm_(model.parameters(), clip)

        # update model parameters
        optimizer.step()

        tk0.set_postfix(loss=total_loss/steps)

    tk0.close()

    perplexity = np.exp(total_loss / len(iterator))

    return output, perplexity

In [None]:
def eval_fn(model, iterator, criterion):
    model.eval()
    total_loss = 0.0
    steps = 0

    tk0 = tqdm(iterator, total=len(iterator), position=0, leave=True)

    with torch.no_grad():
        for idx, batch in enumerate(tk0):
            source = batch["input_ids"].to(device)
            target = batch["labels"].to(device)

            # source: (batch_size, source_seq_len)
            # target: (batch_size, target_seq_len)

            # forward pass
            output, _ = model(source, target[:, :-1])  # (batch_size, target_seq_len - 1, vocab_size) # Remove last target token (should be predicted)

            # calculate the loss
            loss = criterion(
                output.view(-1, output.size(-1)),  # (batch_size * (target_seq_len - 1), vocab_size)
                target[:, 1:].contiguous().view(-1)  # (batch_size * (target_seq_len - 1)) # Shift everything
            )
            total_loss += loss.item()
            steps += 1

            tk0.set_postfix(loss=total_loss/steps)

    tk0.close()

    perplexity = np.exp(total_loss / len(iterator))

    return output, perplexity

In [None]:
# hyperparameters
VOCAB_SIZE = len(tokenizer)
HIDDEN_SIZE = 256
N_LAYERS = 8
FF_SIZE = 1024
N_EPOCHS = 30
CLIP = 1.0

In [None]:
encoder = Encoder(VOCAB_SIZE, HIDDEN_SIZE, N_LAYERS, FF_SIZE, tokenizer.pad_token_id)
decoder = Decoder(VOCAB_SIZE, HIDDEN_SIZE, N_LAYERS, FF_SIZE, tokenizer.pad_token_id)
generator = Generator(HIDDEN_SIZE, VOCAB_SIZE)

model = Transformer(encoder, decoder, generator, tokenizer.pad_token_id).to(device)
print(model)
print(f'# of trainable params: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}')
print(f'# of non-trainable params: {sum(p.numel() for p in model.parameters() if not p.requires_grad):,}')

In [None]:
criterion = nn.CrossEntropyLoss(ignore_index=tokenizer.pad_token_id)

In [None]:
optimizer = optim.Adam(model.parameters(), lr=1e-4)

In [None]:
for epoch in range(0, N_EPOCHS + 1):
    # one epoch training
    _, train_perplexity = train_fn(model, train_dataloader, optimizer, criterion, CLIP)

    # one epoch validation
    _, valid_perplexity = eval_fn(model, test_dataloader, criterion)

    print(f'Epoch: {epoch}, Train perplexity: {train_perplexity:.4f}, Valid perplexity: {valid_perplexity:.4f}')


## Inference


In [None]:
def greedy_decode(model, token_ids, max_len=100):
    model.eval()

    source = torch.tensor(token_ids, dtype=torch.long).unsqueeze(0).to(device)  # (1, source_seq_len)
    source_mask = model.get_pad_mask(source, tokenizer.pad_token_id)  # (1, 1, source_seq_len)

    # encode the source sequence
    with torch.no_grad():
        enc_output = model.encoder(source, source_mask)  # (1, source_seq_len, d_model)

    target_ids = [tokenizer.bos_token_id]  # first token to start generating

    for i in range(max_len):
        target = torch.tensor(target_ids, dtype=torch.long).unsqueeze(0).to(device)  # (1, target_seq_len)
        target_mask = model.get_pad_mask(target, tokenizer.pad_token_id) & model.get_causal_mask(target)

        # decode the sequence
        with torch.no_grad():
            dec_output, attn = model.decoder(target, enc_output, source_mask, target_mask)

            # dec_output: (1, target_seq_len, d_model)
            # attn: (1, n_heads, target_seq_len, source_seq_len)

            output = model.generator(dec_output)  # (1, target_seq_len, vocab_size)

        target_id = output.argmax(dim=-1)[:, -1].item()
        target_ids.append(target_id)

        # stop decoding if we encounter EOS_TOKEN or reach the max length
        if target_id == tokenizer.eos_token_id or len(target_ids) >= max_len:
            break

    attn = attn.squeeze(0).cpu().detach().numpy()  # (n_heads, target_seq_len, source_seq_len)

    # target_tokens: list of size (target_seq_len - 1)
    # attn: (n_heads, target_seq_len, source_seq_len)

    return target_ids[1:], attn

In [None]:
def plot_attention_scores(source, target, attention):
    fig = plt.figure(figsize=(24, 24))
    x = source
    y = target

    sns.heatmap(
        attention, xticklabels=source, yticklabels=y, square=True,
        vmin=0.0, vmax=1.0, cbar=False, cmap="Blues"
    )

In [None]:
example_idx = 0
source = tokenized_datasets['train'][example_idx]["input_ids"]
target = tokenized_datasets['train'][example_idx]["labels"]

predicted, attention_scores = greedy_decode(model, source)

print(f'source: {tokenizer.decode(source)}\n')
print(f'target: {tokenizer.decode(target)}\n')
print(f'predicted: {tokenizer.decode(predicted)}\n')

plot_attention_scores(tokenizer.convert_ids_to_tokens(source), tokenizer.convert_ids_to_tokens(predicted), attention_scores)