In [1]:
import torch
from torch import nn
import torchtext
import numpy as np
from typing import Optional, List, Dict, Tuple, Iterable
from tqdm.auto import tqdm

from torch.nn.utils.rnn import pad_sequence

In [2]:
train_iter = torchtext.datasets.Multi30k(split="train")
test_iter = torchtext.datasets.Multi30k(split="test")

In [3]:
src_tokenizer = torchtext.data.utils.get_tokenizer("spacy", language="de_core_news_sm")
tgt_tokenizer = torchtext.data.utils.get_tokenizer("spacy", language="en_core_web_sm")

In [4]:
tgt_tokenizer("Hi there")

['Hi', 'there']

In [5]:
PAD_IDX, UNK_IDX, SOS_IDX, EOS_IDX = 0, 1, 2, 3
specials = ['<unk>', '<pad>', '<sos>', '<eos>']

def build_vocab(data_iter, tokenizer, lang_idx):
    for instance in data_iter:
        yield tokenizer(instance[lang_idx])


src_vocab = torchtext.vocab.build_vocab_from_iterator(build_vocab(train_iter, src_tokenizer, 0),
                                                      min_freq=1,
                                                      special_first=True,
                                                      specials=specials)
tgt_vocab = torchtext.vocab.build_vocab_from_iterator(build_vocab(train_iter, tgt_tokenizer, 1),
                                                      min_freq=1,
                                                      special_first=True,
                                                      specials=specials)

src_vocab.set_default_index(UNK_IDX)
tgt_vocab.set_default_index(UNK_IDX)



In [6]:
a = "Hey hi, how are you?"
a_tok = tgt_tokenizer(a)
a_voc = tgt_vocab(a_tok)
torch.cat((torch.tensor([SOS_IDX]), torch.tensor(a_voc), torch.tensor([EOS_IDX])))

tensor([   2,    1, 8920,   15,  889,   17, 1328, 2470,    3])

In [7]:
def transform(sentence: str, lang_idx: int):
    if lang_idx == 0:
        tokens = src_tokenizer(sentence)
        token_ids = src_vocab(tokens)
        return torch.cat((torch.tensor([SOS_IDX]),
                          torch.tensor(token_ids),
                          torch.tensor([EOS_IDX])))
    else:
        tokens = tgt_tokenizer(sentence)
        token_ids = tgt_vocab(tokens)
        return torch.cat((torch.tensor([SOS_IDX]),
                          torch.tensor(token_ids),
                          torch.tensor([EOS_IDX])))

In [8]:
transform(a, 1)

tensor([   2,    1, 8920,   15,  889,   17, 1328, 2470,    3])

In [9]:
a = [
    torch.tensor([0, 1, 2]),
    torch.tensor([0, 1])    
]
pad_sequence(a, padding_value=PAD_IDX, batch_first=True)

tensor([[0, 1, 2],
        [0, 1, 0]])

In [10]:
def collate_fn(batch):

    src_batch, tgt_batch = [], []
    for src, tgt in batch:
        src_batch.append(transform(src.rstrip("\n"), 0))
        tgt_batch.append(transform(tgt.rstrip("\n"), 1))
    
    src_batch = pad_sequence(src_batch, padding_value=PAD_IDX, batch_first=True)
    tgt_batch = pad_sequence(tgt_batch, batch_first=True, padding_value=PAD_IDX)
    return src_batch, tgt_batch

In [11]:
train_dataloader = torch.utils.data.DataLoader(train_iter, 
                                               batch_size=32,
                                               collate_fn=collate_fn)
test_dataloader = torch.utils.data.DataLoader(test_iter,
                                              batch_size=32,
                                              collate_fn=collate_fn)

In [12]:
next(iter(train_dataloader))[0].shape

torch.Size([32, 21])

In [13]:
max_seq_len_src, max_seq_len_tgt = 0, 0

for src_batch, tgt_batch in train_dataloader:
    max_seq_len_src = max(max_seq_len_src, src_batch.shape[-1])
    max_seq_len_tgt = max(max_seq_len_tgt, tgt_batch.shape[-1])

max_seq_len_src, max_seq_len_tgt

(46, 43)

In [14]:
from tranformer import Embedding, Transformer

src_embedding = Embedding(d_model=512,
                          vocab_size=len(src_vocab),
                          max_seq_len=max_seq_len_src+10)
tgt_embedding = Embedding(d_model=512,
                          vocab_size=len(tgt_vocab),
                          max_seq_len=max_seq_len_tgt+10)
transformer = Transformer(output_size=len(tgt_vocab))

In [15]:
from torchinfo import summary
summary(transformer,
        input_size=[(32, max_seq_len_src, 512), (32, max_seq_len_tgt, 512)])

  action_fn=lambda data: sys.getsizeof(data.storage()),
  return super().__sizeof__() + self.nbytes()


Layer (type:depth-idx)                                  Output Shape              Param #
Transformer                                             [32, 43, 10837]           --
├─TransformerEncoder: 1-1                               [32, 46, 512]             --
│    └─ModuleList: 2-1                                  --                        --
│    │    └─TransformerEncoderLayer: 3-1                [32, 46, 512]             3,152,384
│    │    └─TransformerEncoderLayer: 3-2                [32, 46, 512]             3,152,384
│    │    └─TransformerEncoderLayer: 3-3                [32, 46, 512]             3,152,384
│    │    └─TransformerEncoderLayer: 3-4                [32, 46, 512]             3,152,384
│    │    └─TransformerEncoderLayer: 3-5                [32, 46, 512]             3,152,384
│    │    └─TransformerEncoderLayer: 3-6                [32, 46, 512]             3,152,384
├─TransformerDecoder: 1-2                               [32, 43, 512]             --
│    └─ModuleList:

In [16]:
def create_src_mask(src):
    src_mask = src!=PAD_IDX
    return src_mask.unsqueeze(1).unsqueeze(2)

def create_trg_mask(trg):
    size = trg.size(-1) - 1
    # trg_padding_mask = (trg!=PAD_IDX).unsqueeze(1).unsqueeze(2)
    trg_mask = torch.tril(torch.ones(size, size)).expand(trg.size(0), 1, size, size)
    return trg_mask

def create_mask(src, trg):
    return create_src_mask(src), create_trg_mask(trg)

In [17]:
def train(model: Transformer,
          dataloader: torch.utils.data.DataLoader,
          loss_fn: torch.nn.Module,
          optimizer: torch.optim.Optimizer):
    
    model.train()
    losses = 0

    for batch, (src_batch, tgt_batch) in enumerate(tqdm(dataloader)):
        
        src_mask, tgt_mask = create_mask(src_batch, tgt_batch)
        src_embedded = src_embedding(src_batch)
        tgt_embedded = tgt_embedding(tgt_batch)
        
        logits = transformer(src_embedded, tgt_embedded[:, :-1], src_mask, tgt_mask)
        loss = loss_fn(logits.reshape(-1, logits.shape[-1]), tgt_batch[:, 1:].reshape(-1))
        losses += loss.item()

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if batch%5==0:
            sample_translation = logits[0]
            sample_translation = torch.argmax(sample_translation, dim=1).squeeze().numpy()
            sample_translation = " ".join(tgt_vocab.get_itos()[i] for i in list(sample_translation))
            print(f"Batch: {batch} sample translation: {sample_translation}")

    losses /= len(list(dataloader))
    return losses        

In [18]:
loss_fn = torch.nn.CrossEntropyLoss(ignore_index=PAD_IDX)
optimizer = torch.optim.Adam(transformer.parameters(), lr=0.0001, betas=(0.9, 0.98), eps=1e-9)

losses = train(transformer, train_dataloader, loss_fn, optimizer)

0it [00:00, ?it/s]



Batch: 0 sample translation: catamaran Warrior touch played Golden devil referees Olympics Pipe attentively pants contains slopes otuside otuside otuside haired operates otuside otuside barricade otuside ropes
Batch: 5 sample translation: A a . a a a . a a a <eos> a a a a a a a a a a a a a a
Batch: 10 sample translation: A . a a . a a . . <eos> a . a a . . . a a . .
Batch: 15 sample translation: A a . . a . . a . . . . . . . . . . . . . . . .
Batch: 20 sample translation: A man a a a a a a a a <eos> a a a a a . a a a a a
Batch: 25 sample translation: A . . . . . . . . . . . . <eos> . . . . . . .
Batch: 30 sample translation: A man . . a . . a . . . . . . . . . . . . . . .
Batch: 35 sample translation: A . . a a red . . <eos> . . . . . . . . . . . . . . .
Batch: 40 sample translation: A . . . . . . . <eos> . . . . . . . . . . . . . .
Batch: 45 sample translation: A . in . . in in a . . a . . <eos> . . . . . . . . . . . .
Batch: 50 sample translation: A . in . a a . <eos> . . . . . . . .