In [1]:
import datetime

import torch
import torch.nn as nn
from torch.utils.data import Dataset
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader

In [2]:
from google.colab import drive
drive.mount('/content/drive')


Mounted at /content/drive


In [3]:
fn = 'drive/My Drive/dataset_text.txt'


In [4]:
import torch
from torch.utils.data import Dataset
from torch.nn.utils.rnn import pad_sequence


class DatasetSeq2Seq(Dataset):
    def __init__(self, file_name, train_lang='en', bos: str = '~', eos: str = '#'):

        with open(file_name, 'r') as f:
            train = f.readlines()

        self.input_sequnces_vocab = {'pad': 0, bos: 1, eos: 2}
        self.output_sequnces_vocab = {'pad': 0, bos: 1, eos: 2}

        self.input_sequnces = []
        self.output_sequnces = []

        n_input = 3
        n_output = 3
        for line in train:
            split_line = line.split('\t')

            sequence = [self.input_sequnces_vocab[bos]]

            for char in split_line[0].strip():
                if self.input_sequnces_vocab.get(char) is None:
                    self.input_sequnces_vocab[char] = n_input
                    n_input += 1
                sequence.append(self.input_sequnces_vocab[char])
            sequence.append(self.input_sequnces_vocab[eos])

            target = [self.output_sequnces_vocab[bos]]
            for char in split_line[2].strip():
                if self.output_sequnces_vocab.get(char) is None:
                    self.output_sequnces_vocab[char] = n_output
                    n_output += 1
                target.append(self.output_sequnces_vocab[char])
            target.append(self.output_sequnces_vocab[eos])

            self.input_sequnces.append(sequence)
            self.output_sequnces.append(target)

        self.target_decode = [k for k in self.output_sequnces_vocab.keys()]

    def __len__(self):
        return len(self.input_sequnces)

    def __getitem__(self, index):
        return {
            'data': self.input_sequnces[index],
            'target': self.output_sequnces[index],
        }


def collate_fn(input_data):
    data = []
    targets = []

    for item in input_data:
        data.append(torch.as_tensor(item['data']))
        targets.append(torch.as_tensor(item['target']))

    data = pad_sequence(data, batch_first=True, padding_value=0)
    targets = pad_sequence(targets, batch_first=True, padding_value=0)
    data_mask = data > 0
    targets_mask = targets > 0

    return {'data': data, 
            'target': targets, 
            'data_mask': data_mask, 
            'targets_mask': targets_mask,
            }

In [5]:
dataset = DatasetSeq2Seq(fn)

In [6]:
#padding
# seq1 = [1, 2, 3, 4]
# seq2 = [9, 7, 6, 4, 3, 7, 5]
# pad seq1 equal seq2
# seq1 = [1, 2, 3, 4, 0, 0, 0]
# concat(seq1, seq2) [[1, 2, 3, 4, 0, 0, 0],
#                     [9, 7, 6, 4, 3, 7, 5]]

In [13]:
class Encoder(nn.Module):
    def __init__(self, vocab_len, emb_dim, hidden_dim):
        super().__init__()
        self.emb = nn.Embedding(vocab_len, emb_dim)
        self.rnn = nn.GRU(emb_dim, hidden_dim, batch_first=True)

    def forward(self, x):
        emb = self.emb(x)
        _, context = self.rnn(emb)

        return context

class Decoder(nn.Module):
    def __init__(self, vocab_len, emb_dim, hidden_dim, eos_id):
        super().__init__()
        self.emb = nn.Embedding(vocab_len, emb_dim)
        self.rnn = nn.GRU(emb_dim, hidden_dim, batch_first=True)
        self.classifier = nn.Linear(hidden_dim, vocab_len)
        self.do = nn.Dropout(0.1)
        self.eos = eos_id

    def forward(self, context, target_sequence):
        if self.training:
            emb = self.emb(target_sequence)
            all_hid, _ = self.rnn(emb, context)
            pred_cls = self.classifier(self.do(all_hid))

            return pred_cls
        else:
            predicts = []
            probas = []
            predicted_token = target_sequence #on this step target sequence contains bos only
            i = 0
            while predicted_token.item() != self.eos and i < 20:
                emb = self.emb(predicted_token)
                context, _ = self.rnn(emb, context) # context B x 1 x Hid
                pred = self.classifier(context)
                predicted_token = torch.argmax(pred, dim=-1) # B x 1
                predicts.append(predicted_token)
                probas.append(torch.softmax(pred, dim=-1))
                i += 1
            probas = torch.cat(probas, dim=1)
            
            return torch.cat(predicts, dim=1)

In [14]:
class DateNormalizer(nn.Module):
    def __init__(self, input_vocab_len, target_vocab_len, 
                 emb__dim, hidden_dim, eos_id):
        super().__init__()
        self.encoder = Encoder(input_vocab_len, emb_dim, hidden_dim)
        self.decoder = Decoder(target_vocab_len, emb_dim, hidden_dim, eos_id)

    def forward(self, x, target_sequence):
        context = self.encoder(x)
        pred = self.decoder(context, target_sequence)

        return pred

In [9]:
#hyper params
input_vocab_size = len(dataset.input_sequnces_vocab)
target_vocab_size = len(dataset.output_sequnces_vocab)
eos_id = dataset.output_sequnces_vocab['#']
#TODO try to use other model parameters
emb_dim = 128
hidden = 256
n_epochs = 10
batch_size = 64
cuda_device = -1
batch_size = 100
device = f'cuda:{cuda_device}' if cuda_device != -1 else 'cpu'

In [15]:
model = DateNormalizer(input_vocab_size, target_vocab_size, emb_dim, hidden, eos_id)
model.train()
optim = torch.optim.Adam(model.parameters(), lr=0.001)
loss_func = nn.CrossEntropyLoss()

In [16]:

for epoch in range(n_epochs):
    dataloader = DataLoader(dataset, 
                            batch_size, 
                            shuffle=True, 
                            collate_fn=collate_fn,
                            drop_last = True,
                            )
    for i, batch in enumerate(dataloader):
        optim.zero_grad()
        target = batch['target'].to(device)
        predicts = model(batch['data'].to(device), target[:, :-1])
        loss = loss_func(predicts.reshape(-1, target_vocab_size), 
                         target[:, 1:].reshape(-1))
        
        loss.backward()
        optim.step()
        if i % 100 == 0:
            print(f'epoch: {epoch}, step: {i}, loss: {loss.item()}')
    test = '14 мая 1978'
    test_encoded = torch.tensor([[dataset.input_sequnces_vocab[c] for c in test]])
    test_encoded = test_encoded.to(device)
    bos_input = torch.tensor([[dataset.output_sequnces_vocab['~']]]).to(device)
    with torch.no_grad():
        model.eval()
        test_pred = model(test_encoded, bos_input)
        model.train()
    decode = list(dataset.output_sequnces_vocab.keys())
    out_str = ''
    for i in test_pred.squeeze().cpu().detach().tolist():
        out_str += decode[i]
    print(out_str)

    torch.save(model.state_dict(), f'./seq2seq_chkpt_{epoch}.pth')

epoch: 0, step: 0, loss: 2.714406728744507
1998-02-12
#
epoch: 1, step: 0, loss: 0.8631631731987
1974-01-14
#
epoch: 2, step: 0, loss: 0.7566551566123962
1977-07-19
#
epoch: 3, step: 0, loss: 0.6143466830253601
1978-01-03
#
epoch: 4, step: 0, loss: 0.5181354284286499
1978-04-03
#
epoch: 5, step: 0, loss: 0.45870569348335266
1978-04-18
#
epoch: 6, step: 0, loss: 0.3724606931209564
1978-03-13
#
epoch: 7, step: 0, loss: 0.3008843660354614
1978-03-19
#
epoch: 8, step: 0, loss: 0.26437878608703613
1978-03-16
#
epoch: 9, step: 0, loss: 0.20521339774131775
1978-05-14
#


In [17]:
dataset.output_sequnces_vocab

{'pad': 0,
 '~': 1,
 '#': 2,
 '1': 3,
 '9': 4,
 '8': 5,
 '-': 6,
 '0': 7,
 '5': 8,
 '\n': 9,
 '2': 10,
 '7': 11,
 '4': 12,
 '6': 13,
 '3': 14}