In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
import sys
sys.path.insert(0,'/content/drive/MyDrive/NLP/Transformer')

In [None]:
from model.transformer import Transformer

In [None]:
import torch
import io 
from torch import nn
from math import sqrt
import numpy as np
from torch.utils.data import Dataset, DataLoader
from torch.utils.tensorboard import SummaryWriter
from torch.utils.data.sampler import Sampler
from torch.autograd import Variable
import torch.nn.functional as F
from  torch.optim.lr_scheduler import LambdaLR, StepLR
import torch.optim as optim
import math
import os 
from time import time 
from sklearn.model_selection import train_test_split
import random
from time import time 
import copy

torch.random.manual_seed(0)
cfg_train_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
data_path = '/content/drive/MyDrive/NLP/Transformer' 
anki_data_path =  '/content/drive/MyDrive/NLP/Transformer/data/anki_rus.txt' 

cfg_exp_name = 'anki' # from ['digits', 'anki']
cfg_train_init_lr = 1. if cfg_exp_name == 'digits' else 2.  # for CE was 0.05
cfg_train_ckpt_path = os.path.join(data_path, cfg_exp_name, 'checkpoints', 'checkpoint')
cfg_train_logs_dir = os.path.join(data_path, cfg_exp_name, 'logs')
cfg_train_nrof_epochs = 3 if cfg_exp_name == 'digits' else 15
cfg_train_log_interval = 100 if cfg_exp_name == 'digits' else 100
cfg_train_batch_size = 128
cfg_train_train_size = 50000 if cfg_exp_name == 'digits' else None 
cfg_train_val_size = 1000 if cfg_exp_name == 'digits' else None 
cfg_train_load = True    
cfg_train_warmup_steps = 500 if cfg_exp_name == 'digits' else 4000
cfg_train_dropout_prob = 0.2
seed_val = 0

## Utils:

In [None]:
def get_mask(batched_sequence, decoding=False):
    '''
    batched_seq of shape (b_s, max_seq_len, emb_size)
    '''
    b_s, max_seq_len = batched_sequence.shape
    mask_pad = batched_sequence.unsqueeze(1).repeat_interleave(max_seq_len, dim=1) != 0
    mask_pad = ~ (mask_pad * mask_pad.permute(0,2,1))
    if decoding:
        mask = torch.full((b_s, max_seq_len, max_seq_len), True)
        mask = torch.triu(mask, diagonal=1) 
        mask[mask_pad] = True
        mask[mask.prod(dim=1)==1] = False #???
        return mask 
    mask_pad[mask_pad.prod(dim=1)==1] = False
    return mask_pad

## Data:

### Digits:

In [None]:
import random

def get_digit_data():
    train_digit_sequences = [] #torch.randint(1, 10, (1000, 25)).tolist()
    val_digit_sequences = []

    for _ in range(cfg_train_train_size):
        seq_len = random.randint(5, 25)
        train_digit_sequences.append(torch.tensor([1] + np.random.randint(3, 10, seq_len).tolist() + [2]))

    for _ in range(cfg_train_val_size):
        seq_len = random.randint(5, 25)
        val_digit_sequences.append(torch.tensor([1] + np.random.randint(3, 10, seq_len).tolist() + [2]))
    
    return train_digit_sequences, val_digit_sequences

class Decoding:
    @staticmethod
    def decode(sequence):
        return sequence

class DigitDataset(Dataset):
    def __init__(self, data, vocab_size):
        super(DigitDataset, self).__init__()
        self.data = data
        self.tokenizers = {'input':Decoding(),
                           'output':Decoding()}

    def __len__(self):
        return len(self.data)
    
    @property
    def vocab_size(self):        
        return {'input': 10,
                'output': 10}

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()
        
        return self.data[idx], self.data[idx]

### Anki data:

#### Dataset utils:

In [None]:
!pip install youtokentome

In [None]:
import youtokentome as yttm
from itertools import chain
from random import shuffle

In [None]:
def split_anki_dataset():
    import sklearn

    with open(anki_data_path, 'r') as f:
        data = f.readlines()

    train, test = sklearn.model_selection.train_test_split(data, test_size=0.2, random_state=11)
    val, test = sklearn.model_selection.train_test_split(data, test_size=0.5, random_state=11)
    
    for phase, data in zip(('train', 'val', 'test'), (train, val, test)):
        ru_file_writer = open('ru_' + phase + '_set', 'w')
        en_file_writer = open('en_' + phase + '_set', 'w')
        for pair in data:
            pair_split = pair.split('\t')
            ru_file_writer.write(pair_split[1]+'\n')
            en_file_writer.write(pair_split[0]+'\n')
        ru_file_writer.close()
        en_file_writer.close()

def chunks(l, n):
    max_len = len(l) - len(l)%n
    return [l[i:i+n] for i in range(0, max_len, n)]

def get_yttm_bpe(data_path, data_type):
    tok_path = data_type + '_' +'yttm_tokenizer'
    if not os.path.exists(tok_path):
        print('BPE training started')
        t = time()
        tokenizer = yttm.BPE.train(data=data_path, vocab_size=50000, model=tok_path,
                                   pad_id=0, unk_id=1, bos_id=2, eos_id=3)
        print('BPE trained after {} sec'.format(time() - t))
    else:
        tokenizer = yttm.BPE(model=tok_path)
        print('BPE loaded')
    return tokenizer


def my_collate(batch):
    input_seqs = [seq_pair[0][1:-1] for seq_pair in batch]
    output_input_seqs = [seq_pair[1][:-1] for seq_pair in batch]
    output_output_seqs = [seq_pair[1][1:] for seq_pair in batch]
    padded_input = nn.utils.rnn.pad_sequence(input_seqs, batch_first=True)
    padded_output_input = nn.utils.rnn.pad_sequence(output_input_seqs, batch_first=True)
    padded_output_output = nn.utils.rnn.pad_sequence(output_output_seqs, batch_first=True)
    input_mask = get_mask(padded_input)
    output_mask = get_mask(padded_output_input, decoding=True)

    return padded_input, padded_output_input, padded_output_output, input_mask, output_mask

In [None]:
class TranslationDataset(Dataset):
    def __init__(self, en_data_path, ru_data_path, output='ru'):
        super(TranslationDataset, self).__init__()
        self.tokenizers = {'input': None, 'target': None}

        en_type = 'output' if output=='en' else 'input'
        ru_type = 'output' if output=='ru' else 'input'
        en_data = self.read_and_preprocess(en_data_path, en_type)
        ru_data = self.read_and_preprocess(ru_data_path, ru_type)

        self.en_dataset, self.ru_dataset = self.sort(en_data, ru_data)

    def read_and_preprocess(self, data_path, data_type):
        with open(data_path, 'r') as f:
            data = f.readlines()
        data = [d[:d.find('\n')] for d in data if not d.find('\n')==-1]
        dataset = []
        

        tokenizer = get_yttm_bpe(data_path, data_type)
        self.tokenizers[data_type] = tokenizer
        
        for i, line in enumerate(data):
            output = tokenizer.encode(line, output_type=yttm.OutputType.ID, bos=True, eos=True)
            dataset.append(output)
        
        return dataset
    
    def sort(self, en_data, ru_data):
        en_ru_sorted = sorted(zip(en_data, ru_data),
                              key=lambda x:max(len(x[0]), len(x[1])))
        return list(zip(*en_ru_sorted))

    def __len__(self):
        return len(self.en_dataset)
    
    @property
    def vocab_size(self):
        input_vocab_size = len(self.tokenizers['input'].vocab())
        output_vocab_size = len(self.tokenizers['output'].vocab())
        
        return {'input': input_vocab_size,
                'output': output_vocab_size}


    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        return torch.tensor(self.en_dataset[idx]), torch.tensor(self.ru_dataset[idx])

class RandomSortingSampler(Sampler):
    def __init__(self, sorted_data, batch_size=32, shuffle=False):
        super(RandomSortingSampler, self).__init__(sorted_data)
        self.dataset_len = len(sorted_data)
        self.batch_size = batch_size
        self.shuffle = shuffle
        self._reset_sampler()
        self.drop_last = shuffle

    def _reset_sampler(self):
        ids = range(self.dataset_len)
        if self.shuffle:
            ids = [ids[i:i + self.batch_size] for i in range(0, len(ids), self.batch_size) if i + self.batch_size < len(ids)]
            random.shuffle(ids)
            ids = list(chain.from_iterable(ids))
        self.sampler = iter(ids)


    def __iter__(self):
        batch = []
        for idx in self.sampler:
            batch.append(idx)
            if len(batch) == self.batch_size:
                yield batch
                batch = []
        if len(batch) > 0 and not self.drop_last:
            yield batch
        self._reset_sampler()

    def __len__(self):
        if self.drop_last:
            return self.dataset_len // self.batch_size  # type: ignore
        else:
            return (self.dataset_len + self.batch_size - 1) // self.batch_size

## Train:

### Training:

In [None]:
def lr_lbmd(cur_step, emb_size=512, warmup_steps=cfg_train_warmup_steps):
    decay = emb_size**(-0.5) * min((cur_step + 1) **(-0.5), (cur_step + 1) * warmup_steps**(-1.5))
    return decay

def SCELoss(predicted, target, eps=0.01, pad=0):
    '''
    CE loss with smoothed labels
    '''
    predicted = predicted.permute(0,2,1)
    K = predicted.shape[-1]
    pad_mask = target == pad
    with torch.no_grad():
        ohe_target = torch.nn.functional.one_hot(target, K).float()
        ohe_target *= (1 - eps)
        ohe_target += eps / K
        ohe_target[pad_mask] = 0
        nrof_nonzero = (~pad_mask).sum()
    return - (F.log_softmax(predicted, -1) * ohe_target).sum() / nrof_nonzero


In [None]:
class Train:
    def __init__(self, traindataset=None, trainloader=None, 
                 valdataset=None, valloader=None, to_log=False):
        self.trainloader = trainloader
        self.train_dataset = traindataset
        self.valloader = valloader
        self.valdataset = valdataset

        self.model = Transformer(traindataset.vocab_size['input'],
                                 traindataset.vocab_size['output'])
        self.model.to(cfg_train_device)
        
        # self.crit = nn.CrossEntropyLoss()
        self.crit = SCELoss
        self.optim = optim.Adam(self.model.parameters(), betas=(0.9, 0.98),
                                eps=1e-9, lr=cfg_train_init_lr)
        self.scheduler = LambdaLR(self.optim, lr_lambda=lr_lbmd)
        self.nrof_epochs = cfg_train_nrof_epochs
        self.epoch_size = len(self.train_dataset) // cfg_train_batch_size + 1
        self.cur_epoch, self.global_step = 0, 0

        self.to_log = to_log
        if self.to_log:
            self.train_writer = SummaryWriter(os.path.join(cfg_train_logs_dir, 'train'))
            self.val_writer = SummaryWriter(os.path.join(cfg_train_logs_dir, 'val'))
            self.best_loss = 1000.



    def decode_sequences(self, sequences, seq_type='output'):
        '''
        sequences (numpy.ndarray): batch of numerical sequences
        '''
        tokenizer = self.train_dataset.tokenizers[seq_type]
        sentences = tokenizer.decode(sequences)
        return sentences


    def save_model(self):
        if not os.path.exists(os.path.dirname(cfg_train_ckpt_path)):
            os.makedirs(os.path.dirname(cfg_train_ckpt_path))

        torch.save({"step": self.global_step,
                    "model": self.model.state_dict(),
                    "optimizer": self.optim.state_dict(),
                    "scheduler": self.scheduler.state_dict(),
                    "loss": self.best_loss},
                   cfg_train_ckpt_path)

        print("Model saved...")

    def load_model(self):
        ckpt = torch.load(cfg_train_ckpt_path)
        self.cur_epoch = ckpt["step"] // self.epoch_size + 1
        self.global_step = ckpt["step"] + 1
        self.model.load_state_dict(ckpt["model"])
        self.optim.load_state_dict(ckpt["optimizer"])
        self.scheduler.load_state_dict(ckpt["scheduler"])
        self.best_loss = ckpt["loss"]

    def train_epoch(self):
        t = time()
        self.model.train()
        nrof_samples, cur_loss = 0, 0.0

        for batch_idx, batch in enumerate(self.trainloader):
            input, target_input, target_output, input_mask, target_mask = batch
            input = input.to(cfg_train_device)
            target_input = target_input.to(cfg_train_device)
            target_output = target_output.to(cfg_train_device)
            input_mask = input_mask.to(cfg_train_device)
            target_mask = target_mask.to(cfg_train_device)
            self.optim.zero_grad()
            outputs = self.model(input, target_input, input_mask, target_mask)

            loss = self.crit(outputs.permute(0,2,1), target_output)
            loss.backward()

            cur_loss += loss.item()
            nrof_samples += len(input)

            self.optim.step()
            self.scheduler.step()
            self.global_step += 1

            if batch_idx % cfg_train_log_interval == 0 and batch_idx!=0: 
                print('Batch num:', batch_idx)
                decoded_inputs = self.decode_sequences(input[:4,:].detach().cpu().numpy().tolist(), seq_type='input')
                decoded_outputs = self.decode_sequences(torch.argmax(outputs[:4,:,:], dim=-1).detach().cpu().numpy().tolist())
                decoded_targets = self.decode_sequences(target_output[:4,:].detach().cpu().numpy().tolist())
                print("Train loss: {:.4f}".format(cur_loss/ nrof_samples))
                print('decoded_inputs', decoded_inputs)
                print('decoded_targets', decoded_targets)
                print('decoded_outputs', decoded_outputs)

                if self.to_log:
                    self.train_writer.add_scalar('Loss', cur_loss / nrof_samples, self.global_step)
                    self.train_writer.add_scalar('LR', self.optim.state_dict()["param_groups"][0]["lr"], self.global_step)
                
                nrof_samples, cur_loss = 0, 0.0
            
    def validate(self):
        self.model.eval()
        nrof_samples, cur_loss = 0, 0.0

        with torch.no_grad():
            for batch_idx, batch in enumerate(self.valloader):
                input, target_input, target_output, input_mask, target_mask = batch
                input = input.to(cfg_train_device)
                target_input = target_input.to(cfg_train_device)
                target_output = target_output.to(cfg_train_device)
                input_mask = input_mask.to(cfg_train_device)
                target_mask = target_mask.to(cfg_train_device)

                outputs = self.model(input, target_input, input_mask, target_mask)

                loss = self.crit(outputs.permute(0,2,1), target_output)
                cur_loss += loss.item()
                nrof_samples += len(input)

                if batch_idx==0:
                    decoded_inputs = self.decode_sequences(input[:4,:].detach().cpu().numpy().tolist(), seq_type='input')
                    decoded_outputs = self.decode_sequences(torch.argmax(outputs[:4,:,:], dim=-1).detach().cpu().numpy().tolist())
                    decoded_targets = self.decode_sequences(target_output[:4,:].detach().cpu().numpy().tolist())

                    print('decoded_inputs', decoded_inputs)
                    print('decoded_targets', decoded_targets)
                    print('decoded_outputs', decoded_outputs)
                    

        return cur_loss/nrof_samples

    def train(self):
        if cfg_train_load:
            self.load_model()

        for epoch in range(self.cur_epoch, self.cur_epoch + self.nrof_epochs):
            self.train_epoch()
            val_loss = self.validate()
            print("Epoch {} trained\nVal loss: {:.4f}".format(epoch, val_loss))
            if self.to_log:
                self.val_writer.add_scalar('Loss', val_loss, self.global_step)
            self.save_model()

    
    def _get_eval_out(self, encoded_input, decoded_sequence):
        stacked_decoded_sequence = torch.tensor(decoded_sequence).unsqueeze(0).to(cfg_train_device)
        target_mask = get_mask(stacked_decoded_sequence, decoding=True)

        with torch.no_grad():
            output = self.model.decode(encoded_input, stacked_decoded_sequence, target_mask)
        return output


    def evaluate_greedy(self, sequence, stop_predict_count=30, bos=1, eos=2):
        '''
        Вывод должен совпадать с результатом evaluate_beam() при beam=1 
        '''
        self.load_model()
        self.model.eval()
        current_symbol, decoded_sequence = bos, [bos]
        input_mask = get_mask(sequence)
        encoded_input = self.model.encode(sequence, input_mask)

        while not current_symbol==eos and len(decoded_sequence) < stop_predict_count:
            output = self._get_eval_out(encoded_input, decoded_sequence)
            decoded_outputs = torch.argmax(output[:, :, :], dim=-1)
            current_symbol = decoded_outputs[0][len(decoded_sequence) - 1].cpu().numpy().tolist()
            decoded_sequence.append(current_symbol)


        return decoded_sequence

    def evaluate_beam(self, sequence, stop_predict_count=30, beam=3, bos=1, eos=2):
        self.load_model()
        self.model.eval()
        input_mask = get_mask(sequence)
        encoded_input = self.model.encode(sequence, input_mask)
        current_best_dec_probs = [([bos], 0)]
        current_leaves = []
        current_beam = beam


        while any([len(d[0]) < stop_predict_count and not eos in d[0] for d in current_best_dec_probs]) \
                and current_beam>0:
            all_log_probs = []

            for d in current_best_dec_probs:
                decoding, log_prob = d
                
                if len(decoding) < stop_predict_count and not eos in decoding:
                    output = self._get_eval_out(encoded_input, decoding)
                    output = output.detach().cpu().numpy().tolist()
                    all_log_probs.extend([(decoding + [i], x + log_prob) for i, x in enumerate(output[0][len(decoding)-1])])


            all_log_probs = sorted(all_log_probs, key=lambda x: (x[1]) / len(x[0]), reverse=True)[:current_beam]
            current_leaves.extend([all_log_probs.pop(i) for i, x in enumerate(all_log_probs) if x[0][-1] == eos])
            current_beam = beam - len(current_leaves)
            current_best_dec_probs = copy.deepcopy(all_log_probs)

        decoded_sequence = max([x for x in current_best_dec_probs + current_leaves], key=lambda y: y[1]/len(y[0]))[0]
        return decoded_sequence

### Train and evaluate:

#### Digits:

In [None]:
# Data:
train_digit_sequences, val_digit_sequences = get_digit_data()
train_dataset = DigitDataset(train_digit_sequences, 10)
val_dataset = DigitDataset(val_digit_sequences, 10)
train_loader = DataLoader(train_dataset, batch_size=cfg_train_batch_size, collate_fn=my_collate)
val_loader = DataLoader(val_dataset, batch_size=cfg_train_batch_size, collate_fn=my_collate)
TR = Train(train_dataset, train_loader, val_dataset, val_loader, to_log=True)

In [None]:
# Check train:
TR.train()

In [None]:
# Check inference:
sequence = torch.randint(3,10,(1, 20)).to(cfg_train_device)
print('Input sentence: ', [1] + sequence.cpu().numpy().tolist()[0] + [2])
greedy_out = TR.evaluate_greedy(sequence)
beam_out = TR.evaluate_beam(sequence, beam=4)
print('Predicted greedy:', greedy_out)
print('Predicted beam:', beam_out)

#### Anki dataset:

In [None]:
# Data:
if not (os.path.exists('/content/en_train_set')\
        and os.path.exists('/content/ru_train_set')):
    split_anki_dataset()
    
train_dataset = TranslationDataset('/content/en_train_set', '/content/ru_train_set')
val_dataset = TranslationDataset('/content/en_val_set', '/content/ru_val_set')
test_dataset = TranslationDataset('/content/en_test_set', '/content/ru_test_set')
train_loader = DataLoader(train_dataset,
                          batch_sampler=RandomSortingSampler(
                              train_dataset, batch_size=cfg_train_batch_size,
                              shuffle=True),
                          collate_fn=my_collate)
val_loader = DataLoader(val_dataset, batch_size=cfg_train_batch_size,
                            collate_fn=my_collate)
test_loader = DataLoader(test_dataset, batch_size=1,
                            collate_fn=my_collate)

In [None]:
# Check train:
TR = Train(train_dataset, train_loader, val_dataset, val_loader, to_log=True)
TR.train()

In [None]:
# Check inference:
#TR = Train(train_dataset, train_loader, val_dataset, val_loader, to_log=True)
sentence = 'It doesn\'t work.'
tokenised_sentece = TR.train_dataset.tokenizers['input'].encode(sentence, output_type=yttm.OutputType.ID, bos=True,
                                                                eos=True)
print('Input tokens:', tokenised_sentece)
greedy_out = TR.evaluate_greedy(torch.tensor(tokenised_sentece[1:-1]).unsqueeze(0).to(cfg_train_device),
                                stop_predict_count=15, bos=2, eos=3)
beam_out = TR.evaluate_beam(torch.tensor(tokenised_sentece[1:-1]).unsqueeze(0).to(cfg_train_device),
                            stop_predict_count=15, beam=5, bos=2, eos=3)
decoded_greedy_out = TR.train_dataset.tokenizers['output'].decode(greedy_out)
decoded_beam_out = TR.train_dataset.tokenizers['output'].decode(beam_out)
print('Predicted greedy: {}\nPredicted beam: {}'.format(decoded_greedy_out[0].replace('<PAD>',''),
                                                        decoded_beam_out[0].replace('<PAD>','')))

##### Count BLEU score:

In [None]:
from nltk.translate.bleu_score import corpus_bleu

In [None]:
def count_bleu(test_dataset):
    TR = Train(train_dataset, train_loader, val_dataset, val_loader)

    ru_tokenizer = TR.train_dataset.tokenizers['output']
    en_tokenizer = TR.train_dataset.tokenizers['input']

    target_sentences, predicted_sentences = [], []
    print('test len', len(test_dataset))
    
    t = time()
    for i, d in enumerate(test_dataset):
        en_ids, ru_ids = d 
        out = TR.evaluate_beam(en_ids[1:-1].unsqueeze(0).to(cfg_train_device), stop_predict_count=10, beam=3, bos=2, eos=3)
        decoded_out = TR.train_dataset.tokenizers['output'].decode(out[1:-1])
        
        target_sentences.append(ru_tokenizer.decode(ru_ids[1:-1].numpy().tolist())[0])
        predicted_sentences.append(decoded_out[0].replace('<PAD>',''))
        if i%50==0:
            print('En input: {}\nRu target: {}\nRu predicted: {}\nTime spent:{}'.format(
               en_tokenizer.decode(en_ids[1:-1].numpy().tolist())[0],
               target_sentences[-1],
               predicted_sentences[-1],
               time() - t
            ))
            t = time()
            bleu_score = corpus_bleu(target_sentences, predicted_sentences, weights=(0.5, 0.5, 0, 0))
            print('Bleu score: {:.2f} on {} pairs'.format(bleu_score,
                                                          len(target_sentences)))

    bleu_score = corpus_bleu(target_sentences, predicted_sentences)
    print('Bleu score: {:.2f}'.format(bleu_score))


count_bleu(test_dataset)

### Logs:

In [None]:
%load_ext tensorboard

In [None]:
%tensorboard --logdir '/content/drive/MyDrive/NLP/Transformer/anki/logs'