In [27]:
%reload_ext autoreload
%autoreload 2

from utils import load_dataset, SpecialTokens
import torch.nn.functional as F
import torch
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import numpy as np
import lightning.pytorch as pl

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [28]:
train_size = 10000
val_size = 1000
train_data, human_vocab, machine_vocab = load_dataset(train_size)
val_data, _, _ = load_dataset(val_size)

100%|██████████| 10000/10000 [00:00<00:00, 31385.40it/s]
100%|██████████| 1000/1000 [00:00<00:00, 28430.56it/s]


In [29]:
human_vocab

{';': 0,
 '?': 1,
 ' ': 2,
 '.': 3,
 '/': 4,
 '0': 5,
 '1': 6,
 '2': 7,
 '3': 8,
 '4': 9,
 '5': 10,
 '6': 11,
 '7': 12,
 '8': 13,
 '9': 14,
 'a': 15,
 'b': 16,
 'c': 17,
 'd': 18,
 'e': 19,
 'f': 20,
 'g': 21,
 'h': 22,
 'i': 23,
 'j': 24,
 'l': 25,
 'm': 26,
 'n': 27,
 'o': 28,
 'p': 29,
 'r': 30,
 's': 31,
 't': 32,
 'u': 33,
 'v': 34,
 'w': 35,
 'y': 36}

In [30]:
machine_vocab

{';': 0,
 '>': 1,
 '<': 2,
 '-': 3,
 '0': 4,
 '1': 5,
 '2': 6,
 '3': 7,
 '4': 8,
 '5': 9,
 '6': 10,
 '7': 11,
 '8': 12,
 '9': 13}

In [31]:
train_data[:20]

[('3/1/74', '>1974-03-01<'),
 ('1/7/99', '>1999-01-07<'),
 ('1/3/95', '>1995-01-03<'),
 ('5/3/93', '>1993-05-03<'),
 ('6/1/16', '>2016-06-01<'),
 ('5/3/83', '>1983-05-03<'),
 ('9/8/79', '>1979-09-08<'),
 ('2/5/02', '>2002-02-05<'),
 ('5/5/05', '>2005-05-05<'),
 ('2/1/72', '>1972-02-01<'),
 ('2/7/71', '>1971-02-07<'),
 ('9/1/74', '>1974-09-01<'),
 ('5/8/94', '>1994-05-08<'),
 ('8/7/82', '>1982-08-07<'),
 ('1/4/86', '>1986-01-04<'),
 ('5/1/07', '>2007-05-01<'),
 ('3/9/93', '>1993-03-09<'),
 ('2/2/84', '>1984-02-02<'),
 ('3/9/75', '>1975-03-09<'),
 ('9/8/21', '>2021-09-08<')]

In [32]:
val_data[:20]

[('1/7/14', '>2014-01-07<'),
 ('5/7/73', '>1973-05-07<'),
 ('5/7/20', '>2020-05-07<'),
 ('5/6/83', '>1983-05-06<'),
 ('1/1/15', '>2015-01-01<'),
 ('2/3/70', '>1970-02-03<'),
 ('4/6/85', '>1985-04-06<'),
 ('3/4/73', '>1973-03-04<'),
 ('8/9/09', '>2009-08-09<'),
 ('1/19/15', '>2015-01-19<'),
 ('2/12/08', '>2008-02-12<'),
 ('2 06 70', '>1970-06-02<'),
 ('1/12/13', '>2013-01-12<'),
 ('2 10 72', '>1972-10-02<'),
 ('10/5/14', '>2014-10-05<'),
 ('6 11 10', '>2010-11-06<'),
 ('3/19/86', '>1986-03-19<'),
 ('3 01 04', '>2004-01-03<'),
 ('5/15/99', '>1999-05-15<'),
 ('1/31/06', '>2006-01-31<')]

In [33]:
class Lang:
    def _get_char(self, ind):
        if isinstance(ind, torch.Tensor):
            return self.inv_vocab[ind.item()]
        else:
            return self.inv_vocab[ind]

    def __init__(self, vocab: dict):
        self.vocab = vocab
        self.inv_vocab = {v:k for k,v in vocab.items()}
        self.vocab_size = len(vocab)

    def str_to_ind(self, str):
        return [self.vocab[c] for c in str]
    
    def ind_to_str(self, ind):
        return ''.join([self._get_char(i) for i in ind])

In [34]:
test = Lang(human_vocab)
date = train_data[0][0]
print(date)
translated_date = test.str_to_ind(date)
print(translated_date)
reversed_translation = test.ind_to_str(translated_date)
print(reversed_translation)

3/1/74
[8, 4, 6, 4, 12, 9]
3/1/74


In [35]:
class TranslationTrainingDataset(Dataset):
    def __init__(self, data, input_vocab, output_vocab):
        self.input_lang = Lang(input_vocab)
        self.target_lang = Lang(output_vocab)

        self.data = data

        self.encoder_inputs = [self.input_lang.str_to_ind(input_sent) for input_sent, _ in self.data]

        targets = [self.target_lang.str_to_ind(target_sent) for _, target_sent in self.data]
        self.decoder_inputs = [target[:-1] for target in targets]
        self.decoder_targets = [target[1:] for target in targets]

    def __getitem__(self, index):
        return self.encoder_inputs[index], self.decoder_inputs[index], self.decoder_targets[index]
    
    def __len__(self):
        return len(self.encoder_inputs)


In [36]:
train_dataset = TranslationTrainingDataset(train_data, human_vocab, machine_vocab)
val_dataset = TranslationTrainingDataset(val_data, human_vocab, machine_vocab)

In [37]:
x,y,z = train_dataset[0]
print(x, train_dataset.input_lang.ind_to_str(x))
print(y, train_dataset.target_lang.ind_to_str(y))
print(z, train_dataset.target_lang.ind_to_str(z))

[8, 4, 6, 4, 12, 9] 3/1/74
[1, 5, 13, 11, 8, 3, 4, 7, 3, 4, 5] >1974-03-01
[5, 13, 11, 8, 3, 4, 7, 3, 4, 5, 2] 1974-03-01<


In [38]:
def collate_batch(data):
    batch = []
    for i in range(len(data[0])):
        batch_data = [torch.tensor(item[i], dtype=torch.int64) for item in data]
        batch_data = nn.utils.rnn.pad_sequence(batch_data, batch_first=True)
        batch.append(batch_data)


    return tuple(batch)

In [39]:
train_loader = DataLoader(dataset=train_dataset, collate_fn=collate_batch, batch_size = 64, num_workers = 8)
val_loader = DataLoader(dataset=val_dataset, collate_fn=collate_batch, batch_size = 64, num_workers = 8)

In [40]:
class EncoderGRU(nn.Module):
    def __init__(self, vocab_size, hidden_size, num_layers=1, bidirectional=False):
        super(EncoderGRU, self).__init__()
        self.D = 2 if bidirectional else 1
        self.vocab_size = vocab_size
        self.hidden_size = hidden_size
        self.gru = nn.GRU(
            vocab_size,
            hidden_size,
            num_layers=num_layers,
            bidirectional=bidirectional,
            batch_first=True,
        )

    def forward(self, x, hidden=None):
        if hidden == None:
            hidden = self.init_hidden(x.shape[0]).to(x.device)

        one_hot = F.one_hot(x, num_classes=self.vocab_size).float().to(x.device)

        output, hidden = self.gru(one_hot, hidden)

        return output, hidden

    def init_hidden(self, batch_size):
        return torch.zeros(
            self.D * self.gru.num_layers,
            batch_size,
            self.hidden_size,
            dtype=torch.float32,
        )

In [41]:
x_enc_batch,  x_dec_batch, y_batch = next(iter(train_loader))
print(x_enc_batch)
print(x_dec_batch)
print(y_batch)
encoder = EncoderGRU(len(human_vocab), hidden_size=32, num_layers=1, bidirectional=False)
print(encoder)
enc_out, enc_hidden = encoder(x_enc_batch)
print(enc_out.shape, enc_hidden.shape)

tensor([[ 8,  4,  6,  4, 12,  9],
        [ 6,  4, 12,  4, 14, 14],
        [ 6,  4,  8,  4, 14, 10],
        [10,  4,  8,  4, 14,  8],
        [11,  4,  6,  4,  6, 11],
        [10,  4,  8,  4, 13,  8],
        [14,  4, 13,  4, 12, 14],
        [ 7,  4, 10,  4,  5,  7],
        [10,  4, 10,  4,  5, 10],
        [ 7,  4,  6,  4, 12,  7],
        [ 7,  4, 12,  4, 12,  6],
        [14,  4,  6,  4, 12,  9],
        [10,  4, 13,  4, 14,  9],
        [13,  4, 12,  4, 13,  7],
        [ 6,  4,  9,  4, 13, 11],
        [10,  4,  6,  4,  5, 12],
        [ 8,  4, 14,  4, 14,  8],
        [ 7,  4,  7,  4, 13,  9],
        [ 8,  4, 14,  4, 12, 10],
        [14,  4, 13,  4,  7,  6],
        [14,  4, 13,  4,  5, 14],
        [ 7,  4, 11,  4, 12, 12],
        [14,  4,  9,  4, 14, 10],
        [11,  4, 11,  4, 14,  5],
        [13,  4,  8,  4, 13, 10],
        [ 7,  4,  9,  4,  5, 10],
        [12,  4,  9,  4, 13,  7],
        [13,  4, 14,  4,  6,  6],
        [ 9,  4, 14,  4,  6,  7],
        [11,  

In [42]:
lang = train_dataset.input_lang
for row in x_enc_batch:
    print(lang.ind_to_str(row))

3/1/74
1/7/99
1/3/95
5/3/93
6/1/16
5/3/83
9/8/79
2/5/02
5/5/05
2/1/72
2/7/71
9/1/74
5/8/94
8/7/82
1/4/86
5/1/07
3/9/93
2/2/84
3/9/75
9/8/21
9/8/09
2/6/77
9/4/95
6/6/90
8/3/85
2/4/05
7/4/82
8/9/11
4/9/12
6/8/78
8/8/22
2/6/71
7/8/71
8/4/16
2/3/79
6/7/16
3/3/80
9/8/01
4/2/06
9/3/13
1/8/79
1/9/03
8/8/92
6/3/08
1/7/90
3/4/94
8/8/86
3/8/91
8/2/91
2/6/91
9/4/94
2/7/19
9/3/99
9/2/96
2/1/88
3/9/74
2/8/07
7/4/76
8/3/09
3/5/72
8/1/80
3/6/12
3/1/91
6/2/21


In [43]:
class DecoderGRU(nn.Module):
    def __init__(self, vocab_size, hidden_size, num_layers=1):
        super(DecoderGRU, self).__init__()
        self.hidden_size = hidden_size

        self.vocab_size = vocab_size
        self.hidden_size = hidden_size
        self.gru = nn.GRU(
            vocab_size, hidden_size, num_layers=num_layers, batch_first=True
        )
        self.fc = nn.Linear(hidden_size, vocab_size)

    def forward(self, x, hidden=None):
        if hidden == None:
            hidden = self.init_hidden(x.shape[0]).to(x.device)
        one_hot = F.one_hot(x, num_classes=self.vocab_size).float().to(x.device)
        output, hidden = self.gru(one_hot, hidden)
        output = self.fc(output)
        return output, hidden

    def init_hidden(self, batch_size):
        return torch.zeros(
            self.gru.num_layers, batch_size, self.hidden_size, dtype=torch.float32
        )

In [44]:
encoder = EncoderGRU(len(human_vocab), hidden_size=32, num_layers=1, bidirectional=False)
decoder = DecoderGRU(len(machine_vocab), hidden_size=32, num_layers=1)
print(decoder)
print(encoder)

DecoderGRU(
  (gru): GRU(14, 32, batch_first=True)
  (fc): Linear(in_features=32, out_features=14, bias=True)
)
EncoderGRU(
  (gru): GRU(37, 32, batch_first=True)
)


In [45]:
print("Decoder forward pass\n")

# Teacher forcing
print("Training with teacher forcing")
print(f"Input batch shape: {x_dec_batch.shape}")
dec_out, dec_hid = decoder(x_dec_batch, enc_hidden)
print(f"decoder output shape: {dec_out.shape}\ndecoder hn shape: {dec_hid.shape}")
# loss(dec_out, target)
print()
# Without teacher forcing
print("Training without teacher forcing")
dec_input = x_dec_batch[:,0:1]
print(f"Decoder 1st input shape: {dec_input.shape}")
dec_out, dec_hid = decoder(dec_input, enc_hidden)
print(f"decoder output shape: {dec_out.shape}\ndecoder hn shape: {dec_hid.shape}")
next_input = torch.argmax(dec_out, dim=-1)
print(f"decoder 2nd input shape: {next_input.shape}")


Decoder forward pass

Training with teacher forcing
Input batch shape: torch.Size([64, 11])
decoder output shape: torch.Size([64, 11, 14])
decoder hn shape: torch.Size([1, 64, 32])

Training without teacher forcing
Decoder 1st input shape: torch.Size([64, 1])
decoder output shape: torch.Size([64, 1, 14])
decoder hn shape: torch.Size([1, 64, 32])
decoder 2nd input shape: torch.Size([64, 1])


In [46]:
class Seq2sec(nn.Module):
    def __init__(self, encoder, decoder) -> None:
        super(Seq2sec, self).__init__()

        self.encoder = encoder
        self.decoder = decoder

    def forward(self, enc_input_batch, sos_index = 1, dec_input_batch = None, teacher_forcing = False, out_length = 1):
        encoder_output, encoder_hidden = self.encoder(enc_input_batch)
        batch_size = len(enc_input_batch)

        if teacher_forcing:
            decoder_output, _ = self.decoder(dec_input_batch, encoder_hidden)
            return decoder_output
        else:
            decoder_input = (torch.zeros(batch_size, 1, dtype=torch.int64) + sos_index).to(enc_input_batch.device)
            decoder_output = torch.empty(batch_size, out_length, self.decoder.vocab_size).to(enc_input_batch.device)

            hidden = encoder_hidden

            for i in range(out_length):
                decoder_output_i, hidden = self.decoder(decoder_input, hidden)
                decoder_output[:,i:i+1,:] = decoder_output_i
                decoder_input = torch.argmax(decoder_output_i, dim=-1)

            return decoder_output


In [47]:
model = Seq2sec(encoder, decoder)
model

Seq2sec(
  (encoder): EncoderGRU(
    (gru): GRU(37, 32, batch_first=True)
  )
  (decoder): DecoderGRU(
    (gru): GRU(14, 32, batch_first=True)
    (fc): Linear(in_features=32, out_features=14, bias=True)
  )
)

In [48]:
model(x_enc_batch, out_length = 20).shape

torch.Size([64, 20, 14])

In [49]:
model(x_enc_batch, dec_input_batch = x_dec_batch, teacher_forcing = True).shape

torch.Size([64, 11, 14])

In [50]:
print("Model forward pass (input -> encoder -> decoder -> output)\n")

# Teacher forcing
print("Training without teacher forcing (out_length = 20)")
print(f"Input batch shape: {x_enc_batch.shape}")
output =  model(x_enc_batch, out_length = 20)
print(f"Output shape: {output.shape}")

print()
# Without teacher forcing
print("Training with teacher forcing")
print(f"Encoder input batch shape: {x_enc_batch.shape}")
print(f"Decoder input batch shape: {x_dec_batch.shape}")
output = model(x_enc_batch, dec_input_batch = x_dec_batch, teacher_forcing = True)
print(f"Output shape: {output.shape}")

Model forward pass (input -> encoder -> decoder -> output)

Training without teacher forcing (out_length = 20)
Input batch shape: torch.Size([64, 6])
Output shape: torch.Size([64, 20, 14])

Training with teacher forcing
Encoder input batch shape: torch.Size([64, 6])
Decoder input batch shape: torch.Size([64, 11])
Output shape: torch.Size([64, 11, 14])


In [51]:
class Trainer:
    def __init__(
        self,
        model,
        train_dataLoader,
        loss_fn,
        optimizer,
        val_dataLoader=None,
        padding_index=0,
        sos_index=1,
        teacher_forcing_ratio=0.5,
        device=device,
    ) -> None:
        self.model = model.to(device)
        self.train_dataLoader = train_dataLoader
        self.val_dataLoader = val_dataLoader
        self.loss_fn = loss_fn
        self.padding_index = padding_index
        self.sos_index = sos_index
        self.teacher_forcing_ratio = teacher_forcing_ratio
        self.device = device
        self.optimizer = optimizer

    def loss_value(self, output, target):
        C = output.shape[-1]

        output_flat = output.view(-1, C)
        target_flat = target.view(-1)

        loss = self.loss_fn(output_flat, target_flat)

        return loss

    def train_batch(self, encoder_input, decoder_input, target):
        teacher_forcing = np.random.random() < self.teacher_forcing_ratio

        output = self.model(
            encoder_input,
            dec_input_batch=decoder_input,
            teacher_forcing=teacher_forcing,
            sos_index=self.sos_index,
            out_length=target.shape[1],
        )

        loss = self.loss_value(output, target)

        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

        return loss.item()

    def train_one_epoch(self):
        model.train()
        epoch_loss = 0

        epoch_loss = 0
        batch_losses = []

        i = 1

        for enc_input, dec_input, target in self.train_dataLoader:
            i += 1
            enc_input = enc_input.to(self.device)
            dec_input = dec_input.to(self.device)
            target = target.to(self.device)

            batch_loss = self.train_batch(
                encoder_input=enc_input, decoder_input=dec_input, target=target
            )

            epoch_loss += batch_loss * len(enc_input)
            batch_losses.append(batch_loss)

        size = len(self.train_dataLoader.dataset)

        return epoch_loss / size, batch_losses

    def validation(self):
        size = len(self.val_dataLoader.dataset)
        model.eval()

        test_loss = 0

        with torch.inference_mode():
            for enc_input, _, target in self.val_dataLoader:
                enc_input = enc_input.to(self.device)
                target = target.to(self.device)

                out = self.model(
                    enc_input,
                    teacher_forcing=False,
                    sos_index=self.sos_index,
                    out_length=target.shape[1],
                )
                test_loss += self.loss_value(out, target).item() * len(enc_input)

            return test_loss / size

    def train(self, n_epochs, verbose = True):
        for epoch in range(n_epochs):
            epoch_loss, batch_losses = self.train_one_epoch()

            if self.val_dataLoader != None:
                val_loss = self.validation()

                if verbose:
                    print(f"Epoch {epoch + 1 :< 4}  training loss: {epoch_loss:>8f} | validation loss: {val_loss:>8f}")\
                    
            else:
                if verbose:
                    print(f"Epoch {epoch + 1 :< 10}  training loss: {epoch_loss:>8f}")

In [52]:
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
cross_entropy_loss = nn.CrossEntropyLoss(ignore_index=0)

trainer = Trainer(model, train_dataLoader=train_loader, val_dataLoader=val_loader, loss_fn= cross_entropy_loss, optimizer=optimizer)

In [72]:
trainer.train(30)

Epoch  1    training loss: 0.089085 | validation loss: 0.143550
Epoch  2    training loss: 0.086851 | validation loss: 0.141181
Epoch  3    training loss: 0.082644 | validation loss: 0.137157
Epoch  4    training loss: 0.079808 | validation loss: 0.138590
Epoch  5    training loss: 0.078739 | validation loss: 0.126426
Epoch  6    training loss: 0.074671 | validation loss: 0.129002
Epoch  7    training loss: 0.073189 | validation loss: 0.132413
Epoch  8    training loss: 0.073589 | validation loss: 0.129050
Epoch  9    training loss: 0.074706 | validation loss: 0.144906
Epoch  10   training loss: 0.075718 | validation loss: 0.140417
Epoch  11   training loss: 0.073115 | validation loss: 0.140514
Epoch  12   training loss: 0.072580 | validation loss: 0.136022
Epoch  13   training loss: 0.073757 | validation loss: 0.133527
Epoch  14   training loss: 0.070386 | validation loss: 0.135001
Epoch  15   training loss: 0.067909 | validation loss: 0.132079
Epoch  16   training loss: 0.068104 | va

In [92]:
test_loader = DataLoader(dataset=val_dataset, collate_fn=collate_batch, batch_size = 64, num_workers = 8, shuffle=True)

x, z, y = next(iter(test_loader))
x = x.to(device)
y = y.to(device)
z = z.to(device)

y_hat = model(
        x,
        teacher_forcing=False,
        sos_index=1,
        out_length=11,
    )

input_lang = train_dataset.input_lang
output_lang = train_dataset.target_lang

y_hat = y_hat.argmax(axis=-1)

print("Translation test\n")
print("Input                 |   Machine translation | Correct translation")
for i in range(10):
    print(input_lang.ind_to_str(x[i]).replace(';', ' '), output_lang.ind_to_str(y_hat[i]).strip('<>;'), " "*10, output_lang.ind_to_str(y[i]).strip(';<>'))

Translation test

Input                 |   Machine translation | Correct translation
thursday november 18 1999 1999-11-18            1999-11-18
4 mar 2019                2019-03-04            2019-03-04
april 9 2014              2014-04-29            2014-04-09
24 aug 1989               1998-08-24            1989-08-24
9 march 1982              1998-03-29            1982-03-09
sunday september 19 2021  2021-09-05            2021-09-19
28 september 2005         2002-09-28            2005-09-28
6 february 2016           2016-02-06            2016-02-06
sunday september 6 2009   2009-09-06            2009-09-06
22 jun 1989               1998-06-22            1989-06-22
