In [183]:
%reload_ext autoreload
%autoreload 2

from utils import load_dataset
import torch.nn.functional as F
import torch
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [222]:
m = 10000
data, human_vocab, machine_vocab = load_dataset(m)

100%|██████████| 10000/10000 [00:00<00:00, 34426.72it/s]

{'0', '2', '6', '4', '-', '1', '9', '3', '5', '7', '8'} 11
{'>': 0, '<': 1, '-': 2, '0': 3, '1': 4, '2': 5, '3': 6, '4': 7, '5': 8, '6': 9, '7': 10, '8': 11, '9': 12}





In [223]:
human_vocab

{';': 0,
 '?': 1,
 ' ': 2,
 '.': 3,
 '/': 4,
 '0': 5,
 '1': 6,
 '2': 7,
 '3': 8,
 '4': 9,
 '5': 10,
 '6': 11,
 '7': 12,
 '8': 13,
 '9': 14,
 'a': 15,
 'b': 16,
 'c': 17,
 'd': 18,
 'e': 19,
 'f': 20,
 'g': 21,
 'h': 22,
 'i': 23,
 'j': 24,
 'l': 25,
 'm': 26,
 'n': 27,
 'o': 28,
 'p': 29,
 'r': 30,
 's': 31,
 't': 32,
 'u': 33,
 'v': 34,
 'w': 35,
 'y': 36}

In [224]:
machine_vocab

{'>': 0,
 '<': 1,
 '-': 2,
 '0': 3,
 '1': 4,
 '2': 5,
 '3': 6,
 '4': 7,
 '5': 8,
 '6': 9,
 '7': 10,
 '8': 11,
 '9': 12}

In [225]:
data[:20]

[('8/8/71', '>1971-08-08<'),
 ('3/3/81', '>1981-03-03<'),
 ('4/8/96', '>1996-04-08<'),
 ('4/9/78', '>1978-04-09<'),
 ('6/5/73', '>1973-06-05<'),
 ('5/2/89', '>1989-05-02<'),
 ('2/1/98', '>1998-02-01<'),
 ('8/5/79', '>1979-08-05<'),
 ('6/7/18', '>2018-06-07<'),
 ('6/8/89', '>1989-06-08<'),
 ('6/3/18', '>2018-06-03<'),
 ('2/7/21', '>2021-02-07<'),
 ('5/5/20', '>2020-05-05<'),
 ('8/6/88', '>1988-08-06<'),
 ('3/5/06', '>2006-03-05<'),
 ('9/2/71', '>1971-09-02<'),
 ('5/8/92', '>1992-05-08<'),
 ('1/9/99', '>1999-01-09<'),
 ('2/1/92', '>1992-02-01<'),
 ('9/4/83', '>1983-09-04<')]

In [226]:
class Lang:
    def __init__(self, vocab: dict):
        self.vocab = vocab
        self.inv_vocab = {v:k for k,v in vocab.items()}
        self.vocab_size = len(vocab)

    def str_to_ind(self, str):
        return [self.vocab[c] for c in str]
    
    def ind_to_str(self, ind):
        return ''.join([self.inv_vocab[i] for i in ind])

In [227]:
test = Lang(human_vocab)
date = data[0][0]
print(date)
translated_date = test.str_to_ind(date)
print(translated_date)
reversed_translation = test.ind_to_str(translated_date)
print(reversed_translation)

8/8/71
[13, 4, 13, 4, 12, 6]
8/8/71


In [228]:
class TranslationDataset(Dataset):
    def __init__(self, data, input_vocab, output_vocab):
        self.input_lang = Lang(input_vocab)
        self.target_lang = Lang(output_vocab)

        self.data = data

        self.inputs = [self.input_lang.str_to_ind(input_sent) for input_sent, _ in self.data]
        self.targets = [self.target_lang.str_to_ind(target_sent) for _, target_sent in self.data]

    def __getitem__(self, index):
        return self.inputs[index], self.targets[index]
    
    def __len__(self):
        return len(self.inputs)


In [229]:
dataset = TranslationDataset(data, human_vocab, machine_vocab)

In [230]:
x,y = dataset[0]
print(x, dataset.input_lang.ind_to_str(x))
print(y, dataset.target_lang.ind_to_str(y))

[13, 4, 13, 4, 12, 6] 8/8/71
[0, 4, 12, 10, 4, 2, 3, 11, 2, 3, 11, 1] >1971-08-08<


In [231]:
def collate_batch(data):
    inputs = [torch.tensor(item[0], dtype=torch.int64) for item in data]
    targets = [torch.tensor(item[1], dtype=torch.int64) for item in data]

    input_batch = nn.utils.rnn.pad_sequence(inputs, batch_first=True)
    target_batch = torch.stack(targets)

    return input_batch, target_batch

In [232]:
loader = DataLoader(dataset=dataset, collate_fn=collate_batch, batch_size = 2, num_workers = 8)

In [233]:
class EncoderGRU(nn.Module):
    def __init__(self, vocab_size, hidden_size, num_layers = 1):
        super(EncoderGRU, self).__init__()

        self.vocab_size = vocab_size
        self.hidden_size = hidden_size
        self.gru = nn.GRU(vocab_size, hidden_size, num_layers = num_layers, batch_first=True)

    def forward(self, input, hidden = None):
        if hidden == None:
            hidden = self.init_hidden(input.shape[0]).to(input.device)
        one_hot = F.one_hot(input, num_classes=self.vocab_size).float()
        output, hidden = self.gru(one_hot, hidden)
        return output, hidden

    def init_hidden(self, batch_size):
        return torch.zeros(self.gru.num_layers, batch_size, self.hidden_size, dtype=torch.float32)

In [234]:
x_batch, y_batch = next(iter(loader))
print(x_batch)
print(y_batch)

tensor([[13,  4, 13,  4, 12,  6],
        [ 8,  4,  8,  4, 13,  6]])
tensor([[ 0,  4, 12, 10,  4,  2,  3, 11,  2,  3, 11,  1],
        [ 0,  4, 12, 11,  4,  2,  3,  6,  2,  3,  6,  1]])


In [235]:
encoder = EncoderGRU(len(human_vocab), 10, num_layers=1)
encoder

EncoderGRU(
  (gru): GRU(37, 10, batch_first=True)
)

In [236]:
out, hn = encoder(x_batch)
print(out.shape, hn.shape)

torch.Size([2, 6, 10]) torch.Size([1, 2, 10])


In [237]:
print(encoder.vocab_size)

37


In [238]:
F.one_hot(x_batch, num_classes=37).shape

torch.Size([2, 6, 37])

In [254]:
class DecoderGRU(nn.Module):
    def __init__(self, vocab_size, hidden_size, num_layers = 1):
        super(DecoderGRU, self).__init__()
        self.hidden_size = hidden_size

        self.vocab_size = vocab_size
        self.hidden_size = hidden_size
        self.gru = nn.GRU(vocab_size, hidden_size, num_layers = num_layers, batch_first=True)
        self.fc = nn.Linear(hidden_size, vocab_size)
        # self.softmax = nn.LogSoftmax(dim=1)

    def forward(self, input, hidden):
        if hidden == None:
            hidden = self.init_hidden(x.shape[0]).to(x.device)
        one_hot = F.one_hot(input, num_classes=self.vocab_size).float()
        print(one_hot.shape)
        output, hidden = self.gru(one_hot, hidden)
        output = self.fc(output)
        # output = self.softmax(self.out(output[0]))
        return output, hidden

    def init_hidden(self, batch_size):
        return torch.zeros(self.gru.num_layers, batch_size, self.hidden_size, dtype=torch.float32)

In [255]:
decoder = DecoderGRU(len(machine_vocab), 10, num_layers=1)
decoder

DecoderGRU(
  (gru): GRU(13, 10, batch_first=True)
  (fc): Linear(in_features=10, out_features=13, bias=True)
)

In [256]:
print(y_batch.shape)

torch.Size([2, 12])


In [257]:
y, h = decoder(y_batch, hn)
print(y.shape, h.shape)

torch.Size([2, 12, 13])
torch.Size([2, 12, 13]) torch.Size([1, 2, 10])


In [253]:
decoder.vocab_size

13