In [59]:
# %pip install torchtext portalocker spacy

In [60]:
import pandas as pd
import numpy as np
import torch 
from torch import nn
import matplotlib.pyplot as plt
import torchtext
from torchtext.data.utils import get_tokenizer
from collections import Counter
from torch.utils.data import DataLoader
from torch.nn.utils.rnn import pad_sequence

device = (
    "cuda" if torch.cuda.is_available()
    else "mps" if torch.backends.mps.is_available()
    else "cpu"
)


In [61]:
df = pd.read_csv("./.data/datasets/fra.txt", delimiter='\t', header=None)[[0, 1]]
print(df.head())
df = np.array(df)

     0           1
0  Go.        Va !
1  Go.     Marche.
2  Go.  En route !
3  Go.     Bouge !
4  Hi.     Salut !


In [62]:
# !python3 -m spacy download en_core_web_sm
# !python3 -m spacy download fr_core_news_sm

In [63]:
en_tokenizer = get_tokenizer('spacy', language='en_core_web_sm')
fr_tokenizer = get_tokenizer('spacy', language='fr_core_news_sm')

en_counter = Counter()
fr_counter = Counter()

_df = df
df = []

for e, f in _df:

    et = en_tokenizer(e.lower())
    ft = fr_tokenizer(f.lower())

    en_counter.update(et)
    fr_counter.update(ft)

    df.append([et, ft])

print(len(en_counter), len(fr_counter))

15959 27135


In [64]:
en_vocab = torchtext.vocab.vocab(en_counter, specials=['<unk>', '<pad>', '<eos>'], special_first=True)
fr_vocab = torchtext.vocab.vocab(fr_counter, specials=['<unk>', '<pad>', '<eos>', '<sos>'], special_first=True)

PAD_EN = en_vocab['<pad>']
PAD_FR = fr_vocab['<pad>']
EOS_EN = en_vocab['<eos>']
EOS_FR = fr_vocab['<eos>']
SOS_FR = fr_vocab['<sos>']

en_vocab.set_default_index = en_vocab['<unk>']
fr_vocab.set_default_index = fr_vocab['<unk>']

del en_counter
del fr_counter 

In [65]:
def make_batch(batch):
    eb = []
    fb = [] 
    for e, f in batch:
        tte = torch.LongTensor([en_vocab[k] for k in e])
        ttf = torch.LongTensor([fr_vocab[k] for k in f])

        eb.append(torch.cat([tte, torch.LongTensor([EOS_EN])], dim=0))
        fb.append(torch.cat([ttf, torch.LongTensor([EOS_FR])], dim=0))

        en_batch = pad_sequence(eb, batch_first=True, padding_value=PAD_EN)
        fr_batch = pad_sequence(fb, batch_first=True, padding_value=PAD_FR)

    return en_batch, fr_batch 

dl = DataLoader(df, batch_size=128, shuffle=True, collate_fn=make_batch)

In [66]:
_, batch = next(enumerate(dl))
print(batch[0].shape, batch[1].shape)
print(batch[0][0])

torch.Size([128, 18]) torch.Size([128, 22])
tensor([129, 258, 890, 121, 460, 278, 129,   9,   2,   1,   1,   1,   1,   1,
          1,   1,   1,   1])


In [67]:
class Encoder(nn.Module): # TODO: make this bidirectional
    def __init__(self, input_size, hidden_size):
        super().__init__()

        self.embedding = nn.Embedding(input_size, hidden_size)

        # just 1 lstm layer
        self.lstm = nn.LSTM(hidden_size, hidden_size, batch_first=True, bidirectional=True)
        self.linear = nn.Linear(hidden_size * 2, hidden_size)
        self.relu = nn.ReLU()

    def forward(self, input):
        embedded = self.embedding(input)
        output, _ = self.lstm(embedded)
        return self.relu(self.linear(output))

# test
enc = Encoder(len(en_vocab), 64)
o_encoder = enc.forward(batch[0])

print(o_encoder.shape)

torch.Size([128, 18, 64])


In [68]:
class PreAttentionDecoder(nn.Module):
    ### TODO: implement teacher forcing, add to each sequence a <SOS> character
    
    def __init__(self, target_vocab, hidden_size):
        super().__init__()
        self.embedding = nn.Embedding(target_vocab, hidden_size)
        self.lstm = nn.LSTM(hidden_size, hidden_size, batch_first=True, bidirectional=False)

    def forward(self, input, word_by_word_idx=-1):

        # shift right with the SOS token
        shift_right = torch.LongTensor([SOS_FR] * input.shape[0]).reshape((input.shape[0], 1)).to(device)
        input = torch.cat([shift_right, input], dim=1)

        # full sentence
        if word_by_word_idx == -1: 
            embedded = self.embedding(input)
            output, _ = self.lstm(embedded)
            return output

        else:
            # a single word
            embedded = self.embedding(input[:, word_by_word_idx : word_by_word_idx + 1])
            if word_by_word_idx == 0:
                self.output, self.prev_state = self.lstm(embedded)
            else:
                output, self.prev_state = self.lstm(embedded, self.prev_state)    
                self.output = torch.cat([self.output, output], axis=1)
            
            return self.output

pad = PreAttentionDecoder(len(fr_vocab), 64).to(device)
o_preattn = pad.forward(batch[1].to(device))

print(o_preattn.shape)

o_preattn_word = None
for i in range(0, 3):
    o_preattn_word = pad.forward(batch[1].to(device), i)
    print(o_preattn_word.shape)


torch.Size([128, 23, 64])
torch.Size([128, 1, 64])
torch.Size([128, 2, 64])
torch.Size([128, 3, 64])


In [69]:
class ScaledDotProductAttention(nn.Module):

    def __init__(self):
        super().__init__()

    def forward(self, encoder_output, pre_attention_decoder_output):
        key = encoder_output # what is input
        query = pre_attention_decoder_output # what has been so far translated
        value = encoder_output # the value we want to translate

        dk = torch.sqrt(torch.Tensor([key.shape[-1]])).to(device)
        key_transpose = torch.transpose(key, 1, 2)
        
        mm = torch.bmm(query, key_transpose) / dk
        mm = nn.functional.softmax(mm, dim=-1)
        ret = torch.bmm(mm, value)    
        return  ret
    

dpa = ScaledDotProductAttention().to(device)

# all at once
o_attn = dpa.forward(o_encoder.to(device), o_preattn.to(device))
#print(o_attn[0].shape)

# word by word
o_attn_word = dpa.forward(o_encoder.to(device), o_preattn_word.to(device))
#print(o_attn_word)

In [70]:

# let's check the ScaledDotProductAttention
t1 = torch.Tensor([[[0, 0, 1], [1, 0, 0], [0, 1, 0]]]).to(device)
t2 = torch.Tensor([[[0, 0, 1], [1, 0, 0], [0, 1, 0]]]).to(device)

dpa.forward(t1, t2)


tensor([[[0.2645, 0.2645, 0.4711],
         [0.4711, 0.2645, 0.2645],
         [0.2645, 0.4711, 0.2645]]], device='mps:0')

In [71]:
class Decoder(nn.Module):
    def __init__(self, hidden_size):
        super().__init__()

        NUM_LAYERS = 2

        self.gru = nn.GRU(hidden_size, hidden_size, batch_first=True, num_layers=NUM_LAYERS)
    
        self.seq = nn.Sequential(
            nn.BatchNorm1d(hidden_size),
            nn.Linear(hidden_size, hidden_size * 2),
            nn.BatchNorm1d(hidden_size * 2),
            nn.ReLU(),
            nn.Linear(hidden_size * 2, len(fr_vocab)), 
            nn.LogSoftmax(dim=-1)
        )

    def forward (self, input, word_idx=-1):

        if word_idx == -1 or word_idx == 0:
            o, self.h = self.gru(input)
        else:
            o, self.h = self.gru(input[:, word_idx : word_idx + 1, :], self.h)

        # take only the last value
        o = o[:, -1, :] 
        return self.seq(o)
    
dec = Decoder(64).to(device)
o = dec.forward(o_attn)
print(o.shape)
print(torch.argmax(o[0]), torch.max(o[0]), o[0][0:10])


torch.Size([128, 27139])
tensor(4677, device='mps:0') tensor(-8.0331, device='mps:0', grad_fn=<MaxBackward1>) tensor([-10.0987,  -9.5480, -11.2692, -10.8985,  -9.9374,  -9.4283, -10.2868,
        -10.6778, -10.5338, -10.1518], device='mps:0',
       grad_fn=<SliceBackward0>)


In [72]:
# training now, will get input and predict the last word. boom
class Translator(nn.Module):

    def __init__(self):
        super().__init__()

        self.encoder = Encoder(len(en_vocab), 64)
        self.pre_attn_decoder = PreAttentionDecoder(len(fr_vocab), 64)
        self.attention = ScaledDotProductAttention()
        self.decoder = Decoder(64)
        self.past_enc = None

    def forward(self, sentence, translated_sequence_so_far, word_by_word_idx = -1):

        # either batch prediction or char by char, first time
        if word_by_word_idx == -1 or word_by_word_idx == 0:
            enc = self.encoder.forward(sentence)
            self.past_enc = enc
        else:
            enc=self.past_enc
            
        pre_attn = self.pre_attn_decoder.forward(translated_sequence_so_far, word_by_word_idx)
        attn = self.attention.forward(enc, pre_attn)
        return self.decoder.forward(attn, word_by_word_idx)

In [73]:
try:
    tn = torch.load("translation_model.mdl")
except:
    tn = Translator().to(device)

print(tn)
print("Total params: ", sum(p.numel() for p in tn.parameters() if p.requires_grad))

# training
loss_fn = nn.NLLLoss()
optimizer = torch.optim.Adam(tn.parameters(), lr=1e-5) 

def train(model, dataloader):
    size = len(dataloader)
    model.train()

    for batch, (X, y) in enumerate(dataloader):
        X = X.to(device)
        y = y.to(device)
        
        loss = 0

        for i in range(0, y.shape[1]):

            translated_sentence = y[:, 0:i]
            result = y[:, i]

            # Compute prediction error
            pred = model(X, translated_sentence, i)
            loss += loss_fn(pred, result)

        # Backpropagation
        optimizer.zero_grad()
        loss.backward()

        # gradient clipping
        torch.nn.utils.clip_grad_norm_(tn.parameters(), 5.0)
        optimizer.step()

        if batch % 10 == 0:
            loss, current = loss.item(), (batch)
            #max_param =torch.max(torch.nn.utils.parameters_to_vector(tn.parameters())).item()
            seq_len = y.shape[1]

            loss /= y.shape[1] # normalize with the length of the batch
            print(f"loss: {loss:>7f}, seq_len: {seq_len:>4d} [{current:>5d}/{size:>5d}]")


train(tn, dl)

Translator(
  (encoder): Encoder(
    (embedding): Embedding(15962, 64)
    (lstm): LSTM(64, 64, batch_first=True, bidirectional=True)
    (linear): Linear(in_features=128, out_features=64, bias=True)
    (relu): ReLU()
  )
  (pre_attn_decoder): PreAttentionDecoder(
    (embedding): Embedding(27139, 64)
    (lstm): LSTM(64, 64, batch_first=True)
  )
  (attention): ScaledDotProductAttention()
  (decoder): Decoder(
    (gru): GRU(64, 64, num_layers=2, batch_first=True)
    (seq): Sequential(
      (0): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (1): Linear(in_features=64, out_features=128, bias=True)
      (2): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (3): ReLU()
      (4): Linear(in_features=128, out_features=27139, bias=True)
      (5): LogSoftmax(dim=-1)
    )
  )
)
Total params:  6426115
loss: 10.138462, seq_len:   18 [    0/ 1703]
loss: 10.092522, seq_len:   24 [   10/ 1703]
loss: 10.078456, seq

In [74]:
torch.save(tn, "translation_model.mdl")

In [76]:
def translate(str_en):

    cnt = 2 * len(str_en)
    str_en = en_tokenizer(str_en.lower())
    
    tn.train(False)

    # en
    en = torch.LongTensor([en_vocab[k] for k in str_en] + [EOS_EN]).to(device)
    # batch dimension
    en = en[None, :]
    print(en.shape)
    ret = []
    
    with torch.no_grad():
        pred = -1

        while pred != EOS_FR and cnt > 0:
            so_far = torch.LongTensor(ret)[None, :].to(device)
            pred = tn(en, so_far)
            pred = torch.argmax(pred)
            ret.append(pred.item())
            cnt -= 1

    print(fr_vocab.lookup_tokens(ret))


translate("Let's go! Let's move from here!")

torch.Size([1, 11])
['être', 'être', 'être', 'être', 'être', 'être', 'être', 'être', 'être', 'être', 'être', 'être', 'être', 'être', 'être', 'être', 'être', 'être', 'être', 'être', 'être', 'être', 'être', 'être', 'être', 'être', 'être', 'être', 'être', 'être', 'être', 'être', 'être', 'être', 'être', 'être', 'être', 'être', 'être', 'être', 'être', 'être', 'être', 'être', 'être', 'être', 'être', 'être', 'être', 'être', 'être', 'être', 'être', 'être', 'être', 'être', 'être', 'être', 'être', 'être', 'être', 'être']


In [None]:
def min_max_params():
    for p in tn.parameters():
    
        max__ = p.max().item()
        min__ = p.min().item()

        print(min__, max__)

min_max_params()    

Translator(
  (encoder): Encoder(
    (embedding): Embedding(15962, 64)
    (lstm): LSTM(64, 64, batch_first=True)
  )
  (pre_attn_decoder): PreAttentionDecoder(
    (embedding): Embedding(27139, 64)
    (lstm): LSTM(64, 64, batch_first=True)
  )
  (attention): ScaledDotProductAttention()
  (decoder): Decoder(
    (gru): GRU(64, 64, num_layers=2, batch_first=True)
    (seq): Sequential(
      (0): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (1): Linear(in_features=128, out_features=128, bias=True)
      (2): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (3): ReLU()
      (4): Linear(in_features=128, out_features=27139, bias=True)
      (5): LogSoftmax(dim=-1)
    )
  )
)
-4.939054012298584 4.701839447021484
inf -inf
inf -inf
inf -inf
inf -inf
-4.940924167633057 4.736354827880859
inf -inf
inf -inf
inf -inf
inf -inf
inf -inf
inf -inf
inf -inf
inf -inf
inf -inf
inf -inf
inf -inf
inf -inf
inf -inf
inf -inf
