In [1]:
import pytorch_lightning as pl 


In [3]:
import csv
from torch.utils.data import DataLoader
from torch.nn.utils.rnn import pad_sequence
import math
import torchtext
import torch
import torch.nn as nn
from torchtext.data.utils import get_tokenizer
from collections import Counter
from torchtext.vocab import Vocab
from torch import Tensor
import io
import time
from torchtext.utils import unicode_csv_reader
torch.manual_seed(0)

<torch._C.Generator at 0x7fa9fc208f90>

In [4]:
class PositionalEncoding(nn.Module):
    def __init__(self, emb_size, dropout, maxlen):
        super(PositionalEncoding, self).__init__()
        den = torch.exp(- torch.arange(0, emb_size, 2)
                        * math.log(10000) / emb_size)
        pos = torch.arange(0, maxlen).reshape(maxlen, 1)
        pos_embedding = torch.zeros((maxlen, emb_size))
        pos_embedding[:, 0::2] = torch.sin(pos * den)
        pos_embedding[:, 1::2] = torch.cos(pos * den)
        pos_embedding = pos_embedding.unsqueeze(-2)

        self.dropout = nn.Dropout(dropout)
        self.register_buffer('pos_embedding', pos_embedding)

    def forward(self, token_embedding):
        return self.dropout(token_embedding +
                            self.pos_embedding[:token_embedding.size(0), :])


class TokenEmbedding(nn.Module):
    def __init__(self, vocab_size: int, emb_size):
        super(TokenEmbedding, self).__init__()
        self.embedding = nn.Embedding(vocab_size, emb_size)
        self.emb_size = emb_size

    def forward(self, tokens):
        return self.embedding(tokens.long()) * math.sqrt(self.emb_size)

In [5]:
class SelfAttention(nn.Module):
    def __init__(self, embed_size, heads):
        super(SelfAttention, self).__init__()
        self.embed_size = embed_size
        self.heads = heads

        self.head_dim = embed_size // heads

        assert(self.head_dim * heads == embed_size), "Embed size needs to be divisible by heads"

        self.values = nn.Linear(self.embed_size, self.embed_size, bias=False)
        self.keys = nn.Linear(self.embed_size, self.embed_size, bias=False)
        self.queries = nn.Linear(self.embed_size, self.embed_size, bias=False)
        self.fc_out = nn.Linear(heads*self.head_dim, embed_size)

    def forward(self, values, keys, queries, mask):
        N = queries.shape[0]  # get no of training examples

        value_len, key_len, query_len = values.shape[1], keys.shape[1], queries.shape[1]

        # we need to send Q,K,V through linear layers
        values = self.values(values)
        keys = self.keys(keys)
        queries = self.queries(queries)

        # split embedding into self.heads pieces
        values = values.reshape(N, value_len, self.heads, self.head_dim)
        keys = keys.reshape(N, key_len, self.heads, self.head_dim)
        queries = queries.reshape(N, query_len, self.heads, self.head_dim)

        energy = torch.einsum("nqhd,nkhd -> nhqk", [queries, keys])
        # queries shape: (N, query_len, heads, heads_dim)
        # keys shape: (N, key_len, heads, heads_dim)
        # energy shape: (N, heads, query_len, key_len)

        if mask is not None:
            energy = energy.masked_fill(mask == 0, float("-1e20"))

        attention = torch.softmax(energy / (self.embed_size ** (1/2)), dim=3)

        out = torch.einsum(
            "nhql,nlhd -> nqhd", [attention, values]).reshape(N, query_len, self.heads*self.head_dim)
        # attention shape: (N, heads, query_len, key_lem)
        # values shape: (N, value_len, heads, heads_dim)
        # out shape: (N, query_len, heads, head_dim)

        out = self.fc_out(out)
        return out


class TransformerBlock(nn.Module):
    def __init__(self, embed_size, heads, dropout, forward_expansion):
        super(TransformerBlock, self).__init__()
        self.attention = SelfAttention(embed_size, heads)
        self.norm1 = nn.LayerNorm(embed_size)
        self.norm2 = nn.LayerNorm(embed_size)

        self.feed_forward = nn.Sequential(
            nn.Linear(embed_size, forward_expansion*embed_size),
            nn.ReLU(),
            nn.Linear(forward_expansion*embed_size, embed_size))

        self.dropout = nn.Dropout(dropout)

    def forward(self, value, key, query, mask):
        attention = self.attention(value, key, query, mask)
        x = self.dropout(self.norm1(attention + query))
        forward = self.feed_forward(x)
        out = self.dropout(self.norm2(forward + x))
        return out


class Encoder(nn.Module):
    def __init__(self, src_vocab_size, embed_size,
                 num_layers, heads, forward_expansion,
                 dropout, max_length):

        super(Encoder, self).__init__()
        self.embed_size = embed_size
        self.word_embedding = TokenEmbedding(src_vocab_size, embed_size)
        self.positional_embedding = PositionalEncoding(
            embed_size, dropout, max_length)
        self.dropout = nn.Dropout(dropout)
        self.layers = nn.ModuleList([TransformerBlock(embed_size, heads,
                                                      dropout=dropout, forward_expansion=forward_expansion) for _ in range(num_layers)])

    def forward(self, x, mask):

        # *[Done] needs to be replaced with sinsoidal position embeddings from official pytorch docs
        out = self.dropout(self.positional_embedding(self.word_embedding(x)))

        for layer in self.layers:
            out = layer(out, out, out, mask)
        return out


class DecoderBlock(nn.Module):
    def __init__(self, embed_size, heads, forward_expansion, dropout):
        super(DecoderBlock, self).__init__()
        self.attention = SelfAttention(embed_size, heads)
        self.norm = nn.LayerNorm(embed_size)
        self.transformer_block = TransformerBlock(
            embed_size, heads, dropout, forward_expansion)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, value, key, src_mask, trg_mask):
        attention = self.attention(x, x, x, trg_mask)
        query = self.dropout(self.norm(attention + x))
        out = self.transformer_block(value, key, query, src_mask)
        return out


class Decoder(nn.Module):
    def __init__(self, trg_vocab_size, embed_size,  num_layers, heads, forward_expansion, dropout, max_length):
        super(Decoder, self).__init__()
        self.word_embedding = TokenEmbedding(trg_vocab_size, embed_size)
        self.positional_embedding = PositionalEncoding(
            embed_size, dropout, max_length)

        self.layers = nn.ModuleList([DecoderBlock(
            embed_size, heads, forward_expansion, dropout)for _ in range(num_layers)])

        self.fc_out = nn.Linear(embed_size, trg_vocab_size)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, enc_out, src_mask, trg_mask):

        x = self.dropout(self.positional_embedding(self.word_embedding(x)))

        for layer in self.layers:
            x = layer(x, enc_out, enc_out, src_mask, trg_mask)

        out = self.fc_out(x)
        return out

In [20]:
num_epochs = 10000
learning_rate = 3e-4
batch_size = 4
embedding_size = 512 
num_heads = 8 
num_encoder_layers = 6 
num_deocder_layers = 6 
dropout = 0.3 
max_len = 128 
forward_expansion = 4 

gom_tokenizer = get_tokenizer('spacy', language='xx_sent_ud_sm')
hin_tokenizer = get_tokenizer('spacy', language='xx_sent_ud_sm')


def build_vocab(filepath, tokenizer1, tokenizer2):
    counter1 = Counter()
    counter2 = Counter()
    with open(filepath) as csv_file:
        csv_reader = csv.reader(csv_file, delimiter=',')
        reader = unicode_csv_reader(csv_file)

        for string_ in reader:
            counter1.update(tokenizer1(string_[0]))
            counter2.update(tokenizer2(string_[1]))
    return Vocab(counter1, specials=['<unk>', '<pad>', '<bos>', '<eos>']), Vocab(counter2, specials=['<unk>', '<pad>', '<bos>', '<eos>'])


gom_vocab, hin_vocab = build_vocab('train.csv', gom_tokenizer, hin_tokenizer)


def data_process(filepath):
    csv_file = open(filepath, encoding='utf8')
    raw_data_iter = iter(unicode_csv_reader(csv_file))
    data = []
    for (raw_gom, raw_hin) in raw_data_iter:
        gom_tensor_ = torch.tensor([gom_vocab[token] for token in gom_tokenizer(raw_gom)],
                                   dtype=torch.long)
        hin_tensor_ = torch.tensor([hin_vocab[token] for token in hin_tokenizer(raw_hin)],
                                   dtype=torch.long)
        data.append((gom_tensor_, hin_tensor_))
    return data

gom_vocab, hin_vocab = build_vocab('train.csv', gom_tokenizer, hin_tokenizer)
train_data = data_process('train.csv')
test_data = data_process('test.csv')

src_vocab_size = len(gom_vocab)
trg_vocab_size = len(hin_vocab)
src_pad_idx = gom_vocab.stoi['<pad>']
trg_pad_idx = hin_vocab.stoi['<pad>']
pad_idx = gom_vocab['<pad>']
bos_idx = gom_vocab['<bos>']
eos_idx = gom_vocab['<eos>']

def generate_batch(data_batch):
    gom_batch, hin_batch = [], []
    for (gom_item, hin_item) in data_batch:
        gom_batch.append(
            torch.cat([torch.tensor([bos_idx]), gom_item, torch.tensor([eos_idx])], dim=0))
        hin_batch.append(
            torch.cat([torch.tensor([bos_idx]), hin_item, torch.tensor([eos_idx])], dim=0))
    gom_batch = pad_sequence(gom_batch, padding_value=pad_idx)
    hin_batch = pad_sequence(hin_batch, padding_value=pad_idx)
    return gom_batch, hin_batch    



In [12]:
class Transformer(pl.LightningModule):
    def __init__(self, src_vocab_size, trg_vocab_size, src_pad_idx, trg_pad_idx, embed_size=512, num_layers=6, forward_expansion=4, heads=8, dropout=0.3, max_length=128):
        super(Transformer, self).__init__()
        self.encoder = Encoder(src_vocab_size, embed_size, num_layers,
                               heads, forward_expansion, dropout, max_length)

        self.decoder = Decoder(trg_vocab_size, embed_size, num_layers,
                               heads, forward_expansion, dropout, max_length)

        self.src_pad_idx = src_pad_idx
        self.trg_pad_idx = trg_pad_idx
        
        self.loss_fn = nn.CrossEntropyLoss(ignore_index=self.src_pad_idx)
        #self.acc_metrics = torchmetrics.Accuracy()
        
        
    def make_src_mask(self, src):
        src_mask = (src != self.src_pad_idx).unsqueeze(1).unsqueeze(2)
        # (N, 1, 1, src_len)
        return src_mask.to(self.device)

    def make_trg_mask(self, trg):
        N, trg_len = trg.shape
        trg_mask = torch.tril(torch.ones((trg_len, trg_len))
                              ).expand(N, 1, trg_len, trg_len)
        return trg_mask.to(self.device)

    def forward(self, src, trg):
        src = src.transpose(1, 0)  # batch_first
        trg = trg.transpose(1, 0)  # batch_first

        src_mask = self.make_src_mask(src)
        trg_mask = self.make_trg_mask(trg)
        enc_src = self.encoder(src, src_mask)
        out = self.decoder(trg, enc_src, src_mask, trg_mask)
        return out
    def train_dataloader(self):
        return DataLoader(train_data, shuffle=True, batch_size=batch_size, collate_fn=generate_batch) # collate_fn for similar sort batches 

    def val_dataloader(self):
        return DataLoader(test_data, shuffle=True, batch_size=batch_size, collate_fn=generate_batch) # collate_fn for simialr sort batches 


    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=learning_rate, betas=(0.9, 0.98), eps=1e-9)
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, factor=0.1, patience=10, verbose=True)
        return  {'optimizer': optimizer, 'lr_scheduler':scheduler, 'monitor':'val_loss'}
        
    
    def training_step(self, train_batch, batch_idx):
        src, trg = train_batch
        output = self(src, trg[:-1, :])
        output = output.transpose(1, 0)
        output = output.reshape(-1, output.shape[2])
        trg = trg[1:].reshape(-1)
        train_loss = self.loss_fn(output, trg)
        self.log('train_loss', train_loss, prog_bar=True)
        return train_loss
    
    def validation_step(self, val_batch, batch_idx):
        src, trg = val_batch
        output = self(src, trg[:-1, :])
        output = output.transpose(1, 0)
        output = output.reshape(-1, output.shape[2])
        trg = trg[1:].reshape(-1)
        val_loss = self.loss_fn(output, trg)
        self.log('val_loss', val_loss, prog_bar=True)
        return val_loss
    

In [13]:
model = Transformer(src_vocab_size, trg_vocab_size, src_pad_idx, trg_pad_idx)

In [14]:
for p in model.parameters():
    if p.dim() > 1:
        nn.init.xavier_uniform_(p)

In [16]:
trainer = pl.Trainer(gpus=1, gradient_clip_val=1, max_epochs=num_epochs, progress_bar_refresh_rate=20, logger=wandb_logger)
trainer.fit(model)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name    | Type             | Params
---------------------------------------------
0 | encoder | Encoder          | 46.2 M
1 | decoder | Decoder          | 128 M 
2 | loss_fn | CrossEntropyLoss | 0     
---------------------------------------------
174 M     Trainable params
0         Non-trainable params
174 M     Total params
698.334   Total estimated model params size (MB)


HBox(children=(HTML(value='Validation sanity check'), FloatProgress(value=1.0, bar_style='info', layout=Layout…

HBox(children=(HTML(value='Training'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…






1