In [None]:
import datetime

import torch
import torch.nn as nn
from torch.utils.data import Dataset
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader

In [None]:
from google.colab import drive
drive.mount('/content/drive')


Mounted at /content/drive


In [None]:
fn = 'drive/My Drive/dataset_text.txt'


In [None]:
import torch
from torch.utils.data import Dataset
from torch.nn.utils.rnn import pad_sequence


class DatasetSeq2Seq(Dataset):
    def __init__(self, file_name, train_lang='en', bos: str = '~', eos: str = '#'):

        with open(file_name, 'r') as f:
            train = f.readlines()

        # init vocabs with spetial tokens
        self.input_sequnces_vocab = {'pad': 0, bos: 1, eos: 2}
        self.output_sequnces_vocab = {'pad': 0, bos: 1, eos: 2}

        self.input_sequnces = []
        self.output_sequnces = []

        n_input = 3
        n_output = 3
        for line in train:
            split_line = line.split('\t')

            sequence = [self.input_sequnces_vocab[bos]]

            for char in split_line[0].strip():
                if self.input_sequnces_vocab.get(char) is None:
                    self.input_sequnces_vocab[char] = n_input
                    n_input += 1
                sequence.append(self.input_sequnces_vocab[char])
            sequence.append(self.input_sequnces_vocab[eos])

            target = [self.output_sequnces_vocab[bos]]
            for char in split_line[2].strip():
                if self.output_sequnces_vocab.get(char) is None:
                    self.output_sequnces_vocab[char] = n_output
                    n_output += 1
                target.append(self.output_sequnces_vocab[char])
            target.append(self.output_sequnces_vocab[eos])

            self.input_sequnces.append(sequence)
            self.output_sequnces.append(target)

        self.target_decode = [k for k in self.output_sequnces_vocab.keys()]

    def __len__(self):
        return len(self.input_sequnces)

    def __getitem__(self, index):
        return {
            'data': self.input_sequnces[index],
            'target': self.output_sequnces[index],
        }


def collate_fn(input_data):
    data = []
    targets = []

    for item in input_data:
        data.append(torch.as_tensor(item['data']))
        targets.append(torch.as_tensor(item['target']))

    data = pad_sequence(data, batch_first=True, padding_value=0)
    targets = pad_sequence(targets, batch_first=True, padding_value=0)
    data_mask = data > 0
    targets_mask = targets > 0

    return {'data': data, 
            'target': targets, 
            'data_mask': data_mask, 
            'targets_mask': targets_mask,
            }

In [None]:
dataset = DatasetSeq2Seq(fn)

In [None]:
#    test = '17 ноября 2022 г.'

class Encoder(nn.Module):
    def __init__(self, vocab_len, emb_dim, hidden_dim):
        super().__init__()
        self.emb = nn.Embedding(vocab_len, emb_dim)
        self.rnn = nn.GRU(emb_dim, hidden_dim, batch_first=True)

    def forward(self, x):
        emb = self.emb(x)
        _, context = self.rnn(emb)

        return context
# 1 vector 


class Decoder(nn.Module):
    def __init__(self, vocab_len, emb_dim, hidden_dim, eos_id, n_iter = 20):
        super().__init__()
        self.emb = nn.Embedding(vocab_len, emb_dim)
        self.rnn = nn.GRU(emb_dim, hidden_dim, batch_first=True)
        self.clf = nn.Linear(hidden_dim, vocab_len)
        self.do = nn.Dropout(0.1)
        self.eos_id = eos_id
        self.n_iter = n_iter

    def forward(self, context, target_sequence):
        if self.training:
            # traget sequence начиналась с bos, например ~2022-11-17
            emb = self.emb(target_sequence)
            dec_context, _ = self.rnn(emb, context)
            pred = self.clf(self.do(dec_context))

            return pred
        else:
            predicts = []
            # в инференсе в target sequene будет только bos
            predicted_token = target_sequence # [~]
            i = 0
            while predicted_token.item() != self.eos_id and i < self.n_iter:
                emb = self.emb(predicted_token) 
                context, _ = self.rnn(emb, context) 
                pred = self.clf(self.do(context))
                predicted_token = torch.argmax(pred, dim=-1)
                predicts.append(predicted_token)
                i += 1
            
            return torch.cat(predicts, dim=-1)

   # 2022-11-17#     

In [None]:
class DateNormalizer(nn.Module):
    def __init__(self, input_vocab_len, target_vocab_len, 
                 emb__dim, hidden_dim, eos_id):
        super().__init__()
        self.encoder = Encoder(input_vocab_len, emb_dim, hidden_dim)
        self.decoder = Decoder(target_vocab_len, emb_dim, hidden_dim, eos_id)

    def forward(self, x, target_sequence):
        context = self.encoder(x)
        pred = self.decoder(context, target_sequence)

        return pred

In [None]:
#hyper params
input_vocab_size = len(dataset.input_sequnces_vocab)
target_vocab_size = len(dataset.output_sequnces_vocab)
eos_id = dataset.output_sequnces_vocab['#']
#TODO try to use other model parameters
emb_dim = 128
hidden = 256
n_epochs = 20
batch_size = 64
cuda_device = 0
batch_size = 100
device = f'cuda:{cuda_device}' if cuda_device != -1 else 'cpu'

In [None]:
model = DateNormalizer(input_vocab_size, target_vocab_size, emb_dim, hidden, eos_id).to(device)
model.train()
optim = torch.optim.Adam(model.parameters(), lr=0.001)
loss_func = nn.CrossEntropyLoss()

In [None]:

for epoch in range(n_epochs):
    dataloader = DataLoader(dataset, 
                            batch_size, 
                            shuffle=True, 
                            collate_fn=collate_fn,
                            drop_last = True,
                            )
    for i, batch in enumerate(dataloader):
        optim.zero_grad()
        target = batch['target'].to(device)
        predicts = model(batch['data'].to(device), target[:, :-1])
        loss = loss_func(predicts.reshape(-1, target_vocab_size), 
                         target[:, 1:].reshape(-1))
        
        loss.backward()
        optim.step()
        if i % 100 == 0:
            print(f'epoch: {epoch}, step: {i}, loss: {loss.item()}')
    test = '17 ноября 2016 г.'
    test_encoded = torch.tensor([[dataset.input_sequnces_vocab[c] for c in test]])
    test_encoded = test_encoded.to(device)
    bos_input = torch.tensor([[dataset.output_sequnces_vocab['~']]]).to(device)
    with torch.no_grad():
        model.eval()
        test_pred = model(test_encoded, bos_input)
        model.train()
    decode = list(dataset.output_sequnces_vocab.keys())
    out_str = ''
    for i in test_pred.squeeze().cpu().detach().tolist():
        out_str += decode[i]
    print(out_str)

    torch.save(model.state_dict(), f'./seq2seq_chkpt_{epoch}.pth')

epoch: 0, step: 0, loss: 2.6435515880584717
1986-06-09#
epoch: 1, step: 0, loss: 0.9323967695236206
2016-06-06#
epoch: 2, step: 0, loss: 0.8316783308982849
2017-01-19#
epoch: 3, step: 0, loss: 0.6996486186981201
2016-07-12#
epoch: 4, step: 0, loss: 0.6092305183410645
2016-09-21#
epoch: 5, step: 0, loss: 0.5716256499290466
2016-06-14#
epoch: 6, step: 0, loss: 0.5558145642280579
2016-07-13#
epoch: 7, step: 0, loss: 0.5374111533164978
2016-11-21#
epoch: 8, step: 0, loss: 0.4318406581878662
2016-11-21#
epoch: 9, step: 0, loss: 0.3435434103012085
2016-11-17#
epoch: 10, step: 0, loss: 0.2922822833061218
2016-12-11#
epoch: 11, step: 0, loss: 0.29315754771232605
2016-12-16#
epoch: 12, step: 0, loss: 0.15326042473316193
2016-11-21#
epoch: 13, step: 0, loss: 0.11320129781961441
2016-11-17#
epoch: 14, step: 0, loss: 0.061911772936582565
2016-11-21#
epoch: 15, step: 0, loss: 0.039976563304662704
2016-11-21#
epoch: 16, step: 0, loss: 0.024007968604564667
2016-11-21#
epoch: 17, step: 0, loss: 0.0254

In [None]:
test = '21 11 2019'
test_encoded = torch.tensor([[dataset.input_sequnces_vocab[c] for c in test]])
test_encoded = test_encoded.to(device)
bos_input = torch.tensor([[dataset.output_sequnces_vocab['~']]]).to(device)
with torch.no_grad():
    model.eval()
    test_pred = model(test_encoded, bos_input)
    model.train()
decode = list(dataset.output_sequnces_vocab.keys())
out_str = ''
for i in test_pred.squeeze().cpu().detach().tolist():
    out_str += decode[i]
print(out_str)

2018-11-27#


In [None]:
len(dataset)

5000

In [None]:
import torch
import torch.nn as nn
import numpy as np


class ScaledDotProductAttention(nn.Module):
    ''' Scaled Dot-Product Attention '''

    def __init__(self, temperature=1, attn_dropout=0.1):
        super().__init__()
        self.temperature = temperature
        self.dropout = nn.Dropout(attn_dropout)
        self.softmax = nn.Softmax(dim=2)

    def forward(self, q, k, v, mask=None):

        attn = torch.bmm(q, # B x T1 x V
                         k.transpose(1, 2), # B x T2 x V -> B x V x T2
                         ) # B x T1 x T2
        attn = attn / self.temperature

        if mask is not None:
            # print(attn.size(), mask.size())
            attn = attn.masked_fill(~mask, -np.inf)

        attn = self.softmax(attn)

        if mask is not None:
            attn = attn.masked_fill(~mask, 0.)

        attn = self.dropout(attn)
        output = torch.bmm(attn, v) # B x T1 x T2 @ B x T1 x V

        return output, attn

In [None]:

class EncoderAttn(nn.Module):
    def __init__(self, vocab_len, emb_dim, hidden_dim):
        super().__init__()
        self.emb = nn.Embedding(vocab_len, emb_dim)
        self.rnn = nn.GRU(emb_dim, hidden_dim, batch_first=True)

    def forward(self, x):
        emb = self.emb(x)
        full_context, context = self.rnn(emb)

        return context, full_context
# 1 vector 


class DecoderAttn(nn.Module):
    def __init__(self, vocab_len, emb_dim, hidden_dim, eos_id, n_iter = 20):
        super().__init__()
        self.emb = nn.Embedding(vocab_len, emb_dim)
        self.rnn = nn.GRU(emb_dim, hidden_dim, batch_first=True)
        self.clf = nn.Linear(2 * hidden_dim , vocab_len)
        self.do = nn.Dropout(0.1)
        self.eos_id = eos_id
        self.n_iter = n_iter
        self.attn = ScaledDotProductAttention()

    def forward(self, context, target_sequence, enc_context):
        if self.training:
            # traget sequence начиналась с bos, например ~2022-11-17
            emb = self.emb(target_sequence)
            dec_context, _ = self.rnn(emb, context)
            attn, attn_mtx = self.attn(dec_context, enc_context, enc_context)
            pred = self.clf(self.do(torch.cat((dec_context, attn), dim=-1)))

            return pred
        else:
            predicts = []
            # в инференсе в target sequene будет только bos
            predicted_token = target_sequence # [~]
            i = 0
            while predicted_token.item() != self.eos_id and i < self.n_iter:
                emb = self.emb(predicted_token) 
                context, _ = self.rnn(emb, context)
                attn, attn_mtx = self.attn(context, enc_context, enc_context) 
                pred = self.clf(self.do(torch.cat((context, attn), dim=-1)))
                predicted_token = torch.argmax(pred, dim=-1)
                predicts.append(predicted_token)
                i += 1
            
            return torch.cat(predicts, dim=-1)


In [None]:
class DateNormalizerAttn(nn.Module):
    def __init__(self, input_vocab_len, target_vocab_len, 
                 emb__dim, hidden_dim, eos_id):
        super().__init__()
        self.encoder = EncoderAttn(input_vocab_len, emb_dim, hidden_dim)
        self.decoder = DecoderAttn(target_vocab_len, emb_dim, hidden_dim, eos_id)

    def forward(self, x, target_sequence, mask=None):
        context, all_hid = self.encoder(x)
        pred = self.decoder(context, target_sequence, all_hid)

        return pred

In [None]:
model = DateNormalizerAttn(input_vocab_size, target_vocab_size, emb_dim, hidden, eos_id).to(device)
model.train()
optim = torch.optim.Adam(model.parameters(), lr=0.001)
loss_func = nn.CrossEntropyLoss()

In [None]:

for epoch in range(n_epochs):
    dataloader = DataLoader(dataset, 
                            batch_size, 
                            shuffle=True, 
                            collate_fn=collate_fn,
                            drop_last = True,
                            )
    for i, batch in enumerate(dataloader):
        optim.zero_grad()
        target = batch['target'].to(device)
        predicts = model(batch['data'].to(device), target[:, :-1])#, batch['data_mask'].to(device).unsqueeze(-1))
        loss = loss_func(predicts.reshape(-1, target_vocab_size), 
                         target[:, 1:].reshape(-1))
        
        loss.backward()
        optim.step()
        if i % 100 == 0:
            print(f'epoch: {epoch}, step: {i}, loss: {loss.item()}')
    test = '17 ноября 1117 г.'
    test_encoded = torch.tensor([[dataset.input_sequnces_vocab[c] for c in test]])
    test_encoded = test_encoded.to(device)
    bos_input = torch.tensor([[dataset.output_sequnces_vocab['~']]]).to(device)
    with torch.no_grad():
        model.eval()
        test_pred = model(test_encoded, bos_input)
        model.train()
    decode = list(dataset.output_sequnces_vocab.keys())
    out_str = ''
    for i in test_pred.squeeze().cpu().detach().tolist():
        out_str += decode[i]
    print(out_str)

    torch.save(model.state_dict(), f'./seq2seq_chkpt_{epoch}.pth')

epoch: 0, step: 0, loss: 0.0718703642487526
17797-11-17#
epoch: 1, step: 0, loss: 0.06557723134756088
1777-11-17#
epoch: 2, step: 0, loss: 0.04026917368173599
1777-11-17#
epoch: 3, step: 0, loss: 0.05634636804461479
1777-11-17#
epoch: 4, step: 0, loss: 0.0672856941819191
1777-11-17#
epoch: 5, step: 0, loss: 0.047594889998435974
1977-11-11#
epoch: 6, step: 0, loss: 0.056803375482559204
1977-11-11#
epoch: 7, step: 0, loss: 0.03923920914530754
1977-11-11#
epoch: 8, step: 0, loss: 0.04330570995807648
1977-11-11#
epoch: 9, step: 0, loss: 0.055095501244068146
1977-11-11#
epoch: 10, step: 0, loss: 0.04330611601471901
1977-11-11#
epoch: 11, step: 0, loss: 0.039025381207466125
1977-11-11#
epoch: 12, step: 0, loss: 0.04305577278137207
1977-11-11#
epoch: 13, step: 0, loss: 0.05774302780628204
1977-11-11#
epoch: 14, step: 0, loss: 0.03917933255434036
1977-11-11#
epoch: 15, step: 0, loss: 0.04619215056300163
1977-11-11#
epoch: 16, step: 0, loss: 0.03463716804981232
1977-11-11#
epoch: 17, step: 0, l