In [1]:
# %pip install torchtext portalocker spacy

In [2]:
import pandas as pd
import numpy as np
import torch 
from torch import nn
import matplotlib.pyplot as plt
import torchtext
from torchtext.data.utils import get_tokenizer
from collections import Counter
from torch.utils.data import DataLoader
from torch.nn.utils.rnn import pad_sequence

device = (
    "cuda" if torch.cuda.is_available()
    else "mps" if torch.backends.mps.is_available()
    else "cpu"
)


In [3]:
df = pd.read_csv("./.data/datasets/fra.txt", delimiter='\t', header=None)[[0, 1]]
print(df.head())
df = np.array(df)

     0           1
0  Go.        Va !
1  Go.     Marche.
2  Go.  En route !
3  Go.     Bouge !
4  Hi.     Salut !


In [4]:
# !python3 -m spacy download en_core_web_sm
# !python3 -m spacy download fr_core_news_sm

In [5]:
en_tokenizer = get_tokenizer('spacy', language='en_core_web_sm')
fr_tokenizer = get_tokenizer('spacy', language='fr_core_news_sm')

en_counter = Counter()
fr_counter = Counter()

_df = df
df = []

for e, f in _df:

    et = en_tokenizer(e)
    ft = fr_tokenizer(f)

    en_counter.update(et)
    fr_counter.update(ft)

    df.append([et, ft])

print(len(en_counter), len(fr_counter))

17929 29072


In [6]:
en_vocab = torchtext.vocab.vocab(en_counter, specials=['<unk>', '<pad>', '<eos>'], special_first=True)
fr_vocab = torchtext.vocab.vocab(fr_counter, specials=['<unk>', '<pad>', '<eos>', '<sos>'], special_first=True)

PAD_EN = en_vocab['<pad>']
PAD_FR = fr_vocab['<pad>']
EOS_EN = en_vocab['<eos>']
EOS_FR = fr_vocab['<eos>']
SOS_FR = fr_vocab['<sos>']

en_vocab.set_default_index = en_vocab['<unk>']
fr_vocab.set_default_index = fr_vocab['<unk>']

del en_counter
del fr_counter 

In [7]:
def make_batch(batch):
    eb = []
    fb = [] 
    for e, f in batch:
        tte = torch.LongTensor([en_vocab[k] for k in e])
        ttf = torch.LongTensor([fr_vocab[k] for k in f])

        eb.append(torch.cat([tte, torch.LongTensor([EOS_EN])], dim=0))
        fb.append(torch.cat([ttf, torch.LongTensor([EOS_FR])], dim=0))

        en_batch = pad_sequence(eb, batch_first=True, padding_value=PAD_EN)
        fr_batch = pad_sequence(fb, batch_first=True, padding_value=PAD_FR)

    return en_batch, fr_batch 

dl = DataLoader(df, batch_size=128, shuffle=True, collate_fn=make_batch)

In [8]:
_, batch = next(enumerate(dl))
print(batch[0].shape, batch[1].shape)
print(batch[0][0])

torch.Size([128, 21]) torch.Size([128, 26])
tensor([ 278,  322,   40,  811, 4424,    9,    2,    1,    1,    1,    1,    1,
           1,    1,    1,    1,    1,    1,    1,    1,    1])


In [9]:
class Encoder(nn.Module):
    def __init__(self, input_size, hidden_size):
        super().__init__()

        self.embedding = nn.Embedding(input_size, hidden_size)

        # just 1 lstm layer
        self.lstm = nn.LSTM(hidden_size, hidden_size, batch_first=True, bidirectional=False)

    def forward(self, input):
        embedded = self.embedding(input)
        output, _ = self.lstm(embedded)
        return output

# test
enc = Encoder(len(en_vocab), 64)
o_encoder = enc.forward(batch[0])

print(o_encoder.shape)

torch.Size([128, 21, 64])


In [10]:
class PreAttentionDecoder(nn.Module):
    ### TODO: implement teacher forcing, add to each sequence a <SOS> character
    
    def __init__(self, target_vocab, hidden_size):
        super().__init__()
        self.embedding = nn.Embedding(target_vocab, hidden_size)
        self.lstm = nn.LSTM(hidden_size, hidden_size, batch_first=True, bidirectional=False)

    def forward(self, input):

        # shift right with the SOS token
        shift_right = torch.LongTensor([SOS_FR] * input.shape[0]).reshape((input.shape[0], 1)).to(device)
        input = torch.cat([shift_right, input], dim=1)

        mask = torch.where(input != 0, 1, 0)
        embedded = self.embedding(input)
        output, _ = self.lstm(embedded)
        return output, mask

pad = PreAttentionDecoder(len(fr_vocab), 64).to(device)
o_preattn, m_preattn = pad.forward(batch[1].to(device))

print(o_preattn.shape, m_preattn.shape)

torch.Size([128, 27, 64]) torch.Size([128, 27])


In [11]:
class ScaledDotProductAttention(nn.Module):

    def __init__(self):
        super().__init__()

    def forward(self, encoder_output, pre_attention_decoder_output, pre_attention_decoder_mask):
        key = encoder_output # what is input
        query = pre_attention_decoder_output # what has been so far translated
        value = encoder_output # the value we want to translate

        dk = torch.sqrt(torch.Tensor([key.shape[-1]])).to(device)
        key_transpose = torch.transpose(key, 1, 2)

        #print(dk, key.shape, query.shape, value.shape, key_transpose.shape)

        mm = torch.bmm(query, key_transpose) / dk
        mm = nn.functional.softmax(mm, dim=-1)

        #print(mm.shape, mm[0])
    
        return torch.bmm(mm, value) * pre_attention_decoder_mask[:, :, None]
    

dpa = ScaledDotProductAttention().to(device)
o_attn = dpa.forward(o_encoder.to(device), o_preattn.to(device), m_preattn.to(device))

print(o_attn[0], o_attn[0].shape)


tensor([[ 0.0800,  0.0390, -0.2563,  ...,  0.1186,  0.1022,  0.0967],
        [ 0.0809,  0.0392, -0.2535,  ...,  0.1179,  0.1019,  0.0969],
        [ 0.0791,  0.0383, -0.2566,  ...,  0.1200,  0.1026,  0.0973],
        ...,
        [ 0.0833,  0.0399, -0.2516,  ...,  0.1151,  0.1019,  0.0940],
        [ 0.0833,  0.0399, -0.2516,  ...,  0.1151,  0.1019,  0.0940],
        [ 0.0833,  0.0399, -0.2516,  ...,  0.1151,  0.1019,  0.0940]],
       device='mps:0', grad_fn=<SelectBackward0>) torch.Size([27, 64])


In [12]:
class Decoder(nn.Module):
    def __init__(self, hidden_size):
        super().__init__()

        NUM_LAYERS = 2

        self.gru = nn.GRU(hidden_size, hidden_size, batch_first=True, num_layers=NUM_LAYERS)
    
        self.seq = nn.Sequential(
            nn.BatchNorm1d(hidden_size * NUM_LAYERS),
            nn.Linear(hidden_size * NUM_LAYERS, hidden_size * 2),
            nn.BatchNorm1d(hidden_size * 2),
            nn.ReLU(),
            nn.Linear(hidden_size * 2, len(fr_vocab)), 
            nn.LogSoftmax(dim=-1)
        )

    def forward (self, input):

        _, h = self.gru(input)

        # batch first
        h = torch.permute(h, (1, 0, 2))

        # flatten
        h = torch.reshape(h, (h.shape[0], h.shape[1] * h.shape[2]))

        return self.seq(h)
    
dec = Decoder(64).to(device)
o = dec.forward(o_attn)
print(o.shape)
print(torch.argmax(o[0]), torch.max(o[0]), o[0][0:10])


torch.Size([128, 29076])
tensor(18433, device='mps:0') tensor(-8.5180, device='mps:0', grad_fn=<MaxBackward1>) tensor([-10.2750, -10.2051, -10.5931, -10.6331, -10.4142, -10.4906, -10.1292,
         -9.9625, -10.7294, -10.8555], device='mps:0',
       grad_fn=<SliceBackward0>)


In [13]:
# training now, will get input and predict the last word. boom
class Translator(nn.Module):

    def __init__(self):
        super().__init__()

        self.encoder = Encoder(len(en_vocab), 64)
        self.pre_attn_decoder = PreAttentionDecoder(len(fr_vocab), 64)
        self.attention = ScaledDotProductAttention()
        self.decoder = Decoder(64)

    def forward(self, sentence, translated_sequence_so_far):

        enc = self.encoder.forward(sentence)
        pre_attn, mask = self.pre_attn_decoder.forward(translated_sequence_so_far)
        attn = self.attention.forward(enc, pre_attn, mask)
        return self.decoder.forward(attn)

In [14]:
tn = Translator().to(device)
print(tn)
print("Total params: ", sum(p.numel() for p in tn.parameters() if p.requires_grad))

# training
loss_fn = nn.NLLLoss()
optimizer = torch.optim.Adam(tn.parameters(), lr=1e-4) 

def train(model, dataloader):
    size = len(dataloader)
    model.train()

    for batch, (X, y) in enumerate(dataloader):
        X = X.to(device)
        y = y.to(device)
        
        loss = 0

        for i in range(0, y.shape[1]):
            mask = torch.LongTensor([1] * i + [0] * (y.shape[1] - i)).to(device)
            y_masked = y * mask[None, :]
            result = y[:, i]

            # Compute prediction error
            pred = model(X, y_masked)
            loss += loss_fn(pred, result)

        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if batch % 10 == 0:
            loss, current = loss.item(), (batch)
            loss /= y.shape[1] # normalize with the length of the batch
            print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")


train(tn, dl)

Translator(
  (encoder): Encoder(
    (embedding): Embedding(17932, 64)
    (lstm): LSTM(64, 64, batch_first=True)
  )
  (pre_attn_decoder): PreAttentionDecoder(
    (embedding): Embedding(29076, 64)
    (lstm): LSTM(64, 64, batch_first=True)
  )
  (attention): ScaledDotProductAttention()
  (decoder): Decoder(
    (gru): GRU(64, 64, num_layers=2, batch_first=True)
    (seq): Sequential(
      (0): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (1): Linear(in_features=128, out_features=128, bias=True)
      (2): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (3): ReLU()
      (4): Linear(in_features=128, out_features=29076, bias=True)
      (5): LogSoftmax(dim=-1)
    )
  )
)
Total params:  6892820
loss: 10.340668  [    0/ 1703]


In [None]:
def translate(str_en):
    cnt = 2 * len(str_en)
    str_en = en_tokenizer(str_en)

    # en
    en = torch.LongTensor([en_vocab[k] for k in str_en] + [EOS_EN]).to(device)
    # batch dimension
    en = en[None, :]
    print(en.shape)
    ret = []

    with torch.no_grad():
        pred = -1

        while pred != EOS_FR and cnt > 0:
            so_far = torch.LongTensor(ret)[None, :].to(device)
            pred = torch.argmax(tn(en, so_far))
            ret.append(pred.item())
            cnt -= 1

    print(fr_vocab.lookup_tokens(ret))


translate("Let's go! Let's move from here!")