In [1]:
from torch.nn import Transformer
import torchtext
import torch
from torchtext.data.utils import get_tokenizer
from collections import Counter
from torchtext.vocab import Vocab
from torchtext.utils import download_from_url, extract_archive
import io
from torch.utils.data import DataLoader
from torch.nn.utils.rnn import pad_sequence
import torch.nn as nn
from einops import rearrange
import math
from datetime import datetime
from torch.utils.tensorboard import SummaryWriter
import time

In [2]:
url_base = 'https://nlp.stanford.edu/projects/nmt/data/wmt14.en-de/'
train_urls = ('train.de', 'train.en')
val_urls = ('newstest2015.de', 'newstest2015.en')
test_urls = ('test_2016_flickr.de.gz', 'test_2016_flickr.en.gz')

In [6]:
url_base = 'https://raw.githubusercontent.com/multi30k/dataset/master/data/task1/raw/'
train_urls = ('train.de.gz', 'train.en.gz')
val_urls = ('val.de.gz', 'val.en.gz')
test_urls = ('test_2016_flickr.de.gz', 'test_2016_flickr.en.gz')

In [7]:
train_filepaths = [download_from_url(url_base + url) for url in train_urls]
val_filepaths = [download_from_url(url_base + url) for url in val_urls]
test_filepaths = [download_from_url(url_base + url) for url in test_urls]


In [8]:
de_tokenizer = get_tokenizer('spacy', language='de_core_news_sm')
en_tokenizer = get_tokenizer('spacy', language='en_core_web_sm')

In [84]:
train_filepaths

['.data/train.de', '.data/train.en']

In [9]:
train_filepaths = [extract_archive(download_from_url(url_base + url))[0] for url in train_urls]
val_filepaths = [extract_archive(download_from_url(url_base + url))[0] for url in val_urls]
test_filepaths = [extract_archive(download_from_url(url_base + url))[0] for url in test_urls]
de_tokenizer = get_tokenizer('spacy', language='de_core_news_sm')
en_tokenizer = get_tokenizer('spacy', language='en_core_web_sm')

In [10]:
def build_vocab(filepath,tokenizer):
    counter = Counter()
    with io.open(filepath,encoding="utf8") as f:
        for string_ in f:
            counter.update(tokenizer(string_))
    return Vocab(counter,specials=['<unk>','<pad>','<sos>','<eos>'])
de_vocab = build_vocab(train_filepaths[0],de_tokenizer)
en_vocab = build_vocab(train_filepaths[1],en_tokenizer)

In [11]:
len(de_vocab)

19215

In [12]:
def build_vocab(filepath,tokenizer):
    counter = Counter()
    with io.open(filepath,encoding="utf8") as f:
        for string_ in f:
            counter.update(tokenizer(string_))
    return Vocab(counter,specials=['<unk>','<pad>','<sos>','<eos>'])
de_vocab = build_vocab(train_filepaths[0],de_tokenizer)
en_vocab = build_vocab(train_filepaths[1],en_tokenizer)


In [13]:
def data_process(filepaths):
    raw_de_itr = iter(io.open(filepaths[0],encoding="utf8"))
    raw_en_itr = iter(io.open(filepaths[1],encoding="utf8"))
    data =[]
    for(raw_de,raw_en) in zip(raw_de_itr,raw_en_itr):
        de_tensor_ = torch.tensor([de_vocab[token] for token in de_tokenizer(raw_de)],dtype=torch.long)
        en_tensor_ = torch.tensor([en_vocab[token] for token in en_tokenizer(raw_en)],
                            dtype=torch.long)
        data.append((de_tensor_, en_tensor_))
    return data

In [14]:
train_data = data_process(train_filepaths)
val_data = data_process(val_filepaths)
test_data = data_process(test_filepaths)

In [90]:
len(train_data)

29000

In [15]:

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [16]:
BATCH_SIZE = 128
PAD_IDX = de_vocab['<pad>']
BOS_IDX = de_vocab['<sos>']
EOS_IDX = de_vocab['<eos>']

In [17]:
def generate_batch(data_batch):
    de_batch, en_batch = [], []
    src_pad_masks,tgt_pad_masks = [],[]
    for (de_item, en_item) in data_batch:
        de_batch.append(torch.cat([torch.tensor([BOS_IDX]), de_item, torch.tensor([EOS_IDX])], dim=0))
        en_batch.append(en_item)
    de_batch = pad_sequence(de_batch, batch_first=True,padding_value=PAD_IDX)
    en_batch = pad_sequence(en_batch, batch_first=True,padding_value=PAD_IDX)
    for en_item in en_batch:
        curr_mask = en_item == PAD_IDX
        src_pad_masks.append(curr_mask)
    src_pad_masks = torch.stack(src_pad_masks)
    for de_item in de_batch:
        curr_mask = torch.logical_or(de_item == PAD_IDX,de_item == EOS_IDX)[:-1]
        tgt_pad_masks.append(curr_mask)
    tgt_pad_masks = torch.stack(tgt_pad_masks)
    return en_batch,de_batch,src_pad_masks,tgt_pad_masks

In [18]:
train_iter = DataLoader(train_data, batch_size=BATCH_SIZE,
                        shuffle=True, collate_fn=generate_batch,num_workers=4)
valid_iter = DataLoader(val_data, batch_size=BATCH_SIZE,
                        shuffle=True, collate_fn=generate_batch,num_workers=4)
test_iter = DataLoader(test_data, batch_size=BATCH_SIZE,
                       shuffle=True, collate_fn=generate_batch,num_workers=4)

In [19]:
def gen_nopeek_mask(length):
    mask = rearrange(torch.triu(torch.ones(length, length)) == 1, 'h w -> w h')
    mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))

    return mask  

In [20]:
class PositionalEncoding(nn.Module):

    def __init__(self, d_model, dropout=0.1, max_len=5000):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)

        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).transpose(0, 1)
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + self.pe[:x.size(0), :]
        return self.dropout(x)

In [21]:
class TransformerTranlator(nn.Module):
    def __init__(self,in_token,out_token,ninp,nhead,nhid,nlayers,dropout=0.5):
        super(TransformerTranlator, self).__init__()
        self.ninp = ninp
        self.encoder_embedding = nn.Embedding(in_token,ninp)
        self.decoder_embedding = nn.Embedding(out_token,ninp)
        self.pos_encoder = PositionalEncoding(ninp,dropout)
        self.transformer = nn.Transformer(d_model=ninp,nhead=nhead,num_encoder_layers=nlayers,num_decoder_layers=nlayers,dim_feedforward=nhid,dropout=dropout)
        self.fc = nn.Linear(ninp,out_token)
        
    def generate(self,out):
        output = rearrange(out,'t n e -> n t e')
        output = self.fc(output)
        return output

    def forward(self,src,src_pad_mask,tgt,tgt_mask,tgt_pad_mask,mem_mask):
        src = self.encoder_embedding(src) * math.sqrt(self.ninp)
        src = self.pos_encoder(src)
        src = rearrange(src,'n s e-> s n e')
        
        tgt = self.decoder_embedding(tgt) * math.sqrt(self.ninp)
        tgt = self.pos_encoder(tgt)
        tgt = rearrange(tgt, 'n t e-> t n e')
        
        out = self.transformer(src=src,src_key_padding_mask=src_pad_mask,tgt=tgt,tgt_mask=tgt_mask,tgt_key_padding_mask=tgt_pad_mask,memory_key_padding_mask=mem_mask)
        return self.generate(out)

In [22]:
#Hyperparameters
in_token = len(en_vocab)
out_token = len(de_vocab)

emsize = 768
nhid = 256
nlayers = 6
nhead = 6
dropout = 0.2
batch_size = 128
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [23]:
model = TransformerTranlator(in_token,out_token,emsize,nhead,nhid,nlayers,dropout).to(device)
optim = torch.optim.SGD(model.parameters(),lr=0.01)
criterion = nn.CrossEntropyLoss(ignore_index=0)

In [24]:
len(valid_iter)

8

In [25]:
def evaluate(model,epoch):
    model.eval()
    total_loss = 0
    start_time = time.time()
    with torch.no_grad():
        for i,(en_batch,de_batch,src_pad_masks,tgt_pad_masks) in enumerate(valid_iter):
            en_batch = en_batch.to(device)
            de_batch = de_batch.to(device)
            src_pad_masks = src_pad_masks.to(device)
            tgt_pad_masks = tgt_pad_masks.to(device)
            optim.zero_grad()
            tgt_in = de_batch[:,:-1]
            tgt_out = de_batch[:,1:]

            tgt_mask = gen_nopeek_mask(tgt_in.shape[1]).to(device)
            mem_pad_mask = src_pad_masks.clone()

            output = model(en_batch,src_pad_masks,tgt_in,tgt_mask,tgt_pad_masks,mem_pad_mask)
            loss = criterion(rearrange(output, 'b t v -> (b t) v'), rearrange(tgt_out, 'b o -> (b o)'))
            total_loss += loss.item()
        elapsed = time.time() - start_time

        curr_loss = total_loss / len(valid_iter)
        print("Eval Epoch: {:d} Loss: {:.2f} | Batches/sec: {:.2f}".format(epoch,curr_loss,len(valid_iter) / elapsed))
        writer.add_scalar('Evaluation Loss',curr_loss,epoch)
        writer.add_scalar('Evaluation Speed',len(valid_iter)/elapsed,epoch)

In [26]:
len(train_iter)

227

In [27]:
def train(model,writer,epochs):
    log_interval = 100
    total_loss = 0
    start_time = time.time()
    num_batch = len(train_iter)
    for epoch in range(epochs):
        print("Epoch: ",epoch)
        model.train()
        for i,(en_batch,de_batch,src_pad_masks,tgt_pad_masks) in enumerate(train_iter):
            en_batch = en_batch.to(device)
            de_batch = de_batch.to(device)
            src_pad_masks = src_pad_masks.to(device)
            tgt_pad_masks = tgt_pad_masks.to(device)
            optim.zero_grad()
            tgt_in = de_batch[:,:-1]
            tgt_out = de_batch[:,1:]

            tgt_mask = gen_nopeek_mask(tgt_in.shape[1]).to(device)
            mem_pad_mask = src_pad_masks.clone()

            output = model(en_batch,src_pad_masks,tgt_in,tgt_mask,tgt_pad_masks,mem_pad_mask)
            loss = criterion(rearrange(output, 'b t v -> (b t) v'), rearrange(tgt_out, 'b o -> (b o)'))
            total_loss += loss.detach().item()
            if(i % log_interval == 0 and i != 0):
                elapsed = time.time() - start_time
                total_batches = epoch*len(train_iter) + i
                curr_loss = total_loss / log_interval
                print("Training Loss: {:.2f} | Batches/sec: {:.2f} | Total batches: {:d}".format(curr_loss,log_interval / elapsed,total_batches))
                
                writer.add_scalar('Training Loss',curr_loss,total_batches)
                writer.add_scalar('Training Speed',log_interval/elapsed,total_batches)
                total_loss = 0
                start_time = time.time()
            loss.backward()
            optim.step()
        evaluate(model,epoch)

In [28]:
currTime = datetime.now().strftime("%d%m%Y%H%M%S")
writer = SummaryWriter('runs/'+ currTime)
train(model,writer,40)

Epoch:  0
Training Loss: 4.66 | Batches/sec: 1.94 | Total batches: 100
Training Loss: 3.47 | Batches/sec: 1.79 | Total batches: 200
Eval Epoch: 0 Loss: 3.05 | Batches/sec: 4.62
Epoch:  1
Training Loss: 4.00 | Batches/sec: 1.40 | Total batches: 327
Training Loss: 3.01 | Batches/sec: 1.83 | Total batches: 427
Eval Epoch: 1 Loss: 2.68 | Batches/sec: 4.49
Epoch:  2
Training Loss: 3.66 | Batches/sec: 1.39 | Total batches: 554
Training Loss: 2.82 | Batches/sec: 1.85 | Total batches: 654
Eval Epoch: 2 Loss: 2.56 | Batches/sec: 4.51
Epoch:  3
Training Loss: 3.51 | Batches/sec: 1.40 | Total batches: 781
Training Loss: 2.69 | Batches/sec: 1.82 | Total batches: 881
Eval Epoch: 3 Loss: 2.50 | Batches/sec: 4.51
Epoch:  4
Training Loss: 3.38 | Batches/sec: 1.39 | Total batches: 1008
Training Loss: 2.64 | Batches/sec: 1.81 | Total batches: 1108
Eval Epoch: 4 Loss: 2.52 | Batches/sec: 4.60
Epoch:  5
Training Loss: 3.39 | Batches/sec: 1.41 | Total batches: 1235
Training Loss: 2.59 | Batches/sec: 1.82 |