In [None]:
import sys
sys.path.insert(0,'/content/drive/MyDrive/NLP/Transformer')

In [None]:
from model.transformer import Transformer

### Setup:

In [None]:
'''
!pip install transformers
!pip install pytorch-lightning
!pip3 install Cython
'''

In [None]:
import torch
import torch.nn.functional as F
from torchtext.datasets import WikiText2
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator
from torch.utils.data import Dataset, RandomSampler, DataLoader
import pytorch_lightning as pl
from pytorch_lightning import seed_everything
from transformers import AutoModelForSeq2SeqLM, BartConfig, AutoConfig
from sklearn.metrics import accuracy_score

In [None]:
model_cfg = {
    'input_vocab_size':28782,
    'output_vocab_size':28782,
    'device': torch.device("cuda" if torch.cuda.is_available() else "cpu"),
    'dropout': 0.2,
    'emb_size': 512,
    'hidden_size': 64,
    'nrof_heads': 2,
    'f_hidden_size': 64,
    'nrof_layers': 2
}

## Data

In [None]:
class WikiTextDataset(Dataset):
    def __init__(self):
        super().__init__()
        self.phase = 'train'
        self.dataset = {}

    def preprocess(self, raw_text_iter):
        data = [torch.tensor(self.vocab(self.tokenizer(item)), dtype=torch.long) for item in raw_text_iter]
        return torch.cat(tuple(filter(lambda t: t.numel() > 0, data)))
    
    def batchify(self, data, seq_len=513, device='cpu'):
        nrof_seqs = data.size(0) // seq_len
        data = data[:nrof_seqs * seq_len]
        data = data.view(nrof_seqs, seq_len)
        return data.to(device)

    def prepare(self):
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        
        self.tokenizer = get_tokenizer('basic_english')
        self.vocab = build_vocab_from_iterator(map(
            self.tokenizer, WikiText2(split='train')), specials=['<unk>'])
        self.vocab.set_default_index(self.vocab['<unk>'])

        train_iter, val_iter, test_iter = WikiText2()
        self.dataset['train'] = self.batchify(self.preprocess(train_iter), 
                                              device=device)
        self.dataset['val'] = self.batchify(self.preprocess(val_iter), 
                                            device=device)
        self.dataset['test'] = self.batchify(self.preprocess(test_iter), 
                                             device=device)

    def set_phase(self, phase='train'):
        self.phase = phase

    def __len__(self):
        return len(self.dataset[self.phase]) - 1

    def __getitem__(self, idx):
        # returns both: inputs for encoder and decoder
        if torch.is_tensor(idx):
            idx = idx.item()

        item = self.dataset[self.phase][idx][:-1]
        target = self.dataset[self.phase][idx][1:]
        return item, item, target


In [None]:
def get_mask(batched_sequence, decoding=False):
    '''
    batched_seq of shape (b_s, max_seq_len, emb_size)
    '''
    b_s, max_seq_len = batched_sequence.shape
    mask_pad = batched_sequence.unsqueeze(1).repeat_interleave(max_seq_len, dim=1) != 0
    mask_pad = ~ (mask_pad * mask_pad.permute(0,2,1))
    if decoding:
        mask = torch.full((b_s, max_seq_len, max_seq_len), True)
        mask = torch.triu(mask, diagonal=1) 
        mask[mask_pad] = True
        mask[mask.prod(dim=1)==1] = False
        return mask 
    mask_pad[mask_pad.prod(dim=1)==1] = False
    return mask_pad

def my_collate(batch):
    input_encoder, input_decoder, target = zip(*batch)
    input_encoder = torch.vstack(input_encoder)
    input_decoder = torch.vstack(input_decoder)
    target = torch.vstack(target)
    encoder_mask = get_mask(input_encoder)
    decoder_mask = get_mask(input_decoder, decoding=True)
    data = {
        'input_encoder': input_encoder,
        'input_decoder': input_decoder,
        'encoder_mask': encoder_mask,
        'decoder_mask': decoder_mask
    }

    return data, target

def get_dataloader(dataset, sampler=None, phase='train', batch_size=4):
    if phase == 'train' and sampler is None:
        sampler = RandomSampler(dataset)

    dataloader = DataLoader(dataset,
                            sampler=sampler,
                            batch_size=batch_size,
                            drop_last=phase == 'train',
                            collate_fn=my_collate)
    return dataloader

## Train with pytorch lightening:

In [None]:
class TransformerLM(pl.LightningModule):
    def __init__(self, model_config):
        super(TransformerLM, self).__init__()
        self.model_config = model_config
        self.model = Transformer(model_config)
        self.model = self.model.to(model_cfg['device'])
        self.dataset = WikiTextDataset()

    def prepare_data(self):
        self.dataset.prepare()

    def forward(self, **data):
        return self.model(**data)

    def training_step(self, batch, batch_idx):
        data, target = batch
        output = self(**data)
        loss = F.cross_entropy(output.view(
            -1, self.model_config['output_vocab_size']), target.view(-1))
        tensorboard_logs = {'train_loss': loss}
        return {'loss': loss, 'log': tensorboard_logs}

    def validation_step(self, batch, batch_nb):
        data, target = batch
        output = self(**data)
        loss = F.cross_entropy(output.view(
            -1, self.model_config['output_vocab_size']), target.view(-1))
        _, preds = torch.max(output, -1)
        val_acc = accuracy_score(preds.cpu().view(-1), target.cpu().view(-1))
        val_acc = torch.tensor(val_acc)

        return {'val_loss': loss, 'val_acc': val_acc}

    def validation_epoch_end(self, outputs):
        avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean()
        avg_val_acc = torch.stack([x['val_acc'] for x in outputs]).mean()

        tensorboard_logs = {'val_loss': avg_loss, 'avg_val_acc': avg_val_acc}
        return {'val_loss': avg_loss, 'progress_bar': tensorboard_logs}

    def configure_optimizers(self):
        return torch.optim.Adam(
            [p for p in self.parameters() if p.requires_grad], lr=2e-05, eps=1e-08)

    def train_dataloader(self):
        return get_dataloader(self.dataset)
    
    def val_dataloader(self):
        self.dataset.set_phase('val')
        return get_dataloader(self.dataset, phase='val')

    def test_dataloader(self):
        self.dataset.set_phase('test')
        return get_dataloader(self.dataset, phase='test')

In [None]:
seed_everything(11)

logger = pl.loggers.TensorBoardLogger("tb_logs", name="transformer_ml")

transformer_lm = TransformerLM(model_cfg)
transformer_lm.prepare_data()

trainer = pl.Trainer(max_epochs=10, gpus=1, logger=logger)    
trainer.fit(transformer_lm) 

In [None]:
from shutil import rmtree
import os

# to clean logs dir while debugging
for version in os.listdir('/content/lightning_logs'):
    rmtree(os.path.join('/content/lightning_logs', version))

In [None]:
%load_ext tensorboard
%tensorboard --logdir tb_logs/transformer_ml/