# Chapter 18 - Transformer

 This Notebook covered how to train and test a Transformer model using the IWSLT2016 dataset. We discussed preparing the dataset, setting up the training loop with appropriate loss and optimization techniques, and evaluating the model's performance on unseen data. This hands-on approach provides a practical understanding of training and testing Transformer models for language translation tasks. 

# Dataset

In [None]:
from torchtext.datasets import IWSLT2016
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator

from torch.utils.data import DataLoader
from torch.nn.utils.rnn import pad_sequence

# Tokenizer
de_tokenizer = get_tokenizer('spacy', language='de')
en_tokenizer = get_tokenizer('spacy', language='en')

# Load dataset
train_iter, _, _ = IWSLT2016(split='train', language_pair=('de', 'en'))

# Build vocab
de_vocab = build_vocab_from_iterator((de_tokenizer(de) for de, _ in train_iter),
                                     specials=["<unk>", "<pad>", "<bos>", "<eos>"])

en_vocab = build_vocab_from_iterator((en_tokenizer(en) for _, en in train_iter),
                                     specials=["<unk>", "<pad>", "<bos>", "<eos>"])

de_vocab.set_default_index(de_vocab["<unk>"])
en_vocab.set_default_index(en_vocab["<unk>"])

In [3]:
def data_process(raw_data_iter): 
    data = [] 
    for (raw_de, raw_en) in raw_data_iter: 
        de_tensor = torch.tensor([de_vocab[token] for token in de_tokenizer(raw_de)], dtype=torch.long) 
        en_tensor = torch.tensor([en_vocab[token] for token in en_tokenizer(raw_en)], dtype=torch.long) 
        data.append((de_tensor, en_tensor)) 
    return data 

In [4]:
# DataLoader 

def collate_fn(batch): 
    de_batch, en_batch = [], [] 
    for de_item, en_item in batch: 
        de_batch.append(torch.cat([torch.tensor([de_vocab["<bos>"]]), de_item, torch.tensor([de_vocab["<eos>"]])], dim=0)) 
        en_batch.append(torch.cat([torch.tensor([en_vocab["<bos>"]]), en_item, torch.tensor([en_vocab["<eos>"]])], dim=0)) 

    de_batch = pad_sequence(de_batch, padding_value=de_vocab["<pad>"]) 
    en_batch = pad_sequence(en_batch, padding_value=en_vocab["<pad>"]) 

    return de_batch, en_batch 

In [ ]:
train_data = data_process(train_iter) 

train_dataloader = DataLoader(train_data, batch_size=32, shuffle=True, collate_fn=collate_fn) 

# Train Loop

In [ ]:
import torch.optim as optim 
import torch.nn as nn 
from transformer import Transformer

# Model, Loss, and Optimizer 
model = Transformer(embed_size=512, num_layers=6, heads=8, ff_hidden_size=2048, dropout_rate=0.1, vocab_size=len(en_vocab), max_length=100) 
loss_fn = nn.CrossEntropyLoss(ignore_index=en_vocab["<pad>"]) 
optimizer = optim.Adam(model.parameters(), lr=0.0001) 

# Training loop 
num_epochs = 10 
for epoch 
    model.train() 
    total_loss = 0 
    for de_batch, en_batch in train_dataloader: 
        optimizer.zero_grad() 
        output = model(de_batch, en_batch[:-1])  # Exclude <eos> token for target input 
        loss = loss_fn(output.reshape(-1, output.size(-1)), en_batch[1:].reshape(-1))  # Shift target for loss calculation 
        loss.backward() 
        optimizer.step() 
        total_loss += loss.item() 
    print(f"Epoch {epoch+1}, Loss: {total_loss / len(train_dataloader)}") 

# Model Eval

In [ ]:
from nltk.translate.bleu_score import corpus_bleu 

def evaluate(model, data_iter): 
    model.eval() 
    predictions = [] 
    references = [] 
    
    with torch.no_grad(): 
        for de_batch, en_batch in data_iter: 
            output = model(de_batch)  # No target input for evaluation 
            output = output.argmax(dim=-1) 
            predictions.extend([[en_vocab.itos[idx] for idx in sentence] for sentence in output]) 
            references.extend([[[en_vocab.itos[idx] for idx in sentence]] for sentence in en_batch]) 
    bleu_score = corpus_bleu(references, predictions) 
    print(f"BLEU Score: {bleu_score}") 

# Assuming test_data is prepared similar to train_data 
test_iter, _ = IWSLT2016(split='test', language_pair=('de', 'en')) 
test_data = data_process(test_iter) 
test_dataloader = DataLoader(test_data, batch_size=32, shuffle=False, collate_fn=collate_fn) 

In [ ]:
evaluate(model, test_dataloader) 