[PyTorch Tutorial](https://pytorch.org/tutorials/intermediate/seq2seq_translation_tutorial.html)

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim

from torch.utils.data import Dataset, DataLoader

import unicodedata
import re
import random

In [2]:
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [3]:
!cd ../datasets/ && { curl -O https://www.manythings.org/anki/deu-eng.zip ; cd -; }

  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
100 9376k  100 9376k    0     0   188k      0  0:00:49  0:00:49 --:--:--  318k0  173k
/home/petruschka/repos/World4AI/website/src/notebooks/sequence_modelling


In [4]:
!rm -rf ../datasets/deu_eng/
!unzip ../datasets/deu-eng.zip -d ../datasets/deu_eng

Archive:  ../datasets/deu-eng.zip
  inflating: ../datasets/deu_eng/deu.txt  
  inflating: ../datasets/deu_eng/_about.txt  


In [5]:
!ls ../datasets/deu_eng

_about.txt  deu.txt


In [6]:
!head ../datasets/deu_eng/deu.txt

Go.	Geh.	CC-BY 2.0 (France) Attribution: tatoeba.org #2877272 (CM) & #8597805 (Roujin)
Hi.	Hallo!	CC-BY 2.0 (France) Attribution: tatoeba.org #538123 (CM) & #380701 (cburgmer)
Hi.	Grüß Gott!	CC-BY 2.0 (France) Attribution: tatoeba.org #538123 (CM) & #659813 (Esperantostern)
Run!	Lauf!	CC-BY 2.0 (France) Attribution: tatoeba.org #906328 (papabear) & #941078 (Fingerhut)
Run.	Lauf!	CC-BY 2.0 (France) Attribution: tatoeba.org #4008918 (JSakuragi) & #941078 (Fingerhut)
Wow!	Potzdonner!	CC-BY 2.0 (France) Attribution: tatoeba.org #52027 (Zifre) & #2122382 (Pfirsichbaeumchen)
Wow!	Donnerwetter!	CC-BY 2.0 (France) Attribution: tatoeba.org #52027 (Zifre) & #2122391 (Pfirsichbaeumchen)
Duck!	Kopf runter!	CC-BY 2.0 (France) Attribution: tatoeba.org #280158 (CM) & #9968521 (wolfgangth)
Fire!	Feuer!	CC-BY 2.0 (France) Attribution: tatoeba.org #1829639 (Spamster) & #1958697 (Tamy)
Help!	Hilfe!	CC-BY 2.0 (France) Attribution: tatoeba.org #435084 (lukaszpp) & #575889 (MUIRIEL)


In [7]:
def normalize(s):
    s = s.lower().strip()
    s = re.sub(r"([.!?])", r" \1", s)
    return s

def tokenizer(s):
    s = normalize(s)
    return s.split(' ')

In [8]:
def read_pairs(max_len=30):
    print("Reading lines...")
    en_seq = []
    de_seq = []
    with open('../datasets/deu_eng/deu.txt', 'r', encoding='utf-8') as file:
        print(f"Tokenizing and removing sentences larger than {max_len}")
        for line in file:
            pairs = line.split('\t')
            if len(pairs[0]) <= max_len or len(pairs[1]) <= max_len:
                en_seq.append(tokenizer(pairs[0]))
                de_seq.append(tokenizer(pairs[1]))
        print(f"The dataset has {len(en_seq)} pairs")
        return en_seq, de_seq

In [9]:
en_seq, de_seq = read_pairs()

Reading lines...
Tokenizing and removing sentences larger than 30
The dataset has 146276 pairs


In [10]:
from sklearn.model_selection import train_test_split

In [11]:
#separate into train test split

In [12]:
# train_frac = 0.8
# val_frac = 0.1
# test_frac = 0.1
train_en, test_val_en, train_de, test_val_de = train_test_split(en_seq, de_seq, test_size=0.2)
val_en, test_en, val_de, test_de = train_test_split(test_val_en, test_val_de, test_size=0.5)

In [13]:
class PairDataset(Dataset):
    def __init__(self, en, de):
        assert len(en) == len(de)
        self.en = en
        self.de = de
    
    def __len__(self):
        return len(self.en)
    
    def __getitem__(self, idx):
        return self.en[idx], self.de[idx]

In [14]:
train_dataset = PairDataset(train_en, train_de)
val_dataset = PairDataset(val_en, val_de)
test_dataset = PairDataset(test_en, test_de)

In [15]:
from collections import Counter, OrderedDict

In [16]:
en_counter = Counter()
de_counter = Counter()

for line in train_en:
    en_counter.update(line)

for line in train_de:
    de_counter.update(line)

In [17]:
en_sorted_by_freq_tuples = sorted(en_counter.items(), key=lambda x: x[1], reverse=True)
en_ordered_dict = OrderedDict(en_sorted_by_freq_tuples)

de_sorted_by_freq_tuples = sorted(de_counter.items(), key=lambda x: x[1], reverse=True)
de_ordered_dict = OrderedDict(de_sorted_by_freq_tuples)

In [18]:
import torchtext
en_vocab = torchtext.vocab.vocab(en_ordered_dict, min_freq = 5, specials=['<pad>', '<unk>', '<sos>', '<eos>'], special_first = True)
de_vocab = torchtext.vocab.vocab(de_ordered_dict, min_freq = 5, specials=['<pad>', '<unk>', '<sos>', '<eos>'], special_first = True)

en_vocab.set_default_index(1)
de_vocab.set_default_index(1)

In [19]:
en_vocab(["<eos>"])

[3]

In [20]:
print(en_vocab(["what", "are", "you", "doing", "?"]))
print(de_vocab(["was", "machst", "du", "?"]))

[27, 23, 6, 125, 8]
[23, 342, 12, 7]


In [21]:
def collate(batch):
    en, de, seq_len = [], [], []
    for en_token, de_token in batch:
        # add <sos> at start and <eos> at end
        for lang in [en_token, de_token]:
            lang.append('<eos>')
            lang.insert(0, '<sos>')
        en.append(torch.tensor(en_vocab(en_token), dtype=torch.int64))
        de.append(torch.tensor(de_vocab(de_token), dtype=torch.int64))
        seq_len.append(len(en_token))

    return nn.utils.rnn.pad_sequence(en, batch_first=True), nn.utils.rnn.pad_sequence(de, batch_first=True), torch.tensor(seq_len)

In [22]:
BATCH_SIZE=32
train_dataloader = DataLoader(dataset=train_dataset, 
                              batch_size=BATCH_SIZE,
                              shuffle=True,
                              num_workers=2,
                              drop_last=True,
                              collate_fn=collate)
val_dataloader = DataLoader(dataset=val_dataset, 
                              batch_size=BATCH_SIZE,
                              shuffle=False,
                              num_workers=2,
                              drop_last=False,
                              collate_fn=collate)
test_dataloader = DataLoader(dataset=test_dataset, 
                              batch_size=BATCH_SIZE,
                              shuffle=False,
                              num_workers=2,
                              drop_last=False,
                              collate_fn=collate)

In [23]:
class Encoder(nn.Module):
    
    def __init__(self, num_embeddings, embedding_dim, hidden_size=128, lstm_layers=2):
        super().__init__()
        self.embedding = nn.Embedding(num_embeddings, embedding_dim, padding_idx=0)
        self.lstm = nn.LSTM(input_size=embedding_dim, hidden_size=hidden_size, num_layers=lstm_layers, batch_first=True)
        
        
    def forward(self, x):
        x = self.embedding(x)
        _, (h_n, c_n) = self.lstm(x)
        return h_n, c_n

In [24]:
class Decoder(nn.Module):
    
    def __init__(self, num_embeddings, embedding_dim, hidden_size=128, lstm_layers=2):
        super().__init__()
        self.embedding = nn.Embedding(num_embeddings, embedding_dim, padding_idx=0)
        self.lstm_cell_list = nn.ModuleList([nn.LSTMCell(input_size=embedding_dim, hidden_size=hidden_size) for i in range(lstm_layers)])
        self.fc = nn.Linear(hidden_size, num_embeddings)
    
    def forward(self, x, h, c):
        x = self.embedding(x)
        h_n, c_n = torch.zeros_like(h, device=DEVICE), torch.zeros_like(c, device=DEVICE)
        for i, lstm_cell in enumerate(self.lstm_cell_list):
            (h_n[i], c_n[i]) = lstm_cell(x, (h[i], c[i]))
            x = h_n[i]
        logits = self.fc(x)
        return logits, h_n, c_n

In [66]:
class EncoderDecoder(nn.Module):
    def __init__(self, encoder, decoder, teacher_forcing_ratio=0.5):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.teacher_forcing_ratio = teacher_forcing_ratio
    
    def forward(self, en_sequence, de_sequence):
        batch_size, sequence_len, num_de_embeddings = de_sequence.size()[0], de_sequence.size()[1], self.decoder.embedding.num_embeddings
        
        # minus 1 due to fewer predictions as inputs, we don't predict <sos>
        outputs = torch.zeros(batch_size, sequence_len-1, num_de_embeddings, device=DEVICE)

        h_n, c_n = self.encoder(en_sequence)
        inp = de_batch[:, 0]
        for i in range(1, sequence_len):
            logits, h_n, c_n = decoder(de_batch[:, i-1], h_n, c_n)
            outputs[:, i-1] = logits
            
            force = random.random() < self.teacher_forcing_ratio
            if force:
                inp = de_batch[:, i]
            else:
                inp = logits.argmax(dim=1)
        
        return outputs

In [67]:
encoder = Encoder(num_embeddings=len(en_vocab), embedding_dim=128)
decoder = Decoder(num_embeddings=len(de_vocab), embedding_dim=128)
seq2seq = EncoderDecoder(encoder, decoder).to(DEVICE)

In [68]:
# # for debugging purposes
# en_batch, de_batch, seq_len = next(iter(train_dataloader))
# en_batch = en_batch.to(DEVICE)
# de_batch = de_batch.to(DEVICE)
# seq2seq(en_batch, de_batch)

In [69]:
optimizer = optim.Adam(seq2seq.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss(ignore_index=0)

In [95]:
def train(num_epochs, dataloader, model, optimizer, criterion):
    model.train()
    for epoch, (en_sequence, de_sequence, _) in enumerate(dataloader):
        optimizer.zero_grad()
        en_sequence = en_batch.to(DEVICE)
        de_sequence = de_batch.to(DEVICE)
        
        logits = model(en_sequence, de_sequence)
        print(logits.shape)
        # we don't actually predict the <sos> token
        labels = de_sequence[:, 1:]
        
        loss = criterion(logits, labels)
        
        break

In [96]:
train(10, train_dataloader, seq2seq, optimizer, criterion)

torch.Size([32, 10, 5804])
tensor([[  13,    8,   56,   58,  339,    4,    3,    0,    0,    0],
        [ 460,  290,   24,   13,  416,    4,    3,    0,    0,    0],
        [   6,  629,  109,   38,    5,   20, 4909,    8,    4,    3],
        [ 189,  110,    6,   10,   78,   15,   68,    4,    3,    0],
        [  18,    8,   20,  228, 2417,    4,    3,    0,    0,    0],
        [   6,   21,   31, 1828,   56,   44,    4,    3,    0,    0],
        [ 137,    8,    9, 4894,    7,    3,    0,    0,    0,    0],
        [  23,  429,  225,   38,   43,    1,    7,    3,    0,    0],
        [  12,  181,   22, 2493,   32,   11,  532,    4,    3,    0],
        [1126,    5,   40,    7,    3,    0,    0,    0,    0,    0],
        [  76,    5,   88,   48,    7,    3,    0,    0,    0,    0],
        [   6,   28, 1306,  127, 2571,    4,    3,    0,    0,    0],
        [   5,   14,  172,   10, 1469,    4,    3,    0,    0,    0],
        [   5,   14,   13,    4,    3,    0,    0,    0,    0, 

RuntimeError: Expected target size [32, 5804], got [32, 10]