In [1]:
%reload_ext autoreload
%autoreload 2

from utils import load_dataset, SpecialTokens
import torch.nn.functional as F
import torch
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import numpy as np
import lightning.pytorch as pl

from models import EncoderGRU, DecoderGRU, Seq2sec
from train import Trainer

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [2]:
train_size = 10000
val_size = 1000
train_data, human_vocab, machine_vocab = load_dataset(train_size)
val_data, _, _ = load_dataset(val_size)

  0%|          | 0/10000 [00:00<?, ?it/s]

100%|██████████| 10000/10000 [00:00<00:00, 34490.33it/s]
100%|██████████| 1000/1000 [00:00<00:00, 34343.79it/s]


In [3]:
human_vocab

{';': 0,
 '?': 1,
 ' ': 2,
 '.': 3,
 '/': 4,
 '0': 5,
 '1': 6,
 '2': 7,
 '3': 8,
 '4': 9,
 '5': 10,
 '6': 11,
 '7': 12,
 '8': 13,
 '9': 14,
 'a': 15,
 'b': 16,
 'c': 17,
 'd': 18,
 'e': 19,
 'f': 20,
 'g': 21,
 'h': 22,
 'i': 23,
 'j': 24,
 'l': 25,
 'm': 26,
 'n': 27,
 'o': 28,
 'p': 29,
 'r': 30,
 's': 31,
 't': 32,
 'u': 33,
 'v': 34,
 'w': 35,
 'y': 36}

In [4]:
machine_vocab

{';': 0,
 '>': 1,
 '<': 2,
 '-': 3,
 '0': 4,
 '1': 5,
 '2': 6,
 '3': 7,
 '4': 8,
 '5': 9,
 '6': 10,
 '7': 11,
 '8': 12,
 '9': 13}

In [5]:
train_data[:20]

[('8/8/71', '>1971-08-08<'),
 ('3/3/81', '>1981-03-03<'),
 ('4/8/96', '>1996-04-08<'),
 ('5/5/97', '>1997-05-05<'),
 ('9/7/98', '>1998-09-07<'),
 ('4/4/99', '>1999-04-04<'),
 ('2/8/02', '>2002-02-08<'),
 ('9/1/13', '>2013-09-01<'),
 ('6/2/75', '>1975-06-02<'),
 ('1/1/10', '>2010-01-01<'),
 ('1/8/21', '>2021-01-08<'),
 ('1/6/72', '>1972-01-06<'),
 ('8/8/92', '>1992-08-08<'),
 ('4/2/05', '>2005-04-02<'),
 ('5/4/71', '>1971-05-04<'),
 ('4/8/96', '>1996-04-08<'),
 ('6/2/23', '>2023-06-02<'),
 ('4/6/86', '>1986-04-06<'),
 ('9/1/98', '>1998-09-01<'),
 ('6/6/92', '>1992-06-06<')]

In [6]:
val_data[:20]

[('8/4/21', '>2021-08-04<'),
 ('1/2/80', '>1980-01-02<'),
 ('2/1/88', '>1988-02-01<'),
 ('3/5/76', '>1976-03-05<'),
 ('6/4/90', '>1990-06-04<'),
 ('1/2/74', '>1974-01-02<'),
 ('9/27/85', '>1985-09-27<'),
 ('1 11 89', '>1989-11-01<'),
 ('8 06 88', '>1988-06-08<'),
 ('6/10/23', '>2023-06-10<'),
 ('9/30/06', '>2006-09-30<'),
 ('2/14/00', '>2000-02-14<'),
 ('1/11/12', '>2012-01-11<'),
 ('6 09 16', '>2016-09-06<'),
 ('8 05 16', '>2016-05-08<'),
 ('6/28/78', '>1978-06-28<'),
 ('1 08 84', '>1984-08-01<'),
 ('6/17/06', '>2006-06-17<'),
 ('5 07 09', '>2009-07-05<'),
 ('7/26/79', '>1979-07-26<')]

In [7]:
class Lang:
    def _get_char(self, ind):
        if isinstance(ind, torch.Tensor):
            return self.inv_vocab[ind.item()]
        else:
            return self.inv_vocab[ind]

    def __init__(self, vocab: dict):
        self.vocab = vocab
        self.inv_vocab = {v:k for k,v in vocab.items()}
        self.vocab_size = len(vocab)

    def str_to_ind(self, str):
        return [self.vocab[c] for c in str]
    
    def ind_to_str(self, ind):
        return ''.join([self._get_char(i) for i in ind])

In [8]:
test = Lang(human_vocab)
date = train_data[0][0]
print(date)
translated_date = test.str_to_ind(date)
print(translated_date)
reversed_translation = test.ind_to_str(translated_date)
print(reversed_translation)

8/8/71
[13, 4, 13, 4, 12, 6]
8/8/71


In [9]:
class TranslationTrainingDataset(Dataset):
    def __init__(self, data, input_vocab, output_vocab):
        self.input_lang = Lang(input_vocab)
        self.target_lang = Lang(output_vocab)

        self.data = data

        self.encoder_inputs = [self.input_lang.str_to_ind(input_sent) for input_sent, _ in self.data]

        targets = [self.target_lang.str_to_ind(target_sent) for _, target_sent in self.data]
        self.decoder_inputs = [target[:-1] for target in targets]
        self.decoder_targets = [target[1:] for target in targets]

    def __getitem__(self, index):
        return self.encoder_inputs[index], self.decoder_inputs[index], self.decoder_targets[index]
    
    def __len__(self):
        return len(self.encoder_inputs)


In [10]:
train_dataset = TranslationTrainingDataset(train_data, human_vocab, machine_vocab)
val_dataset = TranslationTrainingDataset(val_data, human_vocab, machine_vocab)

In [11]:
x,y,z = train_dataset[0]
print(x, train_dataset.input_lang.ind_to_str(x))
print(y, train_dataset.target_lang.ind_to_str(y))
print(z, train_dataset.target_lang.ind_to_str(z))

[13, 4, 13, 4, 12, 6] 8/8/71
[1, 5, 13, 11, 5, 3, 4, 12, 3, 4, 12] >1971-08-08
[5, 13, 11, 5, 3, 4, 12, 3, 4, 12, 2] 1971-08-08<


In [12]:
def collate_batch(data):
    batch = []
    for i in range(len(data[0])):
        batch_data = [torch.tensor(item[i], dtype=torch.int64) for item in data]
        batch_data = nn.utils.rnn.pad_sequence(batch_data, batch_first=True)
        batch.append(batch_data)


    return tuple(batch)

In [13]:
train_loader = DataLoader(dataset=train_dataset, collate_fn=collate_batch, batch_size = 64, num_workers = 8)
val_loader = DataLoader(dataset=val_dataset, collate_fn=collate_batch, batch_size = 64, num_workers = 8)

In [49]:
x_enc_batch,  x_dec_batch, y_batch = next(iter(train_loader))
print(x_enc_batch)
print(x_dec_batch)
print(y_batch)
encoder = EncoderGRU(len(human_vocab), hidden_size=64, num_layers=1, bidirectional=False)
print(encoder)
enc_out, enc_hidden = encoder(x_enc_batch)
print(enc_out.shape, enc_hidden.shape)

tensor([[13,  4, 13,  4, 12,  6],
        [ 8,  4,  8,  4, 13,  6],
        [ 9,  4, 13,  4, 14, 11],
        [10,  4, 10,  4, 14, 12],
        [14,  4, 12,  4, 14, 13],
        [ 9,  4,  9,  4, 14, 14],
        [ 7,  4, 13,  4,  5,  7],
        [14,  4,  6,  4,  6,  8],
        [11,  4,  7,  4, 12, 10],
        [ 6,  4,  6,  4,  6,  5],
        [ 6,  4, 13,  4,  7,  6],
        [ 6,  4, 11,  4, 12,  7],
        [13,  4, 13,  4, 14,  7],
        [ 9,  4,  7,  4,  5, 10],
        [10,  4,  9,  4, 12,  6],
        [ 9,  4, 13,  4, 14, 11],
        [11,  4,  7,  4,  7,  8],
        [ 9,  4, 11,  4, 13, 11],
        [14,  4,  6,  4, 14, 13],
        [11,  4, 11,  4, 14,  7],
        [ 8,  4,  6,  4, 13, 11],
        [ 9,  4,  6,  4,  5, 11],
        [14,  4,  9,  4, 12,  8],
        [ 7,  4, 11,  4,  7,  7],
        [ 8,  4, 10,  4,  6,  7],
        [11,  4,  6,  4,  6,  7],
        [ 6,  4,  9,  4, 14, 13],
        [14,  4,  8,  4,  6,  8],
        [ 6,  4, 12,  4, 14, 13],
        [ 7,  

In [50]:
lang = train_dataset.input_lang
for row in x_enc_batch:
    print(lang.ind_to_str(row))

8/8/71
3/3/81
4/8/96
5/5/97
9/7/98
4/4/99
2/8/02
9/1/13
6/2/75
1/1/10
1/8/21
1/6/72
8/8/92
4/2/05
5/4/71
4/8/96
6/2/23
4/6/86
9/1/98
6/6/92
3/1/86
4/1/06
9/4/73
2/6/22
3/5/12
6/1/12
1/4/98
9/3/13
1/7/98
2/8/74
2/6/93
8/4/92
2/2/72
4/3/84
8/8/89
6/1/70
2/6/77
7/4/86
9/9/20
6/9/96
2/1/22
8/8/07
3/4/93
9/9/78
6/1/99
1/4/23
3/2/71
7/7/78
5/2/02
7/2/20
8/5/91
1/7/17
6/6/81
9/2/87
3/5/78
9/8/00
3/7/17
7/4/83
7/1/09
7/9/97
7/6/93
8/9/88
4/2/07
6/1/17


In [51]:
len(machine_vocab)

14

In [52]:
encoder = EncoderGRU(len(human_vocab), hidden_size=64, num_layers=1, bidirectional=False)
decoder = DecoderGRU(len(machine_vocab), hidden_size=64, num_layers=1)
print(decoder)
print(encoder)

DecoderGRU(
  (gru): GRU(14, 64, batch_first=True)
  (fc): Linear(in_features=64, out_features=14, bias=True)
)
EncoderGRU(
  (gru): GRU(37, 64, batch_first=True)
)


In [53]:
print("Decoder forward pass\n")

# Teacher forcing
print("Training with teacher forcing")
print(f"Input batch shape: {x_dec_batch.shape}")
dec_out, dec_hid = decoder(x_dec_batch, enc_hidden)
print(f"decoder output shape: {dec_out.shape}\ndecoder hn shape: {dec_hid.shape}")
# loss(dec_out, target)
print()
# Without teacher forcing
print("Training without teacher forcing")
dec_input = x_dec_batch[:,0:1]
print(f"Decoder 1st input shape: {dec_input.shape}")
dec_out, dec_hid = decoder(dec_input, enc_hidden)
print(f"decoder output shape: {dec_out.shape}\ndecoder hn shape: {dec_hid.shape}")
next_input = torch.argmax(dec_out, dim=-1)
print(f"decoder 2nd input shape: {next_input.shape}")


Decoder forward pass

Training with teacher forcing
Input batch shape: torch.Size([64, 11])
decoder output shape: torch.Size([64, 11, 14])
decoder hn shape: torch.Size([1, 64, 64])

Training without teacher forcing
Decoder 1st input shape: torch.Size([64, 1])
decoder output shape: torch.Size([64, 1, 14])
decoder hn shape: torch.Size([1, 64, 64])
decoder 2nd input shape: torch.Size([64, 1])


In [54]:
model = Seq2sec(encoder, decoder)
model

Seq2sec(
  (encoder): EncoderGRU(
    (gru): GRU(37, 64, batch_first=True)
  )
  (decoder): DecoderGRU(
    (gru): GRU(14, 64, batch_first=True)
    (fc): Linear(in_features=64, out_features=14, bias=True)
  )
)

In [55]:
print("Model forward pass (input -> encoder -> decoder -> output)\n")

# Teacher forcing
print("Training without teacher forcing (out_length = 20)")
print(f"Input batch shape: {x_enc_batch.shape}")
output =  model(x_enc_batch, out_length = 20)
print(f"Output shape: {output.shape}")

print()
# Without teacher forcing
print("Training with teacher forcing")
print(f"Encoder input batch shape: {x_enc_batch.shape}")
print(f"Decoder input batch shape: {x_dec_batch.shape}")
output = model(x_enc_batch, dec_input_batch = x_dec_batch, teacher_forcing = True)
print(f"Output shape: {output.shape}")

Model forward pass (input -> encoder -> decoder -> output)

Training without teacher forcing (out_length = 20)
Input batch shape: torch.Size([64, 6])
Output shape: torch.Size([64, 20, 14])

Training with teacher forcing
Encoder input batch shape: torch.Size([64, 6])
Decoder input batch shape: torch.Size([64, 11])
Output shape: torch.Size([64, 11, 14])


In [57]:
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
cross_entropy_loss = nn.CrossEntropyLoss(ignore_index=0)

trainer = Trainer(model, train_dataLoader=train_loader, val_dataLoader=val_loader, loss_fn= cross_entropy_loss, optimizer=optimizer)

In [59]:
trainer.train(50)

Epoch  1    training loss: 0.428725 | validation loss: 0.480993
Epoch  2    training loss: 0.413526 | validation loss: 0.463851
Epoch  3    training loss: 0.398315 | validation loss: 0.444159
Epoch  4    training loss: 0.380441 | validation loss: 0.427556
Epoch  5    training loss: 0.362214 | validation loss: 0.413763
Epoch  6    training loss: 0.346169 | validation loss: 0.403205
Epoch  7    training loss: 0.330280 | validation loss: 0.379970
Epoch  8    training loss: 0.311869 | validation loss: 0.361615
Epoch  9    training loss: 0.296308 | validation loss: 0.351179
Epoch  10   training loss: 0.281257 | validation loss: 0.345222
Epoch  11   training loss: 0.271490 | validation loss: 0.328777
Epoch  12   training loss: 0.255192 | validation loss: 0.306013
Epoch  13   training loss: 0.241457 | validation loss: 0.293268
Epoch  14   training loss: 0.234754 | validation loss: 0.291790
Epoch  15   training loss: 0.226277 | validation loss: 0.288448
Epoch  16   training loss: 0.216527 | va

In [60]:
def print_translations(model, n = 10):
    test_loader = DataLoader(dataset=val_dataset, collate_fn=collate_batch, batch_size = 64, num_workers = 8, shuffle=True)

    x, z, y = next(iter(test_loader))
    x = x.to(device)
    y = y.to(device)
    z = z.to(device)

    y_hat = model(
            x,
            teacher_forcing=False,
            sos_index=1,
            out_length=11,
        )

    input_lang = train_dataset.input_lang
    output_lang = train_dataset.target_lang

    y_hat = y_hat.argmax(axis=-1)

    print("Translation test\n")
    print("Input                 |   Machine translation | Correct translation")
    for i in range(10):
        print(input_lang.ind_to_str(x[i]).replace(';', ' '), output_lang.ind_to_str(y_hat[i]).strip('<>;'), " "*10, output_lang.ind_to_str(y[i]).strip(';<>'))

In [67]:
print_translations(model)

Translation test

Input                 |   Machine translation | Correct translation
22 december 1999            1997-12-22            1999-12-22
wednesday september 9 2020  2020-09-09            2020-09-09
monday november 29 1982     1982-11-29            1982-11-29
april 18 2013               2029-04-18            2013-04-18
may 14 2001                 2028-05-14            2001-05-14
3 nov 1995                  2006-11-03            1995-11-03
february 18 1981            2010-02-17            1981-02-18
tuesday june 13 1978        1978-06-13            1978-06-13
friday june 19 1992         2010-07-19            1992-06-19
11 april 1978               2018-04-11            1978-04-11


In [52]:
class BahdanauAttention(nn.Module):
    def __init__(self, hidden_size):
        super(BahdanauAttention, self).__init__()
        self.Wa = nn.Linear(hidden_size, hidden_size)
        self.Ua = nn.Linear(hidden_size, hidden_size)
        self.Va = nn.Linear(hidden_size, 1)

    def forward(self, query, keys):
        scores = self.Va(torch.tanh(self.Wa(query) + self.Ua(keys)))
        scores = scores.squeeze(2).unsqueeze(1)
        
        print(scores.shape)

        weights = F.softmax(scores, dim=-1)
        context = torch.bmm(weights, keys)

        return context, weights

In [74]:
hidden_size = 32
output_size = len(machine_vocab)
batch_size = 5


att = BahdanauAttention(hidden_size)

hidden = torch.zeros(1, batch_size, hidden_size)
keys = torch.zeros(batch_size, output_size, hidden_size) # encoder outputs

query = hidden.permute(1, 0, 2)

c, w = att(query, keys)
print(c.shape, w.shape)

torch.Size([5, 1, 14])
torch.Size([5, 1, 32]) torch.Size([5, 1, 14])


In [None]:
class AttentionDecoderGRU(nn.Module):
    def __init__(self, hidden_size, vocab_size):
        super(AttentionDecoderGRU, self).__init__()
        self.embedding = nn.Embedding(vocab_size, hidden_size)
        self.attention = BahdanauAttention(hidden_size)
        self.gru = nn.GRU(hidden_size, hidden_size, batch_first=True)
        self.out = nn.Linear(hidden_size, vocab_size)

    def forward(self, encoder_outputs, encoder_hidden, sos_index = 1, decoder_input = None, teacher_forcing = False, out_length = 1):
        batch_size = encoder_outputs.size(0)
        decoder_input = torch.empty(batch_size, 1, dtype=torch.long, device=device).fill_(sos_index)
        decoder_hidden = encoder_hidden
        decoder_outputs = []
        attentions = []

        for i in range(out_length):
            decoder_output, decoder_hidden, attn_weights = self.forward_step(
                decoder_input, decoder_hidden, encoder_outputs
            )
            decoder_outputs.append(decoder_output)
            attentions.append(attn_weights)

            if teacher_forcing:
                # Teacher forcing: Feed the target as the next input
                decoder_input = decoder_input[:, i].unsqueeze(1) # Teacher forcing
            else:
                # Without teacher forcing: use its own predictions as the next input
                _, topi = decoder_output.topk(1)
                decoder_input = topi.squeeze(-1).detach()  # detach from history as input

        decoder_outputs = torch.cat(decoder_outputs, dim=1)
        decoder_outputs = F.log_softmax(decoder_outputs, dim=-1)
        attentions = torch.cat(attentions, dim=1)

        return decoder_outputs, decoder_hidden, attentions


    def forward_step(self, input, hidden, encoder_outputs):
        query = hidden.permute(1, 0, 2)
        context, attn_weights = self.attention(query, encoder_outputs)
        # input_gru = torch.cat((embedded, context), dim=2)
        input_gru = context

        output, hidden = self.gru(input_gru, hidden)
        output = self.out(output)

        return output, hidden, attn_weights

In [None]:
class AttentionSeq2sec(nn.Module):
    def __init__(self, encoder: EncoderGRU, decoder: AttentionDecoderGRU) -> None:
        super(Seq2sec, self).__init__()

        self.encoder = encoder
        self.decoder = decoder

    def forward(self, enc_input_batch, sos_index = 1, decoder_input = None, teacher_forcing = False, out_length = 1):
        encoder_outputs, encoder_hidden = self.encoder(enc_input_batch)
        batch_size = len(enc_input_batch)
        
        decoder_output, _, _ = decoder(encoder_outputs, encoder_hidden, sos_index, decoder_input, teacher_forcing, out_length)
        
        return decoder_output