In [1]:
import torch
from torch import nn
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader

from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator
from torchtext.datasets import Multi30k

import json

from typing import Dict, Tuple, List

import sys
sys.path.insert(0, "..")
from model import TransformerForSeqToSeq

In [2]:
device = "mps" if torch.backends.mps.is_available() else "cpu"
device

'mps'

In [3]:
with open("../config/base_config.json") as f:
    config = json.load(f)

config

{'batch_size': 64,
 'block_size': 64,
 'context_length': 512,
 'd_model': 384,
 'dropout': 0.1,
 'head_dim': 64,
 'learning_rate': 0.0003,
 'n_decoders': 6,
 'n_encoders': 6,
 'n_heads': 6,
 'train_iters': 10}

In [4]:
SOS_TOKEN = 1
EOS_TOKEN = 2
PAD_TOKEN = 0
UNK_TOKEN = 3

In [5]:
train_iter = Multi30k(split="train")
val_iter = Multi30k(split="valid")

In [6]:
src_tokenizer = get_tokenizer(tokenizer="spacy", language="de_core_news_sm")
tgt_tokenizer = get_tokenizer(tokenizer="spacy", language="en_core_web_sm")

In [7]:
def build_vocab(tokenizer, idx, data_iter):

    for instance in data_iter:
        yield tokenizer(instance[idx])

src_vocab = build_vocab_from_iterator(build_vocab(src_tokenizer, 0, train_iter),
                                      min_freq=2,
                                      special_first=True,
                                      specials=["[PAD]", "[SOS]", "[EOS]", "[UNK]"])
tgt_vocab = build_vocab_from_iterator(build_vocab(tgt_tokenizer, 1, train_iter),
                                      min_freq=2,
                                      special_first=True,
                                      specials=["[PAD]", "[SOS]", "[EOS]", "[UNK]"])

src_vocab.set_default_index(UNK_TOKEN)
tgt_vocab.set_default_index(UNK_TOKEN)

len(src_vocab), len(tgt_vocab)



(8014, 6191)

In [8]:
tgt = "Hey, hi there"
tgt_tokens = tgt_tokenizer(tgt)
tgt_ids = tgt_vocab(tgt_tokens)
torch.cat([torch.tensor([SOS_TOKEN]), torch.tensor(tgt_ids), torch.tensor([EOS_TOKEN])])

tensor([  1,   3,  15,   3, 601,   2])

In [9]:
def collate_fn(batch):

    src_batch, tgt_batch = [], []

    for src, tgt in batch:
        src = src_vocab(src_tokenizer(src.rstrip("\n")))
        tgt = tgt_vocab(tgt_tokenizer(tgt.rstrip("\n")))
        
        src = torch.cat([torch.tensor([SOS_TOKEN]), torch.tensor(src), torch.tensor([EOS_TOKEN])])
        tgt = torch.cat([torch.tensor([SOS_TOKEN]), torch.tensor(tgt), torch.tensor([EOS_TOKEN])])
        
        src_batch.append(src)
        tgt_batch.append(tgt)

    src_batch = pad_sequence(src_batch, batch_first=True, padding_value=PAD_TOKEN)
    tgt_batch = pad_sequence(tgt_batch, batch_first=True, padding_value=PAD_TOKEN)
    return src_batch, tgt_batch

In [10]:
train_dataloader = DataLoader(dataset=train_iter,
                              batch_size=config["batch_size"],
                              collate_fn=collate_fn,
                              shuffle=True)

val_dataloader = DataLoader(dataset=val_iter,
                            batch_size=config["batch_size"],
                            collate_fn=collate_fn,
                            shuffle=False)

In [11]:
transformer = TransformerForSeqToSeq(config=config, src_vocab_size=len(src_vocab),
                                     tgt_vocab_size=len(tgt_vocab), padding_idx=PAD_TOKEN).to(device=device)

In [12]:
from torchinfo import summary
summary(transformer)

Layer (type:depth-idx)                                  Param #
TransformerForSeqToSeq                                  --
├─Embedding: 1-1                                        --
│    └─Embedding: 2-1                                   3,077,376
│    └─Embedding: 2-2                                   196,608
│    └─Dropout: 2-3                                     --
├─Embedding: 1-2                                        --
│    └─Embedding: 2-4                                   2,377,344
│    └─Embedding: 2-5                                   196,608
│    └─Dropout: 2-6                                     --
├─Encoder: 1-3                                          --
│    └─ModuleList: 2-7                                  --
│    │    └─EncoderBlock: 3-1                           1,774,464
│    │    └─EncoderBlock: 3-2                           1,774,464
│    │    └─EncoderBlock: 3-3                           1,774,464
│    │    └─EncoderBlock: 3-4                           1,774,464

In [13]:
def train_step(transformer: TransformerForSeqToSeq,
               dataloader: DataLoader,
               optimizer: torch.optim.Optimizer,
               device: torch.device="cpu") -> float:
    
    transformer.train()
    losses = 0
    for batch, (src_batch, tgt_batch) in enumerate(dataloader):

        src_batch, tgt_batch = src_batch.to(device), tgt_batch.to(device)
        logits, loss = transformer(src_batch, tgt_batch)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        losses += loss.item()
    
    return losses/len(dataloader)

def eval_step(transformer: TransformerForSeqToSeq,
              dataloader: DataLoader,
              device: torch.device="cpu") -> float:
    
    transformer.eval()
    losses = 0
    with torch.inference_mode():

        for batch, (src_batch, tgt_batch) in enumerate(dataloader):
            src_batch, tgt_batch = src_batch.to(device), tgt_batch.to(device)
            logits, loss = transformer(src_batch, tgt_batch)

            losses += loss.item()
        losses /= len(dataloader)
    
    return losses


def train(transformer: TransformerForSeqToSeq,
          train_dataloader: DataLoader,
          val_dataloader: DataLoader,
          optimizer: torch.optim.Optimizer,
          config: Dict[str, int],
          device: torch.device="cpu") -> Tuple[List[float], List[float]]:
    
    train_losses, eval_losses = [], []

    transformer.to(device=device)

    for epoch in range(1, config["train_iters"]+1):
        
        train_loss = train_step(transformer, train_dataloader, optimizer, device)
        val_loss = eval_step(transformer, val_dataloader, device)

        print(f"epoch {epoch}: train_loss: {train_loss:.4f} val_loss: {val_loss: .4f}")
        train_losses.append(train_loss)
        eval_losses.append(val_loss)
    
    return train_losses, eval_losses

In [14]:
transformer = TransformerForSeqToSeq(config=config,
                                     src_vocab_size=len(src_vocab),
                                     tgt_vocab_size=len(tgt_vocab),
                                     padding_idx=PAD_TOKEN)
optimizer = torch.optim.AdamW(params=transformer.parameters(),
                              lr=config["learning_rate"])

In [None]:
train_losses, eval_losses = train(transformer=transformer,
                                  train_dataloader=train_dataloader,
                                  val_dataloader=val_dataloader,
                                  optimizer=optimizer,
                                  config=config,
                                  device=device)