### Purpose
This is a reimplementation of https://pytorch.org/tutorials/beginner/transformer_tutorial.html using pytorch-lightening

In [65]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader


import numpy as np
import pandas as pd
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning import Trainer


import torchtext
from torchtext.data.utils import get_tokenizer
import math

from argparse import ArgumentParser

In [57]:
# refer to https://github.com/achinta/machine-learning/blob/master/notebooks/nlp/torchtext-tutorial.ipynb
class LMDataset(Dataset):
    def __init__(self, data, bptt, bsz):
        '''
        data is a tensor of shape [k,1], where k is the number of words in text
        '''
        self.bptt = bptt
        self.bsz = bsz
        
        # Divide the dataset into bsz parts.
        nbatches = data.size(0)//bsz
        
        # Trim off any extra elements that wouldn't cleanly fit (remainders).
        data = data.narrow(dim=0, start=0, length=nbatches*bsz)
        
        # Evenly divide the data across the bsz batches.
        self.data = data.view(bsz, -1).t().contiguous()
        
    def __getitem__(self, i):
        data = self.data[i:i+self.bptt]
        target = self.data[i+1:i+1+self.bptt].view(-1)
        return data, target
    
    def __len__(self):
        return(self.data.size(0) - self.bptt)
    
class PositionalEncoding(nn.Module):

    def __init__(self, d_model, dropout=0.1, max_len=5000):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)

        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).transpose(0, 1)
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + self.pe[:x.size(0), :]
        return self.dropout(x)

In [107]:
class LMModel(LightningModule):
    def __init__(self, hparams):
        super(LMModel, self).__init__()
        self.hparams = hparams
        self.src_mask = None
        self.pos_encoder = PositionalEncoding(hparams.ninp, hparams.dropout)
        encoder_layers = nn.TransformerEncoderLayer(hparams.ninp, hparams.nhead, hparams.nhid, hparams.dropout)
        self.transformer_encoder = nn.TransformerEncoder(encoder_layers, hparams.nlayers)
        self.criterion = nn.CrossEntropyLoss()
        
    def init_weights(self):
        initrange = 0.1
        self.src_embedding.weight.data.uniform_(-initrange, initrange)
        self.decoder.bias.data.zero_()
        self.decoder.weight.data.uniform_(-initrange, initrange)
        
    @staticmethod
    def add_model_specifi_args(parent_parser):
        parser = ArgumentParser(parents=[parent_parser], add_help=False)
        parser.add_argument('-bsz', default=20, type=int, help='batch_size', )
        parser.add_argument('-bptt', default=35, type=int, help='sentence length')
        parser.add_argument('-ninp', default=256, type=int, help='expected features in the input')
        parser.add_argument('-nhead', default=4, type=int, help='number of attention heads')
        parser.add_argument('-nhid', default=1024, type=int, help='dimesion of feed-forward network model')
        parser.add_argument('-nlayers', default=3, type=int, help='number of encoder layers')
        parser.add_argument('-dropout', default=0.2, type=float, help='dropout')
        return parser
    
    
    def _generate_square_subsequent_mask(self, sz):
        # populate the lower triangle with True and rest with False
        return torch.tril(torch.ones(sz, sz)) == 1.0
    
    def prepare_data(self):
        self.field = torchtext.data.Field(tokenize=get_tokenizer('basic_english'),
                                    init_token='<sos>',
                                    eos_token='<eos>',
                                    lower=True)
        self.train_txt, self.val_txt, self.test_txt = torchtext.datasets.WikiText2.splits(self.field)
        self.field.build_vocab(self.train_txt)
        
        # create source embedding
        self.ntoken = len(self.field.vocab)
        self.src_embedding = nn.Embedding(self.ntoken, self.hparams.ninp)
        self.decoder = nn.Linear(self.hparams.ninp, self.ntoken)
        self.init_weights()
        
    def train_dataloader(self):
        train_data = self.field.numericalize([self.train_txt.examples[0].text])
        train_ds = LMDataset(train_data,self.hparams.bptt, self.hparams.bsz )
        return DataLoader(train_ds, shuffle=True)
    
    def forward(self, src):
        if self.src_mask is None or self.src_mask.size(0) != len(src):
            device = src.device
            mask = self._generate_square_subsequent_mask(len(src)).to(device)
            self.src_mask = mask
            
        src = self.src_embedding(src) * math.sqrt(self.hparams.ninp)
        src = self.pos_encoder(src)
        output = self.transformer_encoder(src, self.src_mask)
        output = self.decoder(output)
        return output
    
    def training_step(self, batch, batch_idx):
        x, y = batch
        x = x.squeeze(0)
        y = y.squeeze(0)
        yhat = self(x)
#         print(f'x.shape - {x.shape} yhat.shape - {yhat.shape}  \t y.shape - {y.shape}')
        loss = self.criterion(yhat.reshape(-1, self.ntoken), y)
        return {'loss': loss}


In [108]:
parser = ArgumentParser()
parser = LMModel.add_model_specifi_args(parser)
hparams = parser.parse_args("")
lm = LMModel(hparams)

trainer = Trainer()
trainer.fit(lm)

HBox(children=(FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max=1.0), HTML(value='')), …




1

### Playground

In [54]:
field = torchtext.data.Field(tokenize=get_tokenizer('basic_english'),
                            init_token='<sos>',
                            eos_token='<eos>',
                            lower=True)
train_txt, val_txt, test_txt = torchtext.datasets.WikiText2.splits(field)

field.build_vocab(train_txt)

data = field.numericalize([train_txt.examples[0].text])
data.shape

bptt = 3
bsz = 4
train_data = field.numericalize([train_txt.examples[0].text])
print(train_data.shape)
train_ds = LMDataset(train_data,bptt, bsz )
len(train_ds)

torch.Size([2086708, 1])


521674

In [62]:
len(lm.field.vocab)

28785