In [1]:
import torch
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator
from torchtext.datasets import multi30k, Multi30k
from typing import Iterable, List

torch.manual_seed(0)
torch.set_float32_matmul_precision('high')

DEVICE = torch.device('cuda:0')

# We need to modify the URLs for the dataset since the links to the original dataset are broken
# Refer to https://github.com/pytorch/text/issues/1756#issuecomment-1163664163 for more info
multi30k.URL["train"] = "https://raw.githubusercontent.com/neychev/small_DL_repo/master/datasets/Multi30k/training.tar.gz"
multi30k.URL["valid"] = "https://raw.githubusercontent.com/neychev/small_DL_repo/master/datasets/Multi30k/validation.tar.gz"

SRC_LANGUAGE = 'de'
TGT_LANGUAGE = 'en'

# Place-holders
token_transform = {}
vocab_transform = {}


# Create source and target language tokenizer. Make sure to install the dependencies.
# pip install -U torchdata
# pip install -U spacy
# python -m spacy download en_core_web_sm
# python -m spacy download de_core_news_sm
token_transform[SRC_LANGUAGE] = get_tokenizer('spacy', language='de_core_news_sm')
token_transform[TGT_LANGUAGE] = get_tokenizer('spacy', language='en_core_web_sm')


# helper function to yield list of tokens
def yield_tokens(data_iter: Iterable, language: str) -> List[str]:
    language_index = {SRC_LANGUAGE: 0, TGT_LANGUAGE: 1}

    for data_sample in data_iter:
        yield token_transform[language](data_sample[language_index[language]])

# Define special symbols and indices
UNK_IDX, PAD_IDX, BOS_IDX, EOS_IDX = 0, 1, 2, 3
# Make sure the tokens are in order of their indices to properly insert them in vocab
special_symbols = ['<unk>', '<pad>', '<bos>', '<eos>']

for ln in [SRC_LANGUAGE, TGT_LANGUAGE]:
    # Training data Iterator
    train_iter = Multi30k(split='train', language_pair=(SRC_LANGUAGE, TGT_LANGUAGE))
    # Create torchtext's Vocab object
    vocab_transform[ln] = build_vocab_from_iterator(yield_tokens(train_iter, ln),
                                                    min_freq=1,
                                                    specials=special_symbols,
                                                    special_first=True)

# Set UNK_IDX as the default index. This index is returned when the token is not found.
# If not set, it throws RuntimeError when the queried token is not found in the Vocabulary.
for ln in [SRC_LANGUAGE, TGT_LANGUAGE]:
  vocab_transform[ln].set_default_index(UNK_IDX)

def make_tok_masks(src, tgt):
    src_mask = (src != PAD_IDX)
    tgt_mask = (tgt != PAD_IDX)
    return src_mask, tgt_mask

In [2]:
ds = list(Multi30k(split='train', language_pair=(SRC_LANGUAGE, TGT_LANGUAGE)))

In [3]:
len(ds)

29001

In [4]:
def generate_square_subsequent_mask(sz):
    mask = (torch.triu(torch.ones((sz, sz), device=DEVICE)) == 1).transpose(0, 1)
    mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
    return mask

def create_mask(src, tgt):
    src_seq_len = src.shape[1]
    tgt_seq_len = tgt.shape[1]

    tgt_mask = generate_square_subsequent_mask(tgt_seq_len)
    src_mask = torch.zeros((src_seq_len, src_seq_len), dtype=torch.bool)

    src_padding_mask = (src == PAD_IDX)
    tgt_padding_mask = (tgt == PAD_IDX)
    return src_mask.to(DEVICE), tgt_mask.to(DEVICE), src_padding_mask.to(DEVICE), tgt_padding_mask.to(DEVICE)

# Model

In [288]:
import torch.nn as nn
from torch import Tensor

class InputEmbeddings(nn.Module):
	''' Apply learnable token and position embeddings to input tokens. '''

	def __init__(self, vocab_size: int, emb_size: int, maxlen: int = 5000):
		super().__init__()
		self.token_embedding_table = nn.Embedding(vocab_size, emb_size)
		self.position_embedding_table = nn.Embedding(maxlen, emb_size)
		self.register_buffer('pos_emb_index', torch.arange(maxlen))
	
	def forward(self, x: Tensor):
		B, T = x.shape
		tok_embd = self.token_embedding_table(x)
		pos_embd = self.position_embedding_table(self.pos_emb_index[:T])
		return tok_embd + pos_embd

# Seq2Seq Network
class Seq2SeqTransformer(nn.Module):
    def __init__(self,
                 num_encoder_layers: int,
                 num_decoder_layers: int,
                 emb_size: int,
                 nhead: int,
                 src_vocab_size: int,
                 tgt_vocab_size: int,
                 dim_feedforward: int = 512,
                 dropout: float = 0.1):
        super(Seq2SeqTransformer, self).__init__()
        self.transformer = nn.Transformer(d_model=emb_size,
                                       nhead=nhead,
                                       num_encoder_layers=num_encoder_layers,
                                       num_decoder_layers=num_decoder_layers,
                                       dim_feedforward=dim_feedforward,
                                       dropout=dropout,
                                       batch_first=True)
        self.generator = nn.Linear(emb_size, tgt_vocab_size)
        self.src_tok_emb = InputEmbeddings(src_vocab_size, emb_size)
        self.tgt_tok_emb = InputEmbeddings(tgt_vocab_size, emb_size)

    def forward(self,
                src: Tensor,
                trg: Tensor,
                src_mask: Tensor,
                tgt_mask: Tensor,
                src_padding_mask: Tensor,
                tgt_padding_mask: Tensor,
                memory_key_padding_mask: Tensor):
        src_emb = (self.src_tok_emb(src))
        tgt_emb = (self.tgt_tok_emb(trg))
        outs = self.transformer(src_emb, tgt_emb, src_mask, tgt_mask, None,
                                src_padding_mask, tgt_padding_mask, memory_key_padding_mask)
        return self.generator(outs)

    def encode(self, src: Tensor, src_mask: Tensor):
        return self.transformer.encoder((
                            self.src_tok_emb(src)), src_mask)

    def decode(self, tgt: Tensor, memory: Tensor, tgt_mask: Tensor):
        return self.transformer.decoder((
                          self.tgt_tok_emb(tgt)), memory,
                          tgt_mask)

In [148]:
from torch.nn.utils.rnn import pad_sequence

# helper function to club together sequential operations
def sequential_transforms(*transforms):
    def func(txt_input):
        for transform in transforms:
            txt_input = transform(txt_input)
        return txt_input
    return func

# function to add BOS/EOS and create tensor for input sequence indices
def tensor_transform(token_ids: List[int]):
    return torch.cat((torch.tensor([BOS_IDX]),
                      torch.tensor(token_ids),
                      torch.tensor([EOS_IDX])))

# src and tgt language text transforms to convert raw strings into tensors indices
text_transform = {}
for ln in [SRC_LANGUAGE, TGT_LANGUAGE]:
    text_transform[ln] = sequential_transforms(token_transform[ln], #Tokenization
                                               vocab_transform[ln], #Numericalization
                                               tensor_transform) # Add BOS/EOS and create tensor


# function to collate data samples into batch tensors
def collate_fn(batch):
    src_batch, tgt_batch = [], []
    for src_sample, tgt_sample in batch:
        src_batch.append(text_transform[SRC_LANGUAGE](src_sample.rstrip("\n")))
        tgt_batch.append(text_transform[TGT_LANGUAGE](tgt_sample.rstrip("\n")))

    src_batch = pad_sequence(src_batch, padding_value=PAD_IDX, batch_first=True)
    tgt_batch = pad_sequence(tgt_batch, padding_value=PAD_IDX, batch_first=True)
    return src_batch, tgt_batch

In [7]:
SRC_VOCAB_SIZE = len(vocab_transform[SRC_LANGUAGE])
TGT_VOCAB_SIZE = len(vocab_transform[TGT_LANGUAGE])

EMB_SIZE = 512
NHEAD = 8
FFN_HID_DIM = 512
BATCH_SIZE = 128
NUM_ENCODER_LAYERS = 3
NUM_DECODER_LAYERS = 3

transformer = Seq2SeqTransformer(num_encoder_layers=NUM_ENCODER_LAYERS,
                                 num_decoder_layers=NUM_DECODER_LAYERS,
                                 emb_size=EMB_SIZE,
                                 nhead=NHEAD,
                                 src_vocab_size=SRC_VOCAB_SIZE,
                                 tgt_vocab_size=TGT_VOCAB_SIZE,
                                 dim_feedforward=FFN_HID_DIM,
                                 dropout=0.1)

transformer = transformer.to(DEVICE)

loss_fn = torch.nn.CrossEntropyLoss(ignore_index=PAD_IDX)

optimizer = torch.optim.Adam(transformer.parameters(), lr=0.0001, betas=(0.9, 0.98), eps=1e-9)

In [8]:
from torch.utils.data import DataLoader

def train_epoch(model, optimizer):
    model.train()
    losses = 0
    train_iter = Multi30k(split='train', language_pair=(SRC_LANGUAGE, TGT_LANGUAGE))
    train_dataloader = DataLoader(train_iter, batch_size=BATCH_SIZE, collate_fn=collate_fn)

    for src, tgt in train_dataloader:
        src = src.to(DEVICE)
        tgt = tgt.to(DEVICE)

        tgt_input = tgt[:, :-1]

        src_mask, tgt_mask, src_padding_mask, tgt_padding_mask = create_mask(src, tgt_input)
        logits = model(src, tgt_input, src_mask, tgt_mask, src_padding_mask, tgt_padding_mask, src_padding_mask)

        optimizer.zero_grad()

        tgt_out = tgt[:, 1:]
        loss = loss_fn(logits.reshape(-1, logits.shape[-1]), tgt_out.reshape(-1))
        loss.backward()

        optimizer.step()
        losses += loss.item()

    return losses / len(list(train_dataloader))


def evaluate(model):
    model.eval()
    losses = 0

    val_iter = Multi30k(split='valid', language_pair=(SRC_LANGUAGE, TGT_LANGUAGE))
    val_dataloader = DataLoader(val_iter, batch_size=BATCH_SIZE, collate_fn=collate_fn)

    for src, tgt in val_dataloader:
        src = src.to(DEVICE)
        tgt = tgt.to(DEVICE)

        tgt_input = tgt[:, :-1]

        src_mask, tgt_mask, src_padding_mask, tgt_padding_mask = create_mask(src, tgt_input)
        logits = model(src, tgt_input, src_mask, tgt_mask,src_padding_mask, tgt_padding_mask, src_padding_mask)

        tgt_out = tgt[:, 1:]
        loss = loss_fn(logits.reshape(-1, logits.shape[-1]), tgt_out.reshape(-1))
        losses += loss.item()

    return losses / len(list(val_dataloader))

In [9]:
from timeit import default_timer as timer
NUM_EPOCHS = 20

for epoch in range(1, NUM_EPOCHS+1):
    start_time = timer()
    train_loss = train_epoch(transformer, optimizer)
    end_time = timer()
    val_loss = evaluate(transformer)
    print((f"Epoch: {epoch}, Train loss: {train_loss:.3f}, Val loss: {val_loss:.3f}, "f"Epoch time = {(end_time - start_time):.3f}s"))




Epoch: 1, Train loss: 4.989, Val loss: 3.925, Epoch time = 7.518s
Epoch: 2, Train loss: 3.670, Val loss: 3.332, Epoch time = 6.923s
Epoch: 3, Train loss: 3.209, Val loss: 3.021, Epoch time = 6.945s
Epoch: 4, Train loss: 2.908, Val loss: 2.823, Epoch time = 6.921s
Epoch: 5, Train loss: 2.684, Val loss: 2.687, Epoch time = 7.032s
Epoch: 6, Train loss: 2.501, Val loss: 2.579, Epoch time = 6.990s
Epoch: 7, Train loss: 2.348, Val loss: 2.489, Epoch time = 7.001s
Epoch: 8, Train loss: 2.213, Val loss: 2.420, Epoch time = 6.954s
Epoch: 9, Train loss: 2.093, Val loss: 2.360, Epoch time = 6.922s
Epoch: 10, Train loss: 1.983, Val loss: 2.324, Epoch time = 6.950s
Epoch: 11, Train loss: 1.884, Val loss: 2.285, Epoch time = 6.878s
Epoch: 12, Train loss: 1.785, Val loss: 2.252, Epoch time = 6.848s
Epoch: 13, Train loss: 1.700, Val loss: 2.228, Epoch time = 6.880s
Epoch: 14, Train loss: 1.615, Val loss: 2.210, Epoch time = 6.888s
Epoch: 15, Train loss: 1.534, Val loss: 2.199, Epoch time = 6.888s
Epoc

# Lightning Training

## 1. Using Automatic Optimization

In [38]:
import pytorch_lightning as pl

class Seq2SeqTransformerLNAuto(pl.LightningModule):
		
	def __init__(self, transformer: Seq2SeqTransformer):
		super().__init__()
		self.transformer = transformer
		self.criterion = nn.CrossEntropyLoss(ignore_index=PAD_IDX)

	def forward(self, src, tgt, src_mask, tgt_mask, src_padding_mask, tgt_padding_mask, memory_key_padding_mask):
		return self.transformer(src, tgt, src_mask, tgt_mask, src_padding_mask, tgt_padding_mask, memory_key_padding_mask)
	
	def training_step(self, batch, batch_idx: int):
		src, tgt = batch
		tgt_input = tgt[:, :-1]
		src_mask, tgt_mask, src_padding_mask, tgt_padding_mask = create_mask(src, tgt_input)
		memory_key_padding_mask = src_padding_mask
		y_pred = self(src, tgt_input, src_mask, tgt_mask, src_padding_mask, tgt_padding_mask, memory_key_padding_mask)
		y_gt = tgt[:, 1:]
		loss = self.criterion(y_pred.view(-1, y_pred.shape[-1]), y_gt.reshape(-1))
		self.log('train_loss', loss, prog_bar=True)
		self.train_losses.append(loss)
		return loss

	def validation_step(self, batch, batch_idx: int):
		src, tgt = batch
		tgt_input = tgt[:, :-1]
		src_mask, tgt_mask, src_padding_mask, tgt_padding_mask = create_mask(src, tgt_input)
		memory_key_padding_mask = src_padding_mask
		y_pred = self(src, tgt_input, src_mask, tgt_mask, src_padding_mask, tgt_padding_mask, memory_key_padding_mask)
		y_gt = tgt[:, 1:]
		loss = self.criterion(y_pred.view(-1, y_pred.shape[-1]), y_gt.reshape(-1))
		self.log('val_loss', loss, prog_bar=True)
		self.val_losses.append(loss)
		return loss
	
	def on_train_epoch_start(self):
		self.train_losses = []
	
	def on_validation_epoch_start(self):
		self.val_losses = []
	
	def on_train_epoch_end(self):
		loss = sum(self.train_losses) / len(self.train_losses)
		print(f'Epoch {self.trainer.current_epoch} train loss:', loss)

	def on_validation_epoch_end(self):
		loss = sum(self.val_losses) / len(self.val_losses)
		print(f'Epoch {self.trainer.current_epoch} val loss:', loss)

	def configure_optimizers(self):
			return torch.optim.Adam(transformer.parameters(), lr=0.0001, betas=(0.9, 0.98), eps=1e-9)


In [39]:
from torch.utils.data import DataLoader

train_iter = Multi30k(split='train', language_pair=(SRC_LANGUAGE, TGT_LANGUAGE))
val_iter = Multi30k(split='valid', language_pair=(SRC_LANGUAGE, TGT_LANGUAGE))
train_dl = DataLoader(train_iter, batch_size=BATCH_SIZE, collate_fn=collate_fn)
val_dl = DataLoader(val_iter, batch_size=BATCH_SIZE, collate_fn=collate_fn)

trainer = pl.Trainer(accelerator='gpu', devices=1)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [40]:
SRC_VOCAB_SIZE = len(vocab_transform[SRC_LANGUAGE])
TGT_VOCAB_SIZE = len(vocab_transform[TGT_LANGUAGE])

EMB_SIZE = 512
NHEAD = 8
FFN_HID_DIM = 512
BATCH_SIZE = 128
NUM_ENCODER_LAYERS = 3
NUM_DECODER_LAYERS = 3

transformer = Seq2SeqTransformer(num_encoder_layers=NUM_ENCODER_LAYERS,
                                 num_decoder_layers=NUM_DECODER_LAYERS,
                                 emb_size=EMB_SIZE,
                                 nhead=NHEAD,
                                 src_vocab_size=SRC_VOCAB_SIZE,
                                 tgt_vocab_size=TGT_VOCAB_SIZE,
                                 dim_feedforward=FFN_HID_DIM,
                                 dropout=0.1)

transformer = Seq2SeqTransformerLNAuto(transformer)

In [41]:
trainer.fit(transformer, train_dl, val_dl)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name        | Type               | Params
---------------------------------------------------
0 | transformer | Seq2SeqTransformer | 38.7 M
1 | criterion   | CrossEntropyLoss   | 0     
---------------------------------------------------
38.7 M    Trainable params
0         Non-trainable params
38.7 M    Total params
154.762   Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

Epoch 0 val loss: tensor(9.4906, device='cuda:0')


Training: 0it [00:00, ?it/s]



Validation: 0it [00:00, ?it/s]

Epoch 0 val loss: tensor(3.9061, device='cuda:0')
Epoch 0 train loss: tensor(4.9814, device='cuda:0', grad_fn=<DivBackward0>)


Validation: 0it [00:00, ?it/s]

Epoch 1 val loss: tensor(3.3512, device='cuda:0')
Epoch 1 train loss: tensor(3.6727, device='cuda:0', grad_fn=<DivBackward0>)


Validation: 0it [00:00, ?it/s]

Epoch 2 val loss: tensor(3.0381, device='cuda:0')
Epoch 2 train loss: tensor(3.2213, device='cuda:0', grad_fn=<DivBackward0>)


Validation: 0it [00:00, ?it/s]

Epoch 3 val loss: tensor(2.8416, device='cuda:0')
Epoch 3 train loss: tensor(2.9216, device='cuda:0', grad_fn=<DivBackward0>)


Validation: 0it [00:00, ?it/s]

Epoch 4 val loss: tensor(2.6962, device='cuda:0')
Epoch 4 train loss: tensor(2.6964, device='cuda:0', grad_fn=<DivBackward0>)


Validation: 0it [00:00, ?it/s]

Epoch 5 val loss: tensor(2.5851, device='cuda:0')
Epoch 5 train loss: tensor(2.5159, device='cuda:0', grad_fn=<DivBackward0>)


Validation: 0it [00:00, ?it/s]

Epoch 6 val loss: tensor(2.4983, device='cuda:0')
Epoch 6 train loss: tensor(2.3624, device='cuda:0', grad_fn=<DivBackward0>)


Validation: 0it [00:00, ?it/s]

Epoch 7 val loss: tensor(2.4204, device='cuda:0')
Epoch 7 train loss: tensor(2.2258, device='cuda:0', grad_fn=<DivBackward0>)


Validation: 0it [00:00, ?it/s]

Epoch 8 val loss: tensor(2.3698, device='cuda:0')
Epoch 8 train loss: tensor(2.1035, device='cuda:0', grad_fn=<DivBackward0>)


Validation: 0it [00:00, ?it/s]

Epoch 9 val loss: tensor(2.3168, device='cuda:0')
Epoch 9 train loss: tensor(1.9952, device='cuda:0', grad_fn=<DivBackward0>)


Validation: 0it [00:00, ?it/s]

Epoch 10 val loss: tensor(2.2725, device='cuda:0')
Epoch 10 train loss: tensor(1.8943, device='cuda:0', grad_fn=<DivBackward0>)


Validation: 0it [00:00, ?it/s]

Epoch 11 val loss: tensor(2.2396, device='cuda:0')
Epoch 11 train loss: tensor(1.7968, device='cuda:0', grad_fn=<DivBackward0>)


Validation: 0it [00:00, ?it/s]

Epoch 12 val loss: tensor(2.2187, device='cuda:0')
Epoch 12 train loss: tensor(1.7091, device='cuda:0', grad_fn=<DivBackward0>)


Validation: 0it [00:00, ?it/s]

Epoch 13 val loss: tensor(2.1948, device='cuda:0')
Epoch 13 train loss: tensor(1.6268, device='cuda:0', grad_fn=<DivBackward0>)


Validation: 0it [00:00, ?it/s]

Epoch 14 val loss: tensor(2.1743, device='cuda:0')
Epoch 14 train loss: tensor(1.5528, device='cuda:0', grad_fn=<DivBackward0>)


Validation: 0it [00:00, ?it/s]

Epoch 15 val loss: tensor(2.1648, device='cuda:0')
Epoch 15 train loss: tensor(1.4763, device='cuda:0', grad_fn=<DivBackward0>)


Validation: 0it [00:00, ?it/s]

Epoch 16 val loss: tensor(2.1540, device='cuda:0')
Epoch 16 train loss: tensor(1.4031, device='cuda:0', grad_fn=<DivBackward0>)


Validation: 0it [00:00, ?it/s]

Epoch 17 val loss: tensor(2.1269, device='cuda:0')
Epoch 17 train loss: tensor(1.3399, device='cuda:0', grad_fn=<DivBackward0>)


Validation: 0it [00:00, ?it/s]

Epoch 18 val loss: tensor(2.1246, device='cuda:0')
Epoch 18 train loss: tensor(1.2728, device='cuda:0', grad_fn=<DivBackward0>)


Validation: 0it [00:00, ?it/s]

Epoch 19 val loss: tensor(2.1210, device='cuda:0')
Epoch 19 train loss: tensor(1.2081, device='cuda:0', grad_fn=<DivBackward0>)


Validation: 0it [00:00, ?it/s]

Epoch 20 val loss: tensor(2.1135, device='cuda:0')
Epoch 20 train loss: tensor(1.1469, device='cuda:0', grad_fn=<DivBackward0>)


Validation: 0it [00:00, ?it/s]

Epoch 21 val loss: tensor(2.1197, device='cuda:0')
Epoch 21 train loss: tensor(1.0881, device='cuda:0', grad_fn=<DivBackward0>)


  rank_zero_warn("Detected KeyboardInterrupt, attempting graceful shutdown...")


## My Own Model

In [361]:
import torch
import torch.nn as nn
from torch.nn import functional as F
from torch.utils.checkpoint import checkpoint
from torch import Tensor
from typing import *
import pytorch_lightning as pl
import sentencepiece as sp
import math

GPU = torch.device('cuda')

import torch
from torch import Tensor
from torch import nn

class PositionalEmbedding(nn.Module):
	''' Sinosuidal positional embedding. '''

	def __init__(self, emb_dim: int, max_len: int):
		super().__init__()

		encoding = torch.zeros(max_len + 2, emb_dim, requires_grad=False)
		pos = torch.arange(0.0, max_len + 2, dtype=torch.float).unsqueeze(dim=1)
		_2i = torch.arange(0, emb_dim, step=2).float()

		encoding[:, 0::2] = torch.sin(pos / (10000 ** (_2i / emb_dim)))
		encoding[:, 1::2] = torch.cos(pos / (10000 ** (_2i / emb_dim)))

		self.register_buffer('encoding', encoding)

	def forward(self, x: Tensor):
		return self.encoding[:x.size(0)]

class PosNTokEmbedding(nn.Module):
	''' Apply learnable token and sinosuidal position embeddings to input tokens. '''

	def __init__(self, vocab_size: int, emb_dim: int, max_len: int):
		super().__init__()
		self.token_embedding_table = nn.Embedding(vocab_size, emb_dim)
		self.position_embedding_table = PositionalEmbedding(emb_dim, max_len)
		self.max_len = max_len
	
	def forward(self, x: Tensor):
		tok_embd = self.token_embedding_table(x)
		pos_embd = self.position_embedding_table(torch.arange(x.size(1), device=x.device))
		return tok_embd + pos_embd
	
import torch
from torch import nn
from torch import Tensor
import math

class MultiHeadSelfAttention(nn.Module):

	''' Multi-head self attention.
	Implements a somewhat optimized version of the self attention by combining the q, k, v projections.
	
	Inputs:
		`x`: Tensor<Float>[B, T, C] input tensor.
		`tok_mask`: Tensor<Bool>[B, T] per-token mask applied to the `x`, false is masked out, true is preserved - masks both keys and queries.

	Outputs:
		Tensor<Float>[B, T, C] output tensor.
	'''

	def __init__(self, n_heads: int, emb_dim: int, dropout: float, bias: bool = False, is_causal: bool = False):
		super().__init__()
		self.is_causal = is_causal
		self.n_heads = n_heads
		self.emb_dim = emb_dim
		self.attn_dropout = nn.Dropout(dropout)
		self.resid_dropout = nn.Dropout(dropout)
		# combine q, k, v projections for efficiency
		self.qkv_projection = nn.Linear(emb_dim, 3 * emb_dim, bias=bias)
		# output projection
		self.c_proj = nn.Linear(emb_dim, emb_dim, bias=bias)

	def forward(self, x: Tensor, tok_mask: Tensor):
		B, T, C = x.shape
		# proj q, k, v for all heads
		# the heads are treated as a batch dimension
		q, k, v = self.qkv_projection(x).split(self.emb_dim, dim=2)
		q = q.view(B, T, self.n_heads, C // self.n_heads).transpose(1, 2)
		k = k.view(B, T, self.n_heads, C // self.n_heads).transpose(1, 2)
		v = v.view(B, T, self.n_heads, C // self.n_heads).transpose(1, 2)
		# compute attention
		att_weights = (q @ k.transpose(-2, -1)) / math.sqrt(k.size(-1))
		mask = tok_mask.view(B, 1, T) # (B, 1, T)
		mask = mask.tile(1, T, 1) # (B, T, T)
		mask = mask & mask.transpose(-2, -1) # (B, T, T)
		mask = mask.view(B, 1, T, T) # (B, 1, T, T)
		if self.is_causal:
			causal_mask = torch.tril(torch.ones(T, T, dtype=torch.bool, device=x.device))
			mask = mask & causal_mask[None, None, :, :]
		att_weights = att_weights.masked_fill(mask == 0, -1e9)
		att_weights = nn.functional.softmax(att_weights, dim=-1)
		att_weights = self.attn_dropout(att_weights)
		y = att_weights @ v
		# combine heads
		y = y.transpose(1, 2).contiguous().view(B, T, C)
		y = self.resid_dropout(self.c_proj(y))
		return y

class MultiHeadCrossAttention(nn.Module):

	''' Multi-head cross attention.
	Implements a somewhat optimized version of the cross attention by combining the k, v projections.
	
	Inputs:
		`x_q`: Tensor<Float>[B, T_q, C] query input tensor.
		`x_kv`: Tensor<Float>[B, T_kv, C] key and value input tensor.
		`q_tok_mask`: Tensor<Bool>[B, T_q] mask applied to the `x_q`, false is masked out, true is preserved - applies to q only.
		`kv_tok_mask`: Tensor<Bool>[B, T_kv] mask applied to the `x_kv`, false is masked out, true is preserved - applies to k and v.

	Outputs:
		Tensor<Float>[B, T_q, C] output tensor.
	'''

	def __init__(self, n_heads: int, emb_dim: int, dropout: float, bias: bool = False):
		super().__init__()
		self.n_heads = n_heads
		self.emb_dim = emb_dim
		self.attn_dropout = nn.Dropout(dropout)
		self.resid_dropout = nn.Dropout(dropout)
		self.q_projection = nn.Linear(emb_dim, emb_dim, bias=bias)
		# combine k, v projections for efficiency
		self.kv_projection = nn.Linear(emb_dim, 2 * emb_dim, bias=bias)
		# output projection
		self.c_proj = nn.Linear(emb_dim, emb_dim, bias=bias)

	def forward(self, x_q: Tensor, x_kv: Tensor, q_tok_mask: Tensor, kv_tok_mask: Tensor):
		# proj query for all heads
		B, T_q, C = x_q.shape
		q = self.q_projection(x_q)
		q = q.view(B, T_q, self.n_heads, C // self.n_heads).transpose(1, 2)
		# proj key & value for all heads
		B, T_kv, C = x_kv.shape
		k, v = self.kv_projection(x_kv).split(self.emb_dim, dim=2)
		k = k.view(B, T_kv, self.n_heads, C // self.n_heads).transpose(1, 2)
		v = v.view(B, T_kv, self.n_heads, C // self.n_heads).transpose(1, 2)
		# compute attention
		att_weights = (q @ k.transpose(-2, -1)) / math.sqrt(k.size(-1))
		# merge masks
		q_tok_mask = q_tok_mask.unsqueeze(2) # (N, T_q, 1)
		kv_tok_mask = kv_tok_mask.unsqueeze(1) # (N, 1, T_kv)
		attn_mask = q_tok_mask & kv_tok_mask
		# apply mask
		att_weights = att_weights.masked_fill(attn_mask.unsqueeze(1) == 0, -1e9)
		att_weights = nn.functional.softmax(att_weights, dim=-1)
		att_weights = self.attn_dropout(att_weights)
		y = att_weights @ v
		# combine heads
		y = y.transpose(1, 2).contiguous().view(B, T_q, C)
		y = self.resid_dropout(self.c_proj(y))
		return y

import torch
from torch import nn
from torch import Tensor
from torch.utils.checkpoint import checkpoint
from importlib import import_module

class TransformerFeedFoward(nn.Module):

	def __init__(self, emb_dim: int, dropout: float):
		super().__init__()
		self.net = nn.Sequential(
			nn.Linear(emb_dim, 4 * emb_dim),
			nn.GELU(approximate='tanh'),
			nn.Linear(4 * emb_dim, emb_dim),
		)
		if dropout:
			self.net.append(nn.Dropout(dropout))

	def forward(self, x: Tensor):
		return self.net(x)

class TransformerEncoderBlock(nn.Module):

	def __init__(self, n_heads: int, emb_dim: int, dropout: float, bias: bool = False, attention_type: str = 'vanilla'):
		super().__init__()
		sa_class = MultiHeadSelfAttention
		self.sa_module = sa_class(n_heads, emb_dim, dropout, bias)
		self.fw_module = TransformerFeedFoward(emb_dim, dropout)
		self.ln1 = nn.LayerNorm(emb_dim)
		self.ln2 = nn.LayerNorm(emb_dim)

	
	def forward(self, src: Tensor, src_mask: Tensor):
		x = src + self.sa_module(self.ln1(src), src_mask)
		x = x + self.fw_module(self.ln2(src))
		return x

class TransformerEncoder(nn.Module):

	def __init__(self, n_blocks: int, n_heads: int, emb_dim: int, dropout: float, bias: bool = False, use_grad_ckpt: bool = False, attention_type: str = 'vanilla'):
		super().__init__()
		self.blocks = nn.ModuleList([TransformerEncoderBlock(n_heads, emb_dim, dropout, bias, attention_type) for _ in range(n_blocks)])
		self.use_grad_ckpt = use_grad_ckpt
	
	def forward(self, src: Tensor, src_mask: Tensor):
		x = src
		for block in self.blocks:
			if self.use_grad_ckpt:
				forward = lambda *inputs: block(*inputs)
				x = checkpoint(forward, x, src_mask, preserve_rng_state=False)
			else:
				x = block(x, src_mask)
		return x

class TransformerDecoderBlock(nn.Module):

	def __init__(self, n_heads: int, emb_dim: int, dropout: float, bias: bool = False, attention_type: str = 'vanilla'):
		super().__init__()
		sa_class = MultiHeadSelfAttention
		ca_class = MultiHeadCrossAttention
		self.sa_module = sa_class(n_heads, emb_dim, dropout, bias, is_causal=True)
		self.ca_module = ca_class(n_heads, emb_dim, dropout, bias)
		self.fw_module = TransformerFeedFoward(emb_dim, dropout)
		self.ln1 = nn.LayerNorm(emb_dim)
		self.ln2 = nn.LayerNorm(emb_dim)
		self.ln3 = nn.LayerNorm(emb_dim)
	
	def forward(self, src: Tensor, tgt: Tensor, src_mask: Tensor, tgt_mask: Tensor):
		x = tgt + self.sa_module(self.ln1(tgt), tgt_mask)
		x = x + self.ca_module(self.ln2(tgt), self.ln2(src), tgt_mask, src_mask)
		x = x + self.fw_module(self.ln3(x))
		return x

class TransformerDecoder(nn.Module):

	def __init__(self, n_blocks: int, n_heads: int, emb_dim: int, dropout: float, bias: bool = False, use_grad_ckpt: bool = False, attention_type: str = 'vanilla'):
		super().__init__()
		self.blocks = nn.ModuleList([TransformerDecoderBlock(n_heads, emb_dim, dropout, bias, attention_type) for _ in range(n_blocks)])
		self.use_grad_ckpt = use_grad_ckpt
	
	def forward(self, src: Tensor, tgt: Tensor, src_mask: Tensor, tgt_mask: Tensor):
		x = tgt
		for block in self.blocks:
			if self.use_grad_ckpt:
				forward = lambda *inputs: block(*inputs)
				x = checkpoint(forward, src, x, src_mask, tgt_mask, preserve_rng_state=False)
			else:
				x = block(src, x, src_mask, tgt_mask)
		return x

class TransformerLMHead(nn.Module):

	def __init__(self, emb_dim: int, tgt_vocab_size: int):
		super().__init__()
		self.ln = nn.LayerNorm(emb_dim)
		self.logits_head = nn.Linear(emb_dim, tgt_vocab_size, bias=False)
	
	def forward(self, x: Tensor):
		x = self.ln(x)
		x = self.logits_head(x)
		return x
import torch
import torch.nn as nn
from torch.nn import functional as F
from torch import Tensor
import pytorch_lightning as pl
from dataclasses import dataclass
from pytorch_lightning.utilities import grad_norm

### CONFIG ###

@dataclass
class TransformerConfig:
    max_len: int
    src_vocab_size: int
    tgt_vocab_size: int
    n_blocks: int
    n_heads: int
    emb_dim: int
    dropout: float
    bias: bool
    weight_tying: bool
    use_grad_ckpt: bool
    pad_index: int
    optimizer: str
    learning_rate: float
    attention_type: str = 'vanilla'

### INPUT ###

@dataclass
class TransformerInputBatch:
	''' A batch of training data for the Transformer model. '''
	x_src: torch.Tensor
	x_tgt: torch.Tensor
	x_src_mask: torch.Tensor
	x_tgt_mask: torch.Tensor
	y_tgt: torch.Tensor

### MODEL ###

class Transformer(pl.LightningModule):

	def __init__(self, config: TransformerConfig):
		super().__init__()
		self.save_hyperparameters()
		self.config = config
		self.learning_rate = config.learning_rate
		self.attention_type = config.attention_type
		self.criterion = nn.CrossEntropyLoss(ignore_index=config.pad_index)
		# setup sample input for tracing
		self.example_input_array = (torch.zeros(1, config.max_len, dtype=torch.long),
			      					torch.zeros(1, config.max_len, dtype=torch.long),
									torch.zeros(1, config.max_len, dtype=torch.long),
									torch.zeros(1, config.max_len, dtype=torch.long))
		# model parts
		self.src_embeddings = PosNTokEmbedding(config.src_vocab_size, config.emb_dim, config.max_len)
		self.tgt_embeddings = PosNTokEmbedding(config.tgt_vocab_size, config.emb_dim, config.max_len)
		self.encoder = TransformerEncoder(config.n_blocks, config.n_heads, config.emb_dim, config.dropout, config.bias, config.use_grad_ckpt, config.attention_type)
		self.decoder = TransformerDecoder(config.n_blocks, config.n_heads, config.emb_dim, config.dropout, config.bias, config.use_grad_ckpt, config.attention_type)
		self.lm_head = TransformerLMHead(config.emb_dim, config.tgt_vocab_size)
		# weight tying
		if config.weight_tying:
			self.tgt_embeddings.token_embedding_table.weight = self.lm_head.logits_head.weight

	def forward(self, src: Tensor, tgt: Tensor, src_tok_mask: Tensor, tgt_tok_mask: Tensor):
		''' Forward pass through the model.'''
		enc = self.encoder(self.src_embeddings(src), src_tok_mask)
		dec = self.decoder(enc, self.tgt_embeddings(tgt), src_tok_mask, tgt_tok_mask)
		logits = self.lm_head(dec)
		return logits
	
	def calculate_loss(self, y_pred: Tensor, y_true: Tensor, prefix: str):
		''' Calculate and log loss for a batch of predictions.'''
		B, T = y_true.shape
		loss = self.criterion(y_pred.view(B * T, -1), y_true.reshape(B * T))
		# log the loss
		return loss

	def calculate_metrics(self, y_pred: Tensor, y_true: Tensor, prefix: str):
		''' Calculate and log [accuracy] for a batch of predictions.'''
		B, T = y_true.shape
		# flatten the tensors
		y_pred = y_pred.view(B * T, -1).argmax(dim=-1)
		y_true = y_true.reshape(B * T)
		# calculate the metrics
		accuracy = (y_pred == y_true).float().mean()
		# log the metrics
		return {
			f'{prefix}_accuracy': accuracy
		}

	def training_step(self, batch: TransformerInputBatch, batch_idx: int):
		y_pred = self(batch.x_src, batch.x_tgt, batch.x_src_mask, batch.x_tgt_mask)
		metrics = self.calculate_metrics(y_pred, batch.y_tgt, 'train')
		metrics['train_loss'] = self.calculate_loss(y_pred, batch.y_tgt, 'train')
		self.log_dict(metrics, prog_bar=True)
		self.train_losses.append(metrics['train_loss'])
		return metrics['train_loss']

	def validation_step(self, batch: TransformerInputBatch, batch_idx: int):
		y_pred = self(batch.x_src, batch.x_tgt, batch.x_src_mask, batch.x_tgt_mask)
		metrics = self.calculate_metrics(y_pred, batch.y_tgt, 'val')
		metrics['val_loss'] = self.calculate_loss(y_pred, batch.y_tgt, 'val')
		self.log_dict(metrics, prog_bar=True)
		self.val_losses.append(metrics['val_loss'])
		return metrics['val_loss']

	#def test_step(self, batch: TransformerInputBatch, batch_idx: int):
	#	y_pred = self(batch.x_src, batch.x_tgt, batch.x_src_mask, batch.x_tgt_mask)
	#	self.calculate_metrics(y_pred, batch.y_tgt, 'test')
	#	return self.calculate_loss(y_pred, batch.y_tgt, 'test')

	def on_before_optimizer_step(self, optimizer):
		norms = grad_norm(self, 2)
		self.log_dict(norms)
	
	@torch.inference_mode()
	def translate(self, src: Tensor, bos_idx: int, eos_idx: int, temperature: float = 1.0, max_new_tokens: int = 1000):
		''' Generator function that translates a source sentence into a target sentence.
		
		Input:
			`src`: Tensor<Float>[T_in] input src tensor.
		
		Output:
			Tensor<Float>[T_out] output tgt tensor.
		'''
		# put self into eval mode
		self.eval()
		# init inputs
		src = src.to(self.device).unsqueeze(0) # (1, T)
		tgt = torch.tensor([bos_idx], dtype=torch.long, device=self.device).unsqueeze(0) # (1, 1)
		src_mask = torch.ones_like(src, dtype=torch.bool, device=self.device)
		for i in range(max_new_tokens):
			# update tgt mask
			tgt_mask = torch.ones_like(tgt, dtype=torch.bool, device=self.device)
			# get the predictions
			logits = self(src[:, -self.config.max_len:], tgt[:, -self.config.max_len:], src_mask, tgt_mask) # (1, T, C)
			# focus only on the last time step
			logits = logits[:, -1, :] #(1, C)
			logits /= temperature
			probs = F.softmax(logits, dim=-1)
			# sample from the distribution
			idx_next = torch.argmax(probs, dim=-1, keepdim=True) # (1, 1)
			# append sampled index to the running sequence
			tgt = torch.cat((tgt, idx_next), dim=1) # (1, T)
			# yield the current token
			token = int(idx_next[0].cpu().numpy())
			# print(f'{i}:', idx_next[0], token)
			# yield f'{idx_next[0]}: {token} \n'
			yield token
			# stop if the last token is the EOS token
			if token == eos_idx:
				break
		return tgt[0]

	def on_train_epoch_start(self):
		self.train_losses = []
	
	def on_validation_epoch_start(self):
		self.val_losses = []
	
	def on_train_epoch_end(self):
		loss = sum(self.train_losses) / len(self.train_losses)
		print(f'Epoch {self.trainer.current_epoch} train loss:', loss)

	def on_validation_epoch_end(self):
		loss = sum(self.val_losses) / len(self.val_losses)
		print(f'Epoch {self.trainer.current_epoch} val loss:', loss)

	def configure_optimizers(self):
			return torch.optim.Adam(self.parameters(), lr=0.0001, betas=(0.9, 0.98), eps=1e-9)

In [362]:
@dataclass
class TransformerInputBatch:
	''' A batch of training data for the Transformer model. '''
	x_src: torch.Tensor
	x_tgt: torch.Tensor
	x_src_mask: torch.Tensor
	x_tgt_mask: torch.Tensor
	y_tgt: torch.Tensor

# function to collate data samples into batch tensors
def collate_function(batch):
    src_batch, tgt_batch = [], []
    for src_sample, tgt_sample in batch:
        src_batch.append(text_transform[SRC_LANGUAGE](src_sample.rstrip("\n")))
        tgt_batch.append(text_transform[TGT_LANGUAGE](tgt_sample.rstrip("\n")))

    src_batch = pad_sequence(src_batch, padding_value=PAD_IDX, batch_first=True)
    tgt_batch = pad_sequence(tgt_batch, padding_value=PAD_IDX, batch_first=True)
    return TransformerInputBatch(src_batch, tgt_batch[:, :-1], src_batch != PAD_IDX, tgt_batch[:, :-1] != PAD_IDX, tgt_batch[:, 1:])

In [369]:
model_config = TransformerConfig(
    max_len = 256,
    src_vocab_size = SRC_VOCAB_SIZE,
    tgt_vocab_size = TGT_VOCAB_SIZE,
    n_blocks = 8,
    n_heads = 8,
    emb_dim = 512,
    dropout = 0.1,
    bias = False,
    weight_tying = False,
    use_grad_ckpt = False,
    pad_index = PAD_IDX,
    optimizer = 'AdamW',
    learning_rate = 0.0001,
    attention_type = 'vanilla')

model = Transformer(model_config)

In [370]:
from torch.utils.data import DataLoader

train_iter = Multi30k(split='train', language_pair=(SRC_LANGUAGE, TGT_LANGUAGE))
val_iter = Multi30k(split='valid', language_pair=(SRC_LANGUAGE, TGT_LANGUAGE))
train_dl = DataLoader(train_iter, batch_size=BATCH_SIZE, collate_fn=collate_function)
val_dl = DataLoader(val_iter, batch_size=BATCH_SIZE, collate_fn=collate_function)

trainer = pl.Trainer(accelerator='gpu', devices=1)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [371]:
trainer.fit(model, train_dl, val_dl)

  rank_zero_warn(
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name           | Type               | Params | In sizes                                           | Out sizes      
-----------------------------------------------------------------------------------------------------------------------------
0 | criterion      | CrossEntropyLoss   | 0      | ?                                                  | ?              
1 | src_embeddings | PosNTokEmbedding   | 9.8 M  | [1, 256]                                           | [1, 256, 512]  
2 | tgt_embeddings | PosNTokEmbedding   | 5.5 M  | [1, 256]                                           | [1, 256, 512]  
3 | encoder        | TransformerEncoder | 25.2 M | [[1, 256, 512], [1, 256]]                          | [1, 256, 512]  
4 | decoder        | TransformerDecoder | 33.6 M | [[1, 256, 512], [1, 256, 512], [1, 256], [1, 256]] | [1, 256, 512]  
5 | lm_head        | TransformerLMHead  | 5.5 M  | [1, 256, 512]                             

Sanity Checking: 0it [00:00, ?it/s]

Epoch 0 val loss: tensor(9.4134, device='cuda:0')


  rank_zero_warn(
  rank_zero_warn(


Training: 0it [00:00, ?it/s]



Validation: 0it [00:00, ?it/s]

Epoch 0 val loss: tensor(3.3234, device='cuda:0')
Epoch 0 train loss: tensor(4.3947, device='cuda:0', grad_fn=<DivBackward0>)


Validation: 0it [00:00, ?it/s]

Epoch 1 val loss: tensor(2.7482, device='cuda:0')
Epoch 1 train loss: tensor(2.9848, device='cuda:0', grad_fn=<DivBackward0>)


Validation: 0it [00:00, ?it/s]

Epoch 2 val loss: tensor(2.4780, device='cuda:0')
Epoch 2 train loss: tensor(2.5046, device='cuda:0', grad_fn=<DivBackward0>)


Validation: 0it [00:00, ?it/s]

Epoch 3 val loss: tensor(2.3084, device='cuda:0')
Epoch 3 train loss: tensor(2.1959, device='cuda:0', grad_fn=<DivBackward0>)


Validation: 0it [00:00, ?it/s]

Epoch 4 val loss: tensor(2.2112, device='cuda:0')
Epoch 4 train loss: tensor(1.9623, device='cuda:0', grad_fn=<DivBackward0>)


Validation: 0it [00:00, ?it/s]

Epoch 5 val loss: tensor(2.1280, device='cuda:0')
Epoch 5 train loss: tensor(1.7663, device='cuda:0', grad_fn=<DivBackward0>)


Validation: 0it [00:00, ?it/s]

Epoch 6 val loss: tensor(2.0790, device='cuda:0')
Epoch 6 train loss: tensor(1.5881, device='cuda:0', grad_fn=<DivBackward0>)


Validation: 0it [00:00, ?it/s]

Epoch 7 val loss: tensor(2.0730, device='cuda:0')
Epoch 7 train loss: tensor(1.4321, device='cuda:0', grad_fn=<DivBackward0>)


Validation: 0it [00:00, ?it/s]

Epoch 8 val loss: tensor(2.0732, device='cuda:0')
Epoch 8 train loss: tensor(1.2913, device='cuda:0', grad_fn=<DivBackward0>)


Validation: 0it [00:00, ?it/s]

Epoch 9 val loss: tensor(2.0558, device='cuda:0')
Epoch 9 train loss: tensor(1.1615, device='cuda:0', grad_fn=<DivBackward0>)


  rank_zero_warn("Detected KeyboardInterrupt, attempting graceful shutdown...")


: 

In [366]:
trainer.validate(model, val_dl)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Validation: 0it [00:00, ?it/s]

Epoch 1 val loss: tensor(3.0236, device='cuda:0')
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
Runningstage.validating metric      DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
      val_accuracy          0.2608773410320282
        val_loss            3.0202431678771973
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────


[{'val_accuracy': 0.2608773410320282, 'val_loss': 3.0202431678771973}]