[PyTorch Tutorial](https://pytorch.org/tutorials/intermediate/seq2seq_translation_tutorial.html)

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim

from torch.utils.data import Dataset, DataLoader

import unicodedata
import re
import random

In [2]:
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [3]:
!cd ../datasets/ && { curl -O https://www.manythings.org/anki/deu-eng.zip ; cd -; }

  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
100 9376k  100 9376k    0     0   188k      0  0:00:49  0:00:49 --:--:--  318k0  173k
/home/petruschka/repos/World4AI/website/src/notebooks/sequence_modelling


In [4]:
!rm -rf ../datasets/deu_eng/
!unzip ../datasets/deu-eng.zip -d ../datasets/deu_eng

Archive:  ../datasets/deu-eng.zip
  inflating: ../datasets/deu_eng/deu.txt  
  inflating: ../datasets/deu_eng/_about.txt  


In [5]:
!ls ../datasets/deu_eng

_about.txt  deu.txt


In [6]:
!head ../datasets/deu_eng/deu.txt

Go.	Geh.	CC-BY 2.0 (France) Attribution: tatoeba.org #2877272 (CM) & #8597805 (Roujin)
Hi.	Hallo!	CC-BY 2.0 (France) Attribution: tatoeba.org #538123 (CM) & #380701 (cburgmer)
Hi.	Grüß Gott!	CC-BY 2.0 (France) Attribution: tatoeba.org #538123 (CM) & #659813 (Esperantostern)
Run!	Lauf!	CC-BY 2.0 (France) Attribution: tatoeba.org #906328 (papabear) & #941078 (Fingerhut)
Run.	Lauf!	CC-BY 2.0 (France) Attribution: tatoeba.org #4008918 (JSakuragi) & #941078 (Fingerhut)
Wow!	Potzdonner!	CC-BY 2.0 (France) Attribution: tatoeba.org #52027 (Zifre) & #2122382 (Pfirsichbaeumchen)
Wow!	Donnerwetter!	CC-BY 2.0 (France) Attribution: tatoeba.org #52027 (Zifre) & #2122391 (Pfirsichbaeumchen)
Duck!	Kopf runter!	CC-BY 2.0 (France) Attribution: tatoeba.org #280158 (CM) & #9968521 (wolfgangth)
Fire!	Feuer!	CC-BY 2.0 (France) Attribution: tatoeba.org #1829639 (Spamster) & #1958697 (Tamy)
Help!	Hilfe!	CC-BY 2.0 (France) Attribution: tatoeba.org #435084 (lukaszpp) & #575889 (MUIRIEL)


In [3]:
def normalize(s):
    s = s.lower().strip()
    s = re.sub(r"([.!?])", r" \1", s)
    return s

def tokenizer(s):
    s = normalize(s)
    return s.split(' ')

In [4]:
def read_pairs(max_len=30):
    print("Reading lines...")
    en_seq = []
    de_seq = []
    with open('../datasets/deu_eng/deu.txt', 'r', encoding='utf-8') as file:
        print(f"Tokenizing and removing sentences larger than {max_len}")
        for line in file:
            pairs = line.split('\t')
            if len(pairs[0]) <= max_len or len(pairs[1]) <= max_len:
                en_seq.append(tokenizer(pairs[0]))
                de_seq.append(tokenizer(pairs[1]))
        print(f"The dataset has {len(en_seq)} pairs")
        return en_seq, de_seq

In [5]:
en_seq, de_seq = read_pairs()

Reading lines...
Tokenizing and removing sentences larger than 30
The dataset has 146276 pairs


In [6]:
from sklearn.model_selection import train_test_split

In [7]:
#separate into train test split

In [8]:
# train_frac = 0.8
# val_frac = 0.1
# test_frac = 0.1
train_en, test_val_en, train_de, test_val_de = train_test_split(en_seq, de_seq, test_size=0.2)
val_en, test_en, val_de, test_de = train_test_split(test_val_en, test_val_de, test_size=0.5)

In [9]:
class PairDataset(Dataset):
    def __init__(self, en, de):
        assert len(en) == len(de)
        self.en = en
        self.de = de
    
    def __len__(self):
        return len(self.en)
    
    def __getitem__(self, idx):
        return self.en[idx], self.de[idx]

In [10]:
train_dataset = PairDataset(train_en, train_de)
val_dataset = PairDataset(val_en, val_de)
test_dataset = PairDataset(test_en, test_de)

In [11]:
from collections import Counter, OrderedDict

In [12]:
en_counter = Counter()
de_counter = Counter()

for line in train_en:
    en_counter.update(line)

for line in train_de:
    de_counter.update(line)

In [13]:
en_sorted_by_freq_tuples = sorted(en_counter.items(), key=lambda x: x[1], reverse=True)
en_ordered_dict = OrderedDict(en_sorted_by_freq_tuples)

de_sorted_by_freq_tuples = sorted(de_counter.items(), key=lambda x: x[1], reverse=True)
de_ordered_dict = OrderedDict(de_sorted_by_freq_tuples)

In [14]:
import torchtext
en_vocab = torchtext.vocab.vocab(en_ordered_dict, min_freq = 5, specials=['<pad>', '<unk>', '<sos>', '<eos>'], special_first = True)
de_vocab = torchtext.vocab.vocab(de_ordered_dict, min_freq = 5, specials=['<pad>', '<unk>', '<sos>', '<eos>'], special_first = True)

en_vocab.set_default_index(1)
de_vocab.set_default_index(1)

In [15]:
en_vocab(["<eos>"])

[3]

In [16]:
print(en_vocab(["what", "are", "you", "doing", "?"]))
print(de_vocab(["was", "machst", "du", "?"]))

[27, 23, 6, 124, 8]
[22, 345, 12, 7]


In [17]:
def collate(batch):
    en, de, seq_len = [], [], []
    for en_token, de_token in batch:
        # add <sos> at start and <eos> at end
        for lang in [en_token, de_token]:
            lang.append('<eos>')
            lang.insert(0, '<sos>')
        en.append(torch.tensor(en_vocab(en_token), dtype=torch.int64))
        de.append(torch.tensor(de_vocab(de_token), dtype=torch.int64))
        seq_len.append(len(en_token))

    return nn.utils.rnn.pad_sequence(en, batch_first=True), nn.utils.rnn.pad_sequence(de, batch_first=True), torch.tensor(seq_len)

In [343]:
BATCH_SIZE=128
train_dataloader = DataLoader(dataset=train_dataset, 
                              batch_size=BATCH_SIZE,
                              shuffle=True,
                              num_workers=2,
                              drop_last=True,
                              collate_fn=collate)
val_dataloader = DataLoader(dataset=val_dataset, 
                              batch_size=BATCH_SIZE,
                              shuffle=False,
                              num_workers=2,
                              drop_last=False,
                              collate_fn=collate)
test_dataloader = DataLoader(dataset=test_dataset, 
                              batch_size=BATCH_SIZE,
                              shuffle=False,
                              num_workers=2,
                              drop_last=False,
                              collate_fn=collate)

In [344]:
class Encoder(nn.Module):
    
    def __init__(self, num_embeddings, embedding_dim, hidden_size=128, lstm_layers=2):
        super().__init__()
        self.embedding = nn.Embedding(num_embeddings, embedding_dim, padding_idx=0)
        self.lstm = nn.LSTM(input_size=embedding_dim, hidden_size=hidden_size, num_layers=lstm_layers, batch_first=True)
        
        
    def forward(self, x):
        x = self.embedding(x)
        _, (h_n, c_n) = self.lstm(x)
        return h_n, c_n

In [345]:
class Decoder(nn.Module):
    
    def __init__(self, num_embeddings, embedding_dim, hidden_size=128, lstm_layers=2):
        super().__init__()
        self.embedding = nn.Embedding(num_embeddings, embedding_dim, padding_idx=0)
        self.lstm_cell_list = nn.ModuleList([nn.LSTMCell(input_size=embedding_dim, hidden_size=hidden_size) for i in range(lstm_layers)])
        self.fc = nn.Linear(hidden_size, num_embeddings)
    
    def forward(self, x, h, c):
        x = self.embedding(x)
        h_n, c_n = torch.zeros_like(h, device=DEVICE), torch.zeros_like(c, device=DEVICE)
        for i, lstm_cell in enumerate(self.lstm_cell_list):
            (h_n[i], c_n[i]) = lstm_cell(x, (h[i], c[i]))
            x = h_n[i].clone()
        logits = self.fc(x)
        return logits, h_n, c_n

In [346]:
class EncoderDecoder(nn.Module):
    def __init__(self, encoder, decoder, teacher_forcing_ratio=0.5):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.teacher_forcing_ratio = teacher_forcing_ratio
    
    def forward(self, en_sequence, de_sequence):
        batch_size, sequence_len, num_de_embeddings = de_sequence.size()[0], de_sequence.size()[1], self.decoder.embedding.num_embeddings
        
        # minus 1 due to fewer predictions as inputs, we don't predict <sos>
        outputs = torch.zeros(batch_size, sequence_len-1, num_de_embeddings, device=DEVICE)

        h_n, c_n = self.encoder(en_sequence)
        inp = de_sequence[:, 0]
        for i in range(1, sequence_len):
            logits, h_n, c_n = decoder(inp, h_n, c_n)
            outputs[:, i-1] = logits
            
            force = random.random() < self.teacher_forcing_ratio
            if force:
                inp = de_sequence[:, i]
            else:
                inp = logits.argmax(dim=1)
        
        return outputs
        

In [358]:
def track_performance(dataloader, model, criterion):
    # switch to evaluation mode
    model.eval()
    loss_sum = 0
    num_iterations = 0

    # no need to calculate gradients
    with torch.inference_mode():
        for en_sequence, de_sequence, _ in dataloader:
            en_sequence = en_sequence.to(DEVICE)
            de_sequence = de_sequence.to(DEVICE)

            logits = model(en_sequence, de_sequence)
            
            # we don't actually predict the <sos> token
            labels = de_sequence[:, 1:]
            # we need to reshape in order to be able to use these tensors with CrossEntropyLoss
            logits = logits.reshape(-1, logits.size()[2])
            labels = labels.reshape(-1)
            loss = criterion(logits, labels)
            loss_sum += loss.cpu().item()
            num_iterations+=1

    # we return the average loss and the accuracy
    return loss_sum/num_iterations


In [417]:
def train(num_epochs, train_dataloader, val_dataloader, model, optimizer, criterion, scheduler=None):
    min_loss = float("inf")
    for epoch in range(num_epochs):
        loss_sum = 0
        num_iterations = 0
        for en_sequence, de_sequence, _ in train_dataloader:
            model.train()

            optimizer.zero_grad()
            en_sequence = en_sequence.to(DEVICE)
            de_sequence = de_sequence.to(DEVICE)

            logits = model(en_sequence, de_sequence)
            # we don't actually predict the <sos> token
            labels = de_sequence[:, 1:]

            # we need to reshape in order to be able to use these tensors with CrossEntropyLoss
            logits = logits.reshape(-1, logits.size()[2])
            labels = labels.reshape(-1)
            loss = criterion(logits, labels)
            loss.backward()
            optimizer.step()
            
            loss_sum += loss.cpu().item()
            num_iterations += 1
        train_loss=loss_sum/num_iterations
        val_loss = track_performance(val_dataloader, model, criterion)
        if scheduler:
            scheduler.step(val_loss)
        print(f'Epoch: {epoch+1:>2}/{num_epochs} | Train Loss: {train_loss:.5f} | Val Loss: {val_loss:.5f}')
        
        if val_loss < min_loss:
            print("Saving Weights!")
            min_loss = val_loss
            torch.save({'encoder_weights': encoder.state_dict(), 'decoder_weights': decoder.state_dict()}, f='../temp/encoder_decoder.pt')

In [418]:
encoder = Encoder(num_embeddings=len(en_vocab), embedding_dim=128)
decoder = Decoder(num_embeddings=len(de_vocab), embedding_dim=128)
seq2seq = EncoderDecoder(encoder, decoder).to(DEVICE)

In [419]:
optimizer = optim.Adam(seq2seq.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss(ignore_index=0)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                       factor=0.1,
                                                       mode='min',
                                                       patience=2,
                                                       verbose=True)

num_epochs=25

In [420]:
train(num_epochs, train_dataloader, val_dataloader, seq2seq, optimizer, criterion, scheduler)

Epoch:  1/25 | Train Loss: 4.32860 | Val Loss: 3.66370
Saving Weights!
Epoch:  2/25 | Train Loss: 3.37974 | Val Loss: 3.08622
Saving Weights!
Epoch:  3/25 | Train Loss: 2.90776 | Val Loss: 2.74637
Saving Weights!
Epoch:  4/25 | Train Loss: 2.56833 | Val Loss: 2.46569
Saving Weights!
Epoch:  5/25 | Train Loss: 2.29745 | Val Loss: 2.24828
Saving Weights!
Epoch:  6/25 | Train Loss: 2.08052 | Val Loss: 2.11049
Saving Weights!
Epoch:  7/25 | Train Loss: 1.90842 | Val Loss: 1.98390
Saving Weights!
Epoch:  8/25 | Train Loss: 1.76101 | Val Loss: 1.88213
Saving Weights!
Epoch:  9/25 | Train Loss: 1.65104 | Val Loss: 1.81226
Saving Weights!
Epoch: 10/25 | Train Loss: 1.54629 | Val Loss: 1.76212
Saving Weights!
Epoch: 11/25 | Train Loss: 1.47077 | Val Loss: 1.72599
Saving Weights!
Epoch: 12/25 | Train Loss: 1.38351 | Val Loss: 1.66042
Saving Weights!
Epoch: 13/25 | Train Loss: 1.32567 | Val Loss: 1.65936
Saving Weights!
Epoch: 14/25 | Train Loss: 1.26437 | Val Loss: 1.61570
Saving Weights!
Epoch:

In [426]:
weights = torch.load('../temp/encoder_decoder.pt')
encoder_weights = weights['encoder_weights']
decoder_weights = weights['decoder_weights']

In [428]:
encoder.load_state_dict(encoder_weights)
decoder.load_state_dict(decoder_weights)

<All keys matched successfully>

In [429]:
def translate_sentence(sentence, vocab, encoder, decoder):
    with torch.inference_mode():
        outputs = []
        
        start_token = ["<sos>"]
        end_token = ["<eos>"]
        start_idx = vocab(start_token)[0]
        end_idx = vocab(end_token)[0]
                
        h_n, c_n = encoder(sentence)
        inp = torch.tensor([start_idx], device=DEVICE)
        while True:
            logits, h_n, c_n = decoder(inp, h_n, c_n)
            inp = logits.argmax(dim=1)
            outputs.append(inp.cpu().item())
            if inp.item() == end_idx:
                break
        return outputs

In [430]:
en_sequence, de_sequence, _ = next(iter(test_dataloader))
en_sequence = en_sequence.to(DEVICE)

In [448]:
for i in range(10):
    en_sentence = en_sequence[i].unsqueeze(0)
    de_sentence = de_sequence[i].unsqueeze(0)
    translation = translate_sentence(en_sentence, en_vocab, encoder, decoder)
    print('-'*130)
    print(f'English Sentence: {en_vocab.lookup_tokens(en_sentence[0].cpu().tolist())}')
    print(f'German Translation: {de_vocab.lookup_tokens(de_sentence[0].cpu().tolist())}')
    print(f'Model Translation: {de_vocab.lookup_tokens(translation)}')
    

----------------------------------------------------------------------------------------------------------------------------------
English Sentence: ['<sos>', 'control', 'yourself', '.', '<eos>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>']
German Translation: ['<sos>', '<unk>', 'dich', 'am', '<unk>', '!', '<eos>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>']
Model Translation: ['<unk>', 'dich', '!', '<eos>']
----------------------------------------------------------------------------------------------------------------------------------
English Sentence: ['<sos>', 'my', 'son', 'is', 'playing', 'in', 'the', 'rain', '.', '<eos>', '<pad>']
German Translation: ['<sos>', 'mein', 'sohn', 'spielt', 'im', 'regen', '.', '<eos>', '<pad>', '<pad>', '<pad>', '<pad>']
Model Translation: ['mein', 'sohn', 'ist', 'in', 'der', '<unk>', '.', '<eos>']
----------------------------------------------------------------------------------------------------------------------------------
English Senten