<a href="https://colab.research.google.com/github/Cazamere/Roformer/blob/main/Roformer.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
%matplotlib inline

Data Sourcing and Processing
============================


In [2]:
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator
from torchtext.datasets import multi30k, Multi30k
from typing import Iterable, List

# We need to modify the URLs for the dataset since the links to the original dataset are broken
multi30k.URL["train"] = "https://raw.githubusercontent.com/neychev/small_DL_repo/master/datasets/Multi30k/training.tar.gz"
multi30k.URL["valid"] = "https://raw.githubusercontent.com/neychev/small_DL_repo/master/datasets/Multi30k/validation.tar.gz"

SRC_LANGUAGE = 'en'
TGT_LANGUAGE = 'de'

token_transform = {}
vocab_transform = {}

In [3]:
# Once this cell is done, restart kernel to reload dependencies
!pip install -U torchdata
!pip install -U spacy
!python -m spacy download en_core_web_sm
!python -m spacy download de_core_news_sm

Collecting en-core-web-sm==3.7.1
  Downloading https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-3.7.1/en_core_web_sm-3.7.1-py3-none-any.whl (12.8 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m12.8/12.8 MB[0m [31m38.0 MB/s[0m eta [36m0:00:00[0m
[38;5;2m✔ Download and installation successful[0m
You can now load the package via spacy.load('en_core_web_sm')
[38;5;3m⚠ Restart to reload dependencies[0m
If you are in a Jupyter or Colab notebook, you may need to restart Python in
order to load all the package's dependencies. You can do this by selecting the
'Restart kernel' or 'Restart runtime' option.
Collecting de-core-news-sm==3.7.0
  Downloading https://github.com/explosion/spacy-models/releases/download/de_core_news_sm-3.7.0/de_core_news_sm-3.7.0-py3-none-any.whl (14.6 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m14.6/14.6 MB[0m [31m50.9 MB/s[0m eta [36m0:00:00[0m
[38;5;2m✔ Download and installation succ

In [4]:
!pip install portalocker==2.8.2



In [5]:
token_transform[SRC_LANGUAGE] = get_tokenizer('spacy', language='de_core_news_sm')
token_transform[TGT_LANGUAGE] = get_tokenizer('spacy', language='en_core_web_sm')

# helper function to yield list of tokens
def yield_tokens(data_iter: Iterable, language: str) -> List[str]:
    language_index = {SRC_LANGUAGE: 0, TGT_LANGUAGE: 1}

    for data_sample in data_iter:
        yield token_transform[language](data_sample[language_index[language]])

# Define special symbols and indices
UNK_IDX, PAD_IDX, BOS_IDX, EOS_IDX = 0, 1, 2, 3
# Make sure the tokens are in order of their indices to properly insert them in vocab
special_symbols = ['<unk>', '<pad>', '<bos>', '<eos>']

for ln in [SRC_LANGUAGE, TGT_LANGUAGE]:
    # Training data Iterator
    train_iter = Multi30k(split='train', language_pair=(SRC_LANGUAGE, TGT_LANGUAGE))
    # Create torchtext's Vocab object
    vocab_transform[ln] = build_vocab_from_iterator(yield_tokens(train_iter, ln),
                                                    min_freq=1,
                                                    specials=special_symbols,
                                                    special_first=True)

# Set ``UNK_IDX`` as the default index. This index is returned when the token is not found.
# If not set, it throws ``RuntimeError`` when the queried token is not found in the Vocabulary.
for ln in [SRC_LANGUAGE, TGT_LANGUAGE]:
  vocab_transform[ln].set_default_index(UNK_IDX)

Transformer
=================================

In [37]:
# Transformer model

from torch import Tensor
import torch
import torch.nn as nn
from torch.nn import Transformer
import math
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

class PositionalEncoding(nn.Module):
    def __init__(self,
                 emb_size: int,
                 dropout: float,
                 maxlen: int = 5000):
        super(PositionalEncoding, self).__init__()
        den = torch.exp(- torch.arange(0, emb_size, 2)* math.log(10000) / emb_size)
        pos = torch.arange(0, maxlen).reshape(maxlen, 1)
        pos_embedding = torch.zeros((maxlen, emb_size))
        pos_embedding[:, 0::2] = torch.sin(pos * den)
        pos_embedding[:, 1::2] = torch.cos(pos * den)
        pos_embedding = pos_embedding.unsqueeze(-2)

        self.dropout = nn.Dropout(dropout)
        self.register_buffer('pos_embedding', pos_embedding)

    def forward(self, token_embedding: Tensor):
        return self.dropout(token_embedding + self.pos_embedding[:token_embedding.size(0), :])

class TokenEmbedding(nn.Module):
    def __init__(self, vocab_size: int, emb_size):
        super(TokenEmbedding, self).__init__()
        self.embedding = nn.Embedding(vocab_size, emb_size)
        self.emb_size = emb_size

    def forward(self, tokens: Tensor):
        return self.embedding(tokens.long()) * math.sqrt(self.emb_size)

class Seq2SeqTransformer(nn.Module):
    def __init__(self,
                 num_encoder_layers: int,
                 num_decoder_layers: int,
                 emb_size: int,
                 nhead: int,
                 src_vocab_size: int,
                 tgt_vocab_size: int,
                 dim_feedforward: int = 512,
                 dropout: float = 0.1):
        super(Seq2SeqTransformer, self).__init__()
        self.transformer = Transformer(d_model=emb_size,
                                       nhead=nhead,
                                       num_encoder_layers=num_encoder_layers,
                                       num_decoder_layers=num_decoder_layers,
                                       dim_feedforward=dim_feedforward,
                                       dropout=dropout)
        self.generator = nn.Linear(emb_size, tgt_vocab_size)
        self.src_tok_emb = TokenEmbedding(src_vocab_size, emb_size)
        self.tgt_tok_emb = TokenEmbedding(tgt_vocab_size, emb_size)
        self.positional_encoding = PositionalEncoding(
            emb_size, dropout=dropout)

    def forward(self,
                src: Tensor,
                trg: Tensor,
                src_mask: Tensor,
                tgt_mask: Tensor,
                src_padding_mask: Tensor,
                tgt_padding_mask: Tensor,
                memory_key_padding_mask: Tensor):
        src_emb = self.positional_encoding(self.src_tok_emb(src))
        tgt_emb = self.positional_encoding(self.tgt_tok_emb(trg))
        outs = self.transformer(src_emb, tgt_emb, src_mask, tgt_mask, None,
                                src_padding_mask, tgt_padding_mask, memory_key_padding_mask)
        return self.generator(outs)

    def encode(self, src: Tensor, src_mask: Tensor):
        return self.transformer.encoder(self.positional_encoding(
                            self.src_tok_emb(src)), src_mask)

    def decode(self, tgt: Tensor, memory: Tensor, tgt_mask: Tensor):
        return self.transformer.decoder(self.positional_encoding(
                          self.tgt_tok_emb(tgt)), memory,
                          tgt_mask)

# Roformer

In [12]:
# Roformer

class RotaryEmbedding(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim
        inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim)) # theta
        self.register_buffer('inv_freq', inv_freq)

    def forward(self, pos): # pos -> p
        sinusoid_inp = torch.outer(pos.float().squeeze(0), self.inv_freq) # m * theta
        return torch.cat((sinusoid_inp.sin(), sinusoid_inp.cos()), dim=-1) # combine sin and cosine frequencies

def apply_rotary_pos_emb(x, sincos):
    sin, cos = torch.split(sincos, sincos.shape[-1] // 2, dim=1) # split sin and cosine frequencies
    sin = sin.unsqueeze(1)
    cos = cos.unsqueeze(1)
    x1, x2 = x[..., :x.size(-1) // 2], x[..., x.size(-1) // 2:]
    return torch.cat([x1 * cos + x2 * sin, x2 * cos - x1 * sin], dim=-1) # rotate input vectors

# Wrapper around nn.MultiheadAttention
class RotaryAttention(nn.Module):
    def __init__(self, emb_size, nhead, dropout=0.0):
        super().__init__()
        self.multihead_attn = nn.MultiheadAttention(emb_size, nhead, dropout=dropout)
        self.rotary_emb = RotaryEmbedding(emb_size)

    def forward(self, query, key, value, attn_mask=None, key_padding_mask=None):
        query_seq_len, key_seq_len, batch_size = query.size(0), key.size(0), query.size(1)
        query_pos = torch.arange(query_seq_len, device=query.device).unsqueeze(0) # Generates a range of position indices from 0 to query_seq_len-1
        query_pos_emb = self.rotary_emb(query_pos)

        key_pos = torch.arange(key_seq_len, device=key.device).unsqueeze(0) # Generates a range of position indices from 0 to key_seq_len-1
        key_pos_emb = self.rotary_emb(key_pos)

        # Apply rotary embeddings to queries and keys
        query, key = apply_rotary_pos_emb(query, query_pos_emb), apply_rotary_pos_emb(key, key_pos_emb)

        # Proceed with MHA
        return self.multihead_attn(query, key, value, attn_mask=attn_mask, key_padding_mask=key_padding_mask)

class TokenEmbedding(nn.Module):
    def __init__(self, vocab_size: int, emb_size: int):
        super(TokenEmbedding, self).__init__()
        self.embedding = nn.Embedding(vocab_size, emb_size)
        self.emb_size = emb_size

    def forward(self, tokens: Tensor):
        return self.embedding(tokens.long()) * math.sqrt(self.emb_size)

class RoformerEncoderLayer(nn.Module):
    def __init__(self, emb_size: int, nhead: int, dim_feedforward: int, dropout: float):
        super(RoformerEncoderLayer, self).__init__()
        self.self_attn = RotaryAttention(emb_size, nhead, dropout=dropout)
        self.linear1 = nn.Linear(emb_size, dim_feedforward)
        self.dropout = nn.Dropout(dropout)
        self.linear2 = nn.Linear(dim_feedforward, emb_size)

        self.norm1 = nn.LayerNorm(emb_size)
        self.norm2 = nn.LayerNorm(emb_size)
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)

        self.activation = nn.ReLU()

    def forward(self, src: Tensor, src_mask: Tensor, src_key_padding_mask: Tensor):
        src2 = self.self_attn(src, src, src, attn_mask=src_mask,
                              key_padding_mask=src_key_padding_mask)[0]
        src = src + self.dropout1(src2)
        src = self.norm1(src)
        src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
        src = src + self.dropout2(src2)
        src = self.norm2(src)
        return src

class RoformerDecoderLayer(nn.Module):
    def __init__(self, emb_size: int, nhead: int, dim_feedforward: int, dropout: float):
        super(RoformerDecoderLayer, self).__init__()
        self.self_attn = RotaryAttention(emb_size, nhead, dropout=dropout)
        self.multihead_attn = RotaryAttention(emb_size, nhead, dropout=dropout)
        self.linear1 = nn.Linear(emb_size, dim_feedforward)
        self.dropout = nn.Dropout(dropout)
        self.linear2 = nn.Linear(dim_feedforward, emb_size)

        self.norm1 = nn.LayerNorm(emb_size)
        self.norm2 = nn.LayerNorm(emb_size)
        self.norm3 = nn.LayerNorm(emb_size)
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)
        self.dropout3 = nn.Dropout(dropout)

        self.activation = nn.ReLU()

    def forward(self, tgt: Tensor, memory: Tensor, tgt_mask: Tensor, memory_mask: Tensor,
                tgt_key_padding_mask: Tensor, memory_key_padding_mask: Tensor):
        tgt2 = self.self_attn(tgt, tgt, tgt, attn_mask=tgt_mask,
                              key_padding_mask=tgt_key_padding_mask)[0]
        tgt = tgt + self.dropout1(tgt2)
        tgt = self.norm1(tgt)
        tgt2 = self.multihead_attn(tgt, memory, memory, attn_mask=memory_mask,
                                   key_padding_mask=memory_key_padding_mask)[0]
        tgt = tgt + self.dropout2(tgt2)
        tgt = self.norm2(tgt)
        tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
        tgt = tgt + self.dropout3(tgt2)
        tgt = self.norm3(tgt)
        return tgt

class RoformerEncoder(nn.Module):
    def __init__(self, emb_size, nhead, num_layers, dim_feedforward, dropout):
        super(RoformerEncoder, self).__init__()
        self.layers = nn.ModuleList([RoformerEncoderLayer(emb_size, nhead, dim_feedforward, dropout) for _ in range(num_layers)])
        self.norm = nn.LayerNorm(emb_size)

    def forward(self, src, mask, src_key_padding_mask):
        for layer in self.layers:
            src = layer(src, src_mask=mask, src_key_padding_mask=src_key_padding_mask)
        src = self.norm(src)
        return src

class RoformerDecoder(nn.Module):
    def __init__(self, emb_size, nhead, num_layers, dim_feedforward, dropout):
        super(RoformerDecoder, self).__init__()
        self.layers = nn.ModuleList([RoformerDecoderLayer(emb_size, nhead, dim_feedforward, dropout) for _ in range(num_layers)])
        self.norm = nn.LayerNorm(emb_size)

    def forward(self, tgt, memory, tgt_mask, memory_mask, tgt_key_padding_mask, memory_key_padding_mask):
        for layer in self.layers:
            tgt = layer(tgt, memory, tgt_mask=tgt_mask, memory_mask=memory_mask,
                        tgt_key_padding_mask=tgt_key_padding_mask, memory_key_padding_mask=memory_key_padding_mask)
        tgt = self.norm(tgt)
        return tgt

class Roformer(nn.Module):
    def __init__(self, num_encoder_layers, num_decoder_layers, emb_size, nhead, src_vocab_size, tgt_vocab_size, dim_feedforward=512, dropout=0.1):
        super(Roformer, self).__init__()
        self.src_tok_emb = TokenEmbedding(src_vocab_size, emb_size)
        self.tgt_tok_emb = TokenEmbedding(tgt_vocab_size, emb_size)
        self.dropout = nn.Dropout(dropout)

        self.encoder = RoformerEncoder(emb_size, nhead, num_encoder_layers, dim_feedforward, dropout)
        self.decoder = RoformerDecoder(emb_size, nhead, num_decoder_layers, dim_feedforward, dropout)

        self.generator = nn.Linear(emb_size, tgt_vocab_size)

    def forward(self, src, tgt, src_mask, tgt_mask, src_padding_mask, tgt_padding_mask, memory_key_padding_mask):
        src_emb = self.dropout(self.src_tok_emb(src))
        tgt_emb = self.dropout(self.tgt_tok_emb(tgt))

        memory = self.encoder(src_emb, src_mask, src_padding_mask)
        outs = self.decoder(tgt_emb, memory, tgt_mask, None, tgt_padding_mask, memory_key_padding_mask)
        return self.generator(outs)

    def encode(self, src, src_mask):
        src_emb = self.dropout(self.src_tok_emb(src))
        return self.encoder(src_emb, src_mask, None)

    def decode(self, tgt, memory, tgt_mask):
        tgt_emb = self.dropout(self.tgt_tok_emb(tgt))
        return self.decoder(tgt_emb, memory, tgt_mask, None, None, None)

# Attention Masks

In [13]:
def generate_square_subsequent_mask(sz):
    mask = (torch.triu(torch.ones((sz, sz), device=DEVICE)) == 1).transpose(0, 1)
    mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
    return mask


def create_mask(src, tgt):
    src_seq_len = src.shape[0]
    tgt_seq_len = tgt.shape[0]

    tgt_mask = generate_square_subsequent_mask(tgt_seq_len)
    src_mask = torch.zeros((src_seq_len, src_seq_len),device=DEVICE).type(torch.bool)

    src_padding_mask = (src == PAD_IDX).transpose(0, 1)
    tgt_padding_mask = (tgt == PAD_IDX).transpose(0, 1)
    return src_mask, tgt_mask, src_padding_mask, tgt_padding_mask

# Training Setup

In [140]:
from torch.optim.lr_scheduler import _LRScheduler
import sys

torch.manual_seed(0)

model = 'roformer'

SRC_VOCAB_SIZE = len(vocab_transform[SRC_LANGUAGE])
TGT_VOCAB_SIZE = len(vocab_transform[TGT_LANGUAGE])
EMB_SIZE = 512
NHEAD = 4
FFN_HID_DIM = 512
BATCH_SIZE = 128
NUM_ENCODER_LAYERS = 4
NUM_DECODER_LAYERS = 4

if model == 'transformer':
  transformer = Seq2SeqTransformer(NUM_ENCODER_LAYERS, NUM_DECODER_LAYERS, EMB_SIZE,
                                 NHEAD, SRC_VOCAB_SIZE, TGT_VOCAB_SIZE, FFN_HID_DIM)
elif model == 'roformer':
  transformer = Roformer(NUM_ENCODER_LAYERS, NUM_DECODER_LAYERS, EMB_SIZE,
                                 NHEAD, SRC_VOCAB_SIZE, TGT_VOCAB_SIZE, FFN_HID_DIM)
else:
  sys.exit('Model not supported')

for p in transformer.parameters():
    if p.dim() > 1:
        nn.init.xavier_uniform_(p)

transformer = transformer.to(DEVICE)

loss_fn = torch.nn.CrossEntropyLoss(ignore_index=PAD_IDX, label_smoothing=0.1)

optimizer = torch.optim.Adam(transformer.parameters(), lr=1e-4, betas=(0.9, 0.98), eps=1e-9)

Collation
=========




In [141]:
from torch.nn.utils.rnn import pad_sequence

# helper function to club together sequential operations
def sequential_transforms(*transforms):
    def func(txt_input):
        for transform in transforms:
            txt_input = transform(txt_input)
        return txt_input
    return func

# function to add BOS/EOS and create tensor for input sequence indices
def tensor_transform(token_ids: List[int]):
    return torch.cat((torch.tensor([BOS_IDX]),
                      torch.tensor(token_ids),
                      torch.tensor([EOS_IDX])))

# ``src`` and ``tgt`` language text transforms to convert raw strings into tensors indices
text_transform = {}
for ln in [SRC_LANGUAGE, TGT_LANGUAGE]:
    text_transform[ln] = sequential_transforms(token_transform[ln], #Tokenization
                                               vocab_transform[ln], #Numericalization
                                               tensor_transform) # Add BOS/EOS and create tensor


# function to collate data samples into batch tensors
def collate_fn(batch):
    src_batch, tgt_batch = [], []
    for src_sample, tgt_sample in batch:
        src_batch.append(text_transform[SRC_LANGUAGE](src_sample.rstrip("\n")))
        tgt_batch.append(text_transform[TGT_LANGUAGE](tgt_sample.rstrip("\n")))

    src_batch = pad_sequence(src_batch, padding_value=PAD_IDX)
    tgt_batch = pad_sequence(tgt_batch, padding_value=PAD_IDX)
    return src_batch, tgt_batch

# Training & Evaluate Loops


In [142]:
from torch.utils.data import DataLoader

def train_epoch(model, optimizer):
    model.train()
    losses = 0
    train_iter = Multi30k(split='train', language_pair=(SRC_LANGUAGE, TGT_LANGUAGE))
    train_dataloader = DataLoader(train_iter, batch_size=BATCH_SIZE, collate_fn=collate_fn)

    for src, tgt in train_dataloader:
        src = src.to(DEVICE)
        tgt = tgt.to(DEVICE)

        tgt_input = tgt[:-1, :]

        src_mask, tgt_mask, src_padding_mask, tgt_padding_mask = create_mask(src, tgt_input)

        logits = model(src, tgt_input, src_mask, tgt_mask,src_padding_mask, tgt_padding_mask, src_padding_mask)

        optimizer.zero_grad()

        tgt_out = tgt[1:, :]
        loss = loss_fn(logits.reshape(-1, logits.shape[-1]), tgt_out.reshape(-1))
        loss.backward()

        optimizer.step()
        losses += loss.item()

    return losses / len(list(train_dataloader))

def evaluate_epoch(model):
    model.eval()
    losses = 0

    val_iter = Multi30k(split='valid', language_pair=(SRC_LANGUAGE, TGT_LANGUAGE))
    val_dataloader = DataLoader(val_iter, batch_size=BATCH_SIZE, collate_fn=collate_fn)

    for src, tgt in val_dataloader:
        src = src.to(DEVICE)
        tgt = tgt.to(DEVICE)

        tgt_input = tgt[:-1, :]

        src_mask, tgt_mask, src_padding_mask, tgt_padding_mask = create_mask(src, tgt_input)

        logits = model(src, tgt_input, src_mask, tgt_mask,src_padding_mask, tgt_padding_mask, src_padding_mask)

        tgt_out = tgt[1:, :]
        loss = loss_fn(logits.reshape(-1, logits.shape[-1]), tgt_out.reshape(-1))
        losses += loss.item()

    return losses / len(list(val_dataloader))

# Train

In [143]:
from timeit import default_timer as timer
from tqdm import tqdm

def train(transformer):

  NUM_EPOCHS = 20

  best_val_loss = float('inf')
  patience = 0

  patience_max = 5

  for epoch in range(1, NUM_EPOCHS+1):
      start_time = timer()
      train_loss = train_epoch(transformer, optimizer)
      end_time = timer()
      val_loss = evaluate_epoch(transformer)


      if val_loss < best_val_loss:
        best_val_loss = val_loss
        torch.save(transformer, 'model.pt')
        patience = 0
        bleu = evaluate_test(transformer, 'greedy')
        print(f'New best! Bleu: {bleu:.3f}')

      patience += 1

      if patience == patience_max:
        break

      print((f"Epoch: {epoch}, Train loss: {train_loss:.3f}, Val loss: {val_loss:.3f}, "f"Epoch time = {(end_time - start_time):.3f}s"))

# Decode

In [144]:
# function to generate output sequence using greedy algorithm
def greedy_decode(model, src, src_mask, max_len, start_symbol):
    src = src.to(DEVICE)
    src_mask = src_mask.to(DEVICE)

    memory = model.encode(src, src_mask)
    ys = torch.ones(1, 1).fill_(start_symbol).type(torch.long).to(DEVICE)
    for i in range(max_len-1):
        memory = memory.to(DEVICE)
        tgt_mask = (generate_square_subsequent_mask(ys.size(0))
                    .type(torch.bool)).to(DEVICE)
        out = model.decode(ys, memory, tgt_mask)
        out = out.transpose(0, 1)
        prob = model.generator(out[:, -1])
        _, next_word = torch.max(prob, dim=1)
        next_word = next_word.item()

        ys = torch.cat([ys,
                        torch.ones(1, 1).type_as(src.data).fill_(next_word)], dim=0)
        if next_word == EOS_IDX:
            break
    return ys

# function to generate output sequence using beam search algorithm
def beam_search_decode(model, src, src_mask, max_len, start_symbol, beam_size, alpha):

    src = src.to(DEVICE)
    src_mask = src_mask.to(DEVICE)

    memory = model.encode(src, src_mask).to(DEVICE)
    ys = torch.ones(1, 1).fill_(start_symbol).type(torch.long).to(DEVICE)

    active_beams = [(0, [start_symbol])]  # Each beam is a tuple(score, token_ids)

    for step in range(max_len):
        new_beams = []
        for score, token_ids in active_beams:
            tgt_mask = (generate_square_subsequent_mask(ys.size(0))
                        .type(torch.bool)).to(DEVICE)
            out = model.decode(ys, memory, tgt_mask)
            out = out.transpose(0, 1)
            probabilities = model.generator(out[:, -1])

            top_probs, top_tokens = torch.topk(probabilities, beam_size) # Get top beam_size next tokens and their log probabilities
            top_probs, top_tokens = top_probs.squeeze(0), top_tokens.squeeze(0)

            # Add new hypotheses to new_beams
            for i in range(beam_size):
                next_score = score + top_probs[i].item()
                next_token_ids = token_ids + [top_tokens[i].item()]
                new_beams.append((next_score, next_token_ids))

                if next_token_ids[-1] == EOS_IDX:  # Early stopping
                    return next_token_ids

        # Keep top beam_size hypotheses
        active_beams = sorted(new_beams, key=lambda x: x[0] / (len(x[1]) ** alpha), reverse=True)[:beam_size]

    # Return hypothesis with the highest score
    return max(active_beams, key=lambda x: x[0] / (len(x[1]) ** alpha))[1]

In [145]:
from torchtext.data.metrics import bleu_score

def translate_greedy(model: torch.nn.Module, src):
    model.eval()
    num_tokens = src.shape[0]
    src_mask = (torch.zeros(num_tokens, num_tokens)).type(torch.bool)
    tgt_tokens = greedy_decode(
        model,  src, src_mask, max_len=num_tokens + 5, start_symbol=BOS_IDX).flatten()

    return " ".join(vocab_transform[TGT_LANGUAGE].lookup_tokens(list(tgt_tokens.cpu().numpy()))).replace("<bos>", "").replace("<eos>", "").split(" ")[1:-1]

def translate_beam_search(model: torch.nn.Module, src, beam_size=4, alpha=0.6):
    model.eval()
    num_tokens = src.shape[0]
    src_mask = (torch.zeros(num_tokens, num_tokens)).type(torch.bool)
    tgt_tokens = beam_search_decode(
        model,  src, src_mask, max_len=num_tokens + 5, start_symbol=BOS_IDX, beam_size=beam_size, alpha=alpha)

    return " ".join(vocab_transform[TGT_LANGUAGE].lookup_tokens(tgt_tokens)).replace("<bos>", "").replace("<eos>", "").split(" ")[1:-1]

def evaluate_test(model, mode):
    model.eval()
    losses = 0
    candidate_corpus = []
    reference_corpus = []

    test_iter = Multi30k(split='valid', language_pair=(SRC_LANGUAGE, TGT_LANGUAGE)) # The test split in PyTorch is broken, so evaluating on valid split
    test_dataloader = DataLoader(test_iter, batch_size=1, collate_fn=collate_fn)  # Set batch_size to 1 for simplicity
    for src, tgt in test_dataloader:
        src = src.to(DEVICE)
        tgt = tgt.to(DEVICE)

        # Translate the source sentence
        if mode == 'greedy': # Greedy decoder
          translated_sentence = translate_greedy(model, src)
        elif mode == 'beam': # Beam search decoder
          translated_sentence = translate_beam_search(model, src, beam_size=2, alpha=0.6) # Doesn't work great, likely due to small training set
        else:
            sys.exit('Mode not supported')

        candidate_corpus.append(translated_sentence)

        # Prepare the reference sentence
        tgt_sent = [tok for tok in tgt.view(-1).cpu().numpy()]
        tgt_sent = " ".join(vocab_transform[TGT_LANGUAGE].lookup_tokens(tgt_sent)).replace("<bos>", "").replace("<eos>", ".").split(" ")[1:]
        reference_corpus.append([tgt_sent])


    # Calculate BLEU score
    bleu = bleu_score(candidate_corpus, reference_corpus, max_n=4, weights = [0.25]*4)
    return bleu

# Run

In [None]:
train(transformer)



In [None]:
best_transformer = torch.load('model.pt')

In [None]:
bleu = evaluate_test(best_transformer, 'greedy')

In [None]:
print(f'Transformer BLEU: {bleu*100:.2f}')