<a href="https://colab.research.google.com/github/AchrafAsh/awesome-pytorch-notebooks/blob/main/01_machine_translation_with_transformers.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Machine Translation with Transformers

<img src="https://pytorch.org/tutorials/_images/transformer_architecture.jpg" width="400px" />

Official Pytorch Tutorial: https://pytorch.org/tutorials/beginner/torchtext_translation_tutorial.html

In [24]:
import io
import os
import pandas as pd
import gensim

import torch
import torch.nn as nn
import torch.nn.functional as F

from collections import Counter

from torch.nn.utils.rnn import pad_sequence  # padding of every batch
from torch.utils.data import Dataset, DataLoader

from torchtext.data.utils import get_tokenizer
from torchtext.utils import download_from_url, extract_archive
from torchtext.vocab import Vocab

from typing import List

## The Data

### Dataset and Data processing

In [None]:
!python -m spacy download en
!python -m spacy download fr

In [20]:
fr_tokenizer = get_tokenizer('spacy', language='fr')
en_tokenizer = get_tokenizer('spacy', language='en')

In [19]:
url_base = 'https://raw.githubusercontent.com/multi30k/dataset/master/data/task1/raw/'
train_urls = ('train.fr.gz', 'train.en.gz')
val_urls = ('val.fr.gz', 'val.en.gz')
test_urls = ('test_2016_flickr.fr.gz', 'test_2016_flickr.en.gz')

train_filepaths = [extract_archive(download_from_url(url_base + url))[0] for url in train_urls]
val_filepaths = [extract_archive(download_from_url(url_base + url))[0] for url in val_urls]
test_filepaths = [extract_archive(download_from_url(url_base + url))[0] for url in test_urls]

train.fr.gz: 100%|██████████| 604k/604k [00:00<00:00, 17.4MB/s]
train.en.gz: 100%|██████████| 569k/569k [00:00<00:00, 8.40MB/s]
val.fr.gz: 100%|██████████| 23.0k/23.0k [00:00<00:00, 9.95MB/s]
val.en.gz: 100%|██████████| 21.6k/21.6k [00:00<00:00, 3.58MB/s]
test_2016_flickr.fr.gz: 100%|██████████| 22.3k/22.3k [00:00<00:00, 8.29MB/s]
test_2016_flickr.en.gz: 100%|██████████| 21.1k/21.1k [00:00<00:00, 7.70MB/s]


In [25]:
def build_vocab(filepath, tokenizer):
    counter = Counter()
    with io.open(filepath, encoding="utf8") as f:
        for string_ in f:
            counter.update(tokenizer(string_))
    return Vocab(counter, specials=['<unk>', '<pad>', '<bos>', '<eos>'])

In [26]:
fr_vocab = build_vocab(train_filepaths[0], fr_tokenizer)
en_vocab = build_vocab(train_filepaths[1], en_tokenizer)

In [27]:
def data_process(filepaths):
    raw_fr_iter = iter(io.open(filepaths[0], encoding="utf8"))
    raw_en_iter = iter(io.open(filepaths[1], encoding="utf8"))
    data = []
    for (raw_fr, raw_en) in zip(raw_fr_iter, raw_en_iter):
        fr_tensor_ = torch.tensor([fr_vocab[token] for token in fr_tokenizer(raw_fr)],
                                dtype=torch.long)
        en_tensor_ = torch.tensor([en_vocab[token] for token in en_tokenizer(raw_en)],
                                dtype=torch.long)
        data.append((fr_tensor_, en_tensor_))
    return data

In [None]:
train_data = data_process(train_filepaths)
val_data = data_process(val_filepaths)
test_data = data_process(test_filepaths)

### Data Loader

In [None]:
BATCH_SIZE = 128
PAD_IDX = de_vocab['<pad>']
BOS_IDX = de_vocab['<bos>']
EOS_IDX = de_vocab['<eos>']

In [None]:
def generate_batch(data_batch):
    fr_batch, en_batch = [], []
    for (fr_item, en_item) in data_batch:
        fr_batch.append(torch.cat([torch.tensor([BOS_IDX]), fr_item, torch.tensor([EOS_IDX])], dim=0))
        en_batch.append(torch.cat([torch.tensor([BOS_IDX]), en_item, torch.tensor([EOS_IDX])], dim=0))
    fr_batch = pad_sequence(fr_batch, padding_value=PAD_IDX)
    en_batch = pad_sequence(en_batch, padding_value=PAD_IDX)
    return fr_batch, en_batch

In [None]:
train_iter = DataLoader(train_data, batch_size=BATCH_SIZE,
                        shuffle=True, collate_fn=generate_batch)
valid_iter = DataLoader(val_data, batch_size=BATCH_SIZE,
                        shuffle=True, collate_fn=generate_batch)
test_iter = DataLoader(test_data, batch_size=BATCH_SIZE,
                       shuffle=True, collate_fn=generate_batch)

## The model

- Embeddings → encode one-hot-encoded words as continuous vectors to catch semantic (might use pre-trained word2vec for that)
- Transformer Block stacked

In [None]:
class Translator(nn.Module):
    def __init__(self, 
                 src_vocab_size: int,
                 trgt_vocab_size: int,
                 hidden_dim:int=124,
                 word_vectors=None):
        
        super().__init__()
        if word_vectors not None:
            self.embedding = nn.Embedding.from_pretrained(weight)
            assert hidden_dim == weight.size(0) # TODO: update to make sure the output is the same as the input of the transformer
        else:
            self.embedding = nn.Embedding(num_embeddings=src_vocab_size,
                                          embedding_dim=hidden_dim)
        
        self.transformer = nn.Transformer(d_model=hidden_dim, nhead=8) # docs: https://pytorch.org/docs/stable/generated/torch.nn.Transformer.html
    
    def load_embeddings(self, keyed_vectors):
        self.embedd

    def forward(self, src, tgt):
        src, tgt = self.embedding(src), self.embedding(tgt)
        return self.transformer(src, tgt)

## The Training

### Optional: pre-trained embeddings

In [None]:
# TODO: download a pre-trained word2vec (a very small one to see if a pre-trained yield better results)
!wget 

In [None]:
# Load pre-trained word vectors
model = gensim.models.KeyedVectors.load_word2vec_format('path/to/file')
weights = torch.FloatTensor(model.vectors)

### Utility functions

In [None]:
INPUT_DIM = len(de_vocab)
OUTPUT_DIM = len(en_vocab)
# ENC_EMB_DIM = 256
# DEC_EMB_DIM = 256
# ENC_HID_DIM = 512
# DEC_HID_DIM = 512
# ATTN_DIM = 64
# ENC_DROPOUT = 0.5
# DEC_DROPOUT = 0.5

ENC_EMB_DIM = 32
DEC_EMB_DIM = 32
ENC_HID_DIM = 64
DEC_HID_DIM = 64
ATTN_DIM = 8
ENC_DROPOUT = 0.5
DEC_DROPOUT = 0.5

In [None]:
model = Translator()

def count_parameters(model: nn.Module):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"The model has {count_parameters(model):,} parameters 🚀")

In [None]:
def run(model:nn.Module,
        iterator: DataLoader,
        epochs: int,
        lr:float=0.01,
        weight_decay:float=0.001):
    
    optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)

    for epoch in range(1, epochs+1):
        total_loss = 0
        for idx, data in enumerate(dataset):
            src, tgt = zip(data)
            loss = train(model, src, tgt, optimizer)
            total_loss += loss

        print(f"Epoch: [{epoch} / {epochs}] | Loss: {total_loss}")


def train(model, src, tgt, optimizer):
    PAD_IDX = en_vocab.stoi['<pad>']

    criterion = nn.CrossEntropyLoss(ignore_index=PAD_IDX)

    model.train()
    output = model(src, tgt)
    loss = nn.CrossEntropyLoss()(output, tgt[])
    
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    return loss
    

def evaluate(model, dataset):
    model.eval()
    criterion = nn.CrossEntropyLoss()
    total_loss = 0
    accuracy = 0

    for idx, data in enumerate(dataset):
        src, tgt = zip(data)
        output = model(src, tgt)
        loss = criterion(output, tgt)
        total_loss += loss
    
    print(f"Total Loss: {total_loss}")

# Building the Transformer from scratch

- Attention is all you need: [link to paper]

[image of the architecture]

## Self-Attention

## Transformer Block

## Encode / Decoder

## Putting everything together