In [17]:
import torch
import torch.nn as nn
from torch.nn import functional as F
from torch.utils.checkpoint import checkpoint
from torch import Tensor
from typing import *
import pytorch_lightning as pl
import sentencepiece as sp
import math

GPU = torch.device('cuda')

import torch
from torch import Tensor
from torch import nn

# allow tf32
torch.backends.cuda.matmul.allow_tf32 = True

class PositionalEmbedding(nn.Module):
	''' Sinosuidal positional embedding. '''

	def __init__(self, emb_dim: int, max_len: int):
		super().__init__()

		encoding = torch.zeros(max_len + 2, emb_dim, requires_grad=False)
		pos = torch.arange(0.0, max_len + 2, dtype=torch.float).unsqueeze(dim=1)
		_2i = torch.arange(0, emb_dim, step=2).float()

		encoding[:, 0::2] = torch.sin(pos / (10000 ** (_2i / emb_dim)))
		encoding[:, 1::2] = torch.cos(pos / (10000 ** (_2i / emb_dim)))

		self.register_buffer('encoding', encoding)

	def forward(self, x: Tensor):
		return self.encoding[:x.size(0)]

class PosNTokEmbedding(nn.Module):
	''' Apply learnable token and sinosuidal position embeddings to input tokens. '''

	def __init__(self, vocab_size: int, emb_dim: int, max_len: int):
		super().__init__()
		self.token_embedding_table = nn.Embedding(vocab_size, emb_dim)
		self.position_embedding_table = PositionalEmbedding(emb_dim, max_len)
		self.max_len = max_len
	
	def forward(self, x: Tensor):
		tok_embd = self.token_embedding_table(x)
		pos_embd = self.position_embedding_table(torch.arange(x.size(1), device=x.device))
		return tok_embd + pos_embd
	
import torch
from torch import nn
from torch import Tensor
import math

class MultiHeadSelfAttention(nn.Module):

	''' Multi-head self attention.
	Implements a somewhat optimized version of the self attention by combining the q, k, v projections.
	
	Inputs:
		`x`: Tensor<Float>[B, T, C] input tensor.
		`tok_mask`: Tensor<Bool>[B, T] per-token mask applied to the `x`, false is masked out, true is preserved - masks both keys and queries.

	Outputs:
		Tensor<Float>[B, T, C] output tensor.
	'''

	def __init__(self, n_heads: int, emb_dim: int, dropout: float, bias: bool = False, is_causal: bool = False):
		super().__init__()
		self.is_causal = is_causal
		self.n_heads = n_heads
		self.emb_dim = emb_dim
		self.attn_dropout = nn.Dropout(dropout)
		self.resid_dropout = nn.Dropout(dropout)
		# combine q, k, v projections for efficiency
		self.qkv_projection = nn.Linear(emb_dim, 3 * emb_dim, bias=bias)
		# output projection
		self.c_proj = nn.Linear(emb_dim, emb_dim, bias=bias)

	def forward(self, x: Tensor, tok_mask: Tensor):
		B, T, C = x.shape
		# proj q, k, v for all heads
		# the heads are treated as a batch dimension
		q, k, v = self.qkv_projection(x).split(self.emb_dim, dim=2)
		q = q.view(B, T, self.n_heads, C // self.n_heads).transpose(1, 2)
		k = k.view(B, T, self.n_heads, C // self.n_heads).transpose(1, 2)
		v = v.view(B, T, self.n_heads, C // self.n_heads).transpose(1, 2)
		# compute attention
		att_weights = (q @ k.transpose(-2, -1)) / math.sqrt(k.size(-1))
		mask = tok_mask.view(B, 1, T) # (B, 1, T)
		mask = mask.tile(1, T, 1) # (B, T, T)
		mask = mask & mask.transpose(-2, -1) # (B, T, T)
		mask = mask.view(B, 1, T, T) # (B, 1, T, T)
		if self.is_causal:
			causal_mask = torch.tril(torch.ones(T, T, dtype=torch.bool, device=x.device))
			mask = mask & causal_mask[None, None, :, :]
		att_weights = att_weights.masked_fill(mask == 0, -1e9)
		att_weights = nn.functional.softmax(att_weights, dim=-1)
		att_weights = self.attn_dropout(att_weights)
		y = att_weights @ v
		# combine heads
		y = y.transpose(1, 2).contiguous().view(B, T, C)
		y = self.resid_dropout(self.c_proj(y))
		return y

class MultiHeadCrossAttention(nn.Module):

	''' Multi-head cross attention.
	Implements a somewhat optimized version of the cross attention by combining the k, v projections.
	
	Inputs:
		`x_q`: Tensor<Float>[B, T_q, C] query input tensor.
		`x_kv`: Tensor<Float>[B, T_kv, C] key and value input tensor.
		`q_tok_mask`: Tensor<Bool>[B, T_q] mask applied to the `x_q`, false is masked out, true is preserved - applies to q only.
		`kv_tok_mask`: Tensor<Bool>[B, T_kv] mask applied to the `x_kv`, false is masked out, true is preserved - applies to k and v.

	Outputs:
		Tensor<Float>[B, T_q, C] output tensor.
	'''

	def __init__(self, n_heads: int, emb_dim: int, dropout: float, bias: bool = False):
		super().__init__()
		self.n_heads = n_heads
		self.emb_dim = emb_dim
		self.attn_dropout = nn.Dropout(dropout)
		self.resid_dropout = nn.Dropout(dropout)
		self.q_projection = nn.Linear(emb_dim, emb_dim, bias=bias)
		# combine k, v projections for efficiency
		self.kv_projection = nn.Linear(emb_dim, 2 * emb_dim, bias=bias)
		# output projection
		self.c_proj = nn.Linear(emb_dim, emb_dim, bias=bias)

	def forward(self, x_q: Tensor, x_kv: Tensor, q_tok_mask: Tensor, kv_tok_mask: Tensor):
		# proj query for all heads
		B, T_q, C = x_q.shape
		q = self.q_projection(x_q)
		q = q.view(B, T_q, self.n_heads, C // self.n_heads).transpose(1, 2)
		# proj key & value for all heads
		B, T_kv, C = x_kv.shape
		k, v = self.kv_projection(x_kv).split(self.emb_dim, dim=2)
		k = k.view(B, T_kv, self.n_heads, C // self.n_heads).transpose(1, 2)
		v = v.view(B, T_kv, self.n_heads, C // self.n_heads).transpose(1, 2)
		# compute attention
		att_weights = (q @ k.transpose(-2, -1)) / math.sqrt(k.size(-1))
		# merge masks
		q_tok_mask = q_tok_mask.unsqueeze(2) # (N, T_q, 1)
		kv_tok_mask = kv_tok_mask.unsqueeze(1) # (N, 1, T_kv)
		attn_mask = q_tok_mask & kv_tok_mask
		# apply mask
		att_weights = att_weights.masked_fill(attn_mask.unsqueeze(1) == 0, -1e9)
		att_weights = nn.functional.softmax(att_weights, dim=-1)
		att_weights = self.attn_dropout(att_weights)
		y = att_weights @ v
		# combine heads
		y = y.transpose(1, 2).contiguous().view(B, T_q, C)
		y = self.resid_dropout(self.c_proj(y))
		return y

import torch
from torch import nn
from torch import Tensor
from torch.utils.checkpoint import checkpoint
from importlib import import_module

class TransformerFeedFoward(nn.Module):

	def __init__(self, emb_dim: int, dropout: float):
		super().__init__()
		self.net = nn.Sequential(
			nn.Linear(emb_dim, 4 * emb_dim),
			nn.GELU(approximate='tanh'),
			nn.Linear(4 * emb_dim, emb_dim),
		)
		if dropout:
			self.net.append(nn.Dropout(dropout))

	def forward(self, x: Tensor):
		return self.net(x)

class TransformerEncoderBlock(nn.Module):

	def __init__(self, n_heads: int, emb_dim: int, dropout: float, bias: bool = False, attention_type: str = 'vanilla'):
		super().__init__()
		sa_class = MultiHeadSelfAttention
		self.sa_module = sa_class(n_heads, emb_dim, dropout, bias)
		self.fw_module = TransformerFeedFoward(emb_dim, dropout)
		self.ln1 = nn.LayerNorm(emb_dim)
		self.ln2 = nn.LayerNorm(emb_dim)

	
	def forward(self, src: Tensor, src_mask: Tensor):
		x = src + self.sa_module(self.ln1(src), src_mask)
		x = x + self.fw_module(self.ln2(src))
		return x

class TransformerEncoder(nn.Module):

	def __init__(self, n_blocks: int, n_heads: int, emb_dim: int, dropout: float, bias: bool = False, use_grad_ckpt: bool = False, attention_type: str = 'vanilla'):
		super().__init__()
		self.blocks = nn.ModuleList([TransformerEncoderBlock(n_heads, emb_dim, dropout, bias, attention_type) for _ in range(n_blocks)])
		self.use_grad_ckpt = use_grad_ckpt
	
	def forward(self, src: Tensor, src_mask: Tensor):
		x = src
		for block in self.blocks:
			if self.use_grad_ckpt:
				forward = lambda *inputs: block(*inputs)
				x = checkpoint(forward, x, src_mask, preserve_rng_state=False)
			else:
				x = block(x, src_mask)
		return x

class TransformerDecoderBlock(nn.Module):

	def __init__(self, n_heads: int, emb_dim: int, dropout: float, bias: bool = False, attention_type: str = 'vanilla'):
		super().__init__()
		sa_class = MultiHeadSelfAttention
		ca_class = MultiHeadCrossAttention
		self.sa_module = sa_class(n_heads, emb_dim, dropout, bias, is_causal=True)
		self.ca_module = ca_class(n_heads, emb_dim, dropout, bias)
		self.fw_module = TransformerFeedFoward(emb_dim, dropout)
		self.ln1 = nn.LayerNorm(emb_dim)
		self.ln2 = nn.LayerNorm(emb_dim)
		self.ln3 = nn.LayerNorm(emb_dim)
	
	def forward(self, src: Tensor, tgt: Tensor, src_mask: Tensor, tgt_mask: Tensor):
		x = tgt + self.sa_module(self.ln1(tgt), tgt_mask)
		x = x + self.ca_module(self.ln2(tgt), self.ln2(src), tgt_mask, src_mask)
		x = x + self.fw_module(self.ln3(x))
		return x

class TransformerDecoder(nn.Module):

	def __init__(self, n_blocks: int, n_heads: int, emb_dim: int, dropout: float, bias: bool = False, use_grad_ckpt: bool = False, attention_type: str = 'vanilla'):
		super().__init__()
		self.blocks = nn.ModuleList([TransformerDecoderBlock(n_heads, emb_dim, dropout, bias, attention_type) for _ in range(n_blocks)])
		self.use_grad_ckpt = use_grad_ckpt
	
	def forward(self, src: Tensor, tgt: Tensor, src_mask: Tensor, tgt_mask: Tensor):
		x = tgt
		for block in self.blocks:
			if self.use_grad_ckpt:
				forward = lambda *inputs: block(*inputs)
				x = checkpoint(forward, src, x, src_mask, tgt_mask, preserve_rng_state=False)
			else:
				x = block(src, x, src_mask, tgt_mask)
		return x

class TransformerLMHead(nn.Module):

	def __init__(self, emb_dim: int, tgt_vocab_size: int):
		super().__init__()
		self.ln = nn.LayerNorm(emb_dim)
		self.logits_head = nn.Linear(emb_dim, tgt_vocab_size, bias=False)
	
	def forward(self, x: Tensor):
		x = self.ln(x)
		x = self.logits_head(x)
		return x
import torch
import torch.nn as nn
from torch.nn import functional as F
from torch import Tensor
import pytorch_lightning as pl
from dataclasses import dataclass
from pytorch_lightning.utilities import grad_norm

### CONFIG ###

@dataclass
class TransformerConfig:
    max_len: int
    src_vocab_size: int
    tgt_vocab_size: int
    n_blocks: int
    n_heads: int
    emb_dim: int
    dropout: float
    bias: bool
    weight_tying: bool
    use_grad_ckpt: bool
    pad_index: int
    optimizer: str
    learning_rate: float
    attention_type: str = 'vanilla'

### INPUT ###

@dataclass
class TransformerInputBatch:
	''' A batch of training data for the Transformer model. '''
	x_src: torch.Tensor
	x_tgt: torch.Tensor
	x_src_mask: torch.Tensor
	x_tgt_mask: torch.Tensor
	y_tgt: torch.Tensor

### MODEL ###

class Transformer(pl.LightningModule):

	def __init__(self, config: TransformerConfig):
		super().__init__()
		self.save_hyperparameters()
		self.config = config
		self.learning_rate = config.learning_rate
		self.attention_type = config.attention_type
		self.criterion = nn.CrossEntropyLoss(ignore_index=config.pad_index)
		# setup sample input for tracing
		self.example_input_array = (torch.zeros(1, config.max_len, dtype=torch.long),
			      					torch.zeros(1, config.max_len, dtype=torch.long),
									torch.zeros(1, config.max_len, dtype=torch.long),
									torch.zeros(1, config.max_len, dtype=torch.long))
		# model parts
		self.src_embeddings = PosNTokEmbedding(config.src_vocab_size, config.emb_dim, config.max_len)
		self.tgt_embeddings = PosNTokEmbedding(config.tgt_vocab_size, config.emb_dim, config.max_len)
		self.encoder = TransformerEncoder(config.n_blocks, config.n_heads, config.emb_dim, config.dropout, config.bias, config.use_grad_ckpt, config.attention_type)
		self.decoder = TransformerDecoder(config.n_blocks, config.n_heads, config.emb_dim, config.dropout, config.bias, config.use_grad_ckpt, config.attention_type)
		self.lm_head = TransformerLMHead(config.emb_dim, config.tgt_vocab_size)
		# weight tying
		if config.weight_tying:
			self.tgt_embeddings.token_embedding_table.weight = self.lm_head.logits_head.weight

	def forward(self, src: Tensor, tgt: Tensor, src_tok_mask: Tensor, tgt_tok_mask: Tensor):
		''' Forward pass through the model.'''
		enc = self.encoder(self.src_embeddings(src), src_tok_mask)
		dec = self.decoder(enc, self.tgt_embeddings(tgt), src_tok_mask, tgt_tok_mask)
		logits = self.lm_head(dec)
		return logits
	
	def calculate_loss(self, y_pred: Tensor, y_true: Tensor, prefix: str):
		''' Calculate and log loss for a batch of predictions.'''
		B, T = y_true.shape
		loss = self.criterion(y_pred.view(B * T, -1), y_true.reshape(B * T))
		# log the loss
		return loss

	def calculate_metrics(self, y_pred: Tensor, y_true: Tensor, prefix: str):
		''' Calculate and log [accuracy] for a batch of predictions.'''
		B, T = y_true.shape
		# flatten the tensors
		y_pred = y_pred.view(B * T, -1).argmax(dim=-1)
		y_true = y_true.reshape(B * T)
		# calculate the metrics
		accuracy = (y_pred == y_true).float().mean()
		# log the metrics
		return {
			f'{prefix}_accuracy': accuracy
		}

	def training_step(self, batch: TransformerInputBatch, batch_idx: int):
		y_pred = self(batch.x_src, batch.x_tgt, batch.x_src_mask, batch.x_tgt_mask)
		metrics = self.calculate_metrics(y_pred, batch.y_tgt, 'train')
		metrics['train_loss'] = self.calculate_loss(y_pred, batch.y_tgt, 'train')
		self.log_dict(metrics, prog_bar=True)
		self.train_losses.append(metrics['train_loss'])
		return metrics['train_loss']

	def validation_step(self, batch: TransformerInputBatch, batch_idx: int):
		y_pred = self(batch.x_src, batch.x_tgt, batch.x_src_mask, batch.x_tgt_mask)
		metrics = self.calculate_metrics(y_pred, batch.y_tgt, 'val')
		metrics['val_loss'] = self.calculate_loss(y_pred, batch.y_tgt, 'val')
		self.log_dict(metrics, prog_bar=True)
		self.val_losses.append(metrics['val_loss'])
		return metrics['val_loss']

	#def test_step(self, batch: TransformerInputBatch, batch_idx: int):
	#	y_pred = self(batch.x_src, batch.x_tgt, batch.x_src_mask, batch.x_tgt_mask)
	#	self.calculate_metrics(y_pred, batch.y_tgt, 'test')
	#	return self.calculate_loss(y_pred, batch.y_tgt, 'test')

	def on_before_optimizer_step(self, optimizer):
		norms = grad_norm(self, 2)
		self.log('grad_2.0_norm_total', norms['grad_2.0_norm_total'], prog_bar=True)
	
	@torch.inference_mode()
	def translate(self, src: Tensor, bos_idx: int, eos_idx: int, temperature: float = 1.0, max_new_tokens: int = 1000):
		''' Generator function that translates a source sentence into a target sentence.
		
		Input:
			`src`: Tensor<Float>[T_in] input src tensor.
		
		Output:
			Tensor<Float>[T_out] output tgt tensor.
		'''
		# put self into eval mode
		self.eval()
		# init inputs
		src = src.to(self.device).unsqueeze(0) # (1, T)
		tgt = torch.tensor([bos_idx], dtype=torch.long, device=self.device).unsqueeze(0) # (1, 1)
		src_mask = torch.ones_like(src, dtype=torch.bool, device=self.device)
		for i in range(max_new_tokens):
			# update tgt mask
			tgt_mask = torch.ones_like(tgt, dtype=torch.bool, device=self.device)
			# get the predictions
			logits = self(src[:, -self.config.max_len:], tgt[:, -self.config.max_len:], src_mask, tgt_mask) # (1, T, C)
			# focus only on the last time step
			logits = logits[:, -1, :] #(1, C)
			logits /= temperature
			probs = F.softmax(logits, dim=-1)
			# sample from the distribution
			idx_next = torch.argmax(probs, dim=-1, keepdim=True) # (1, 1)
			# append sampled index to the running sequence
			tgt = torch.cat((tgt, idx_next), dim=1) # (1, T)
			# yield the current token
			token = int(idx_next[0].cpu().numpy())
			# print(f'{i}:', idx_next[0], token)
			# yield f'{idx_next[0]}: {token} \n'
			yield token
			# stop if the last token is the EOS token
			if token == eos_idx:
				break
		return tgt[0]

	def on_train_epoch_start(self):
		self.train_losses = []
	
	def on_validation_epoch_start(self):
		self.val_losses = []
	
	def on_train_epoch_end(self):
		loss = sum(self.train_losses) / len(self.train_losses)
		print(f'Epoch {self.trainer.current_epoch} train loss:', loss)

	def on_validation_epoch_end(self):
		loss = sum(self.val_losses) / len(self.val_losses)
		print(f'Epoch {self.trainer.current_epoch} val loss:', loss)

	def configure_optimizers(self):
			return torch.optim.Adam(self.parameters(), lr=0.0001, betas=(0.9, 0.98), eps=1e-9)

In [18]:
@dataclass
class TransformerInputBatch:
	''' A batch of training data for the Transformer model. '''
	x_src: torch.Tensor
	x_tgt: torch.Tensor
	x_src_mask: torch.Tensor
	x_tgt_mask: torch.Tensor
	y_tgt: torch.Tensor

In [19]:
import torch
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator
from torchtext.datasets import multi30k, Multi30k
from torch.nn.utils.rnn import pad_sequence
from typing import Iterable, List
from dataclasses import dataclass

''' Currently only de -> en is supported as lots of stuff are hardcoded. '''

# Define special symbols and indices
UNK_IDX, BOS_IDX, EOS_IDX, PAD_IDX = 0, 1, 2, 3

# helper function to club together sequential operations
def sequential_transforms(*transforms):
    def func(txt_input):
        for transform in transforms:
            txt_input = transform(txt_input)
        return txt_input
    return func

# function to add BOS/EOS and create tensor for input sequence indices
def tensor_transform(token_ids: List[int]):
    return torch.cat((torch.tensor([BOS_IDX]),
                      torch.tensor(token_ids),
                      torch.tensor([EOS_IDX])))

# helper function to yield list of tokens
def yield_tokens(data_iter: Iterable, language: str) -> List[str]:
    language_index = {'de': 0, 'en': 1}

    for data_sample in data_iter:
        yield TranslationDatasetSpacyMulti30K.token_transform[language](data_sample[language_index[language]])

@dataclass
class TranslationDatasetSpacyMulti30KConfig:
  src_language: str
  tgt_language: str
  split: str

class TranslationDatasetSpacyMulti30K(torch.utils.data.Dataset):

    _INIT_RESOURCES_DONE = False
    token_transform = None
    vocab_transform = None
    text_transform = None

    def __init__(self, config: TranslationDatasetSpacyMulti30KConfig):
        TranslationDatasetSpacyMulti30K._init_resources()
        super().__init__()
        self.src_lang = config.src_language
        self.tgt_lang = config.tgt_language
        self.split = config.split
        self.dataset = list(Multi30k(split=self.split, language_pair=(self.src_lang, self.tgt_lang)))

    @staticmethod
    def _init_resources():
        if TranslationDatasetSpacyMulti30K._INIT_RESOURCES_DONE:
            return

        # We need to modify the URLs for the dataset since the links to the original dataset are broken
        # Refer to https://github.com/pytorch/text/issues/1756#issuecomment-1163664163 for more info
        multi30k.URL["train"] = "https://raw.githubusercontent.com/neychev/small_DL_repo/master/datasets/Multi30k/training.tar.gz"
        multi30k.URL["valid"] = "https://raw.githubusercontent.com/neychev/small_DL_repo/master/datasets/Multi30k/validation.tar.gz"

        # Place-holders
        token_transform = {}
        vocab_transform = {}

        SRC_LANG = 'de'
        TGT_LANG = 'en'

        token_transform[SRC_LANG] = get_tokenizer('spacy', language='de_core_news_sm')
        token_transform[TGT_LANG] = get_tokenizer('spacy', language='en_core_web_sm')

        TranslationDatasetSpacyMulti30K.token_transform = token_transform

        # Make sure the tokens are in order of their indices to properly insert them in vocab
        special_symbols = ['<unk>', '<bos>', '<eos>', '<pad>']

        for ln in [SRC_LANG, TGT_LANG]:
            # Training data Iterator
            train_iter = Multi30k(split='train', language_pair=(SRC_LANG, TGT_LANG))
            # Create torchtext's Vocab object
            vocab_transform[ln] = build_vocab_from_iterator(yield_tokens(train_iter, ln),
                                                            min_freq=1,
                                                            specials=special_symbols,
                                                            special_first=True)
            
        # Set UNK_IDX as the default index. This index is returned when the token is not found.
        # If not set, it throws RuntimeError when the queried token is not found in the Vocabulary.
        for ln in [SRC_LANG, TGT_LANG]:
            vocab_transform[ln].set_default_index(UNK_IDX)

        TranslationDatasetSpacyMulti30K.vocab_transform = vocab_transform

        # src and tgt language text transforms to convert raw strings into tensors indices
        text_transform = {}
        for ln in [SRC_LANG, TGT_LANG]:
            text_transform[ln] = sequential_transforms(token_transform[ln], #Tokenization
                                                    vocab_transform[ln], #Numericalization
                                                    tensor_transform) # Add BOS/EOS and create tensor
        
        TranslationDatasetSpacyMulti30K.text_transform = text_transform
        TranslationDatasetSpacyMulti30K._INIT_RESOURCES_DONE = True

        # print vocab sizes
        print(f"Vocab size for {SRC_LANG}: {len(vocab_transform[SRC_LANG])}")
        print(f"Vocab size for {TGT_LANG}: {len(vocab_transform[TGT_LANG])}")

    def __len__(self):
        if self.split == 'train':
            return 29000
        elif self.split == 'valid':
            return 1014
        
    def __getitem__(self, idx):
        src, dst = self.dataset[idx]
        src = TranslationDatasetSpacyMulti30K.text_transform['de'](src.rstrip("\n"))
        dst = TranslationDatasetSpacyMulti30K.text_transform['en'](dst.rstrip("\n"))
        return src, dst

    @staticmethod
    def get_collate_function():
        def collate_fn(batch):
            src_batch, tgt_batch = zip(*batch)
            src_batch = pad_sequence(src_batch, padding_value=PAD_IDX, batch_first=True)
            tgt_batch = pad_sequence(tgt_batch, padding_value=PAD_IDX, batch_first=True)
            return TransformerInputBatch(src_batch, tgt_batch[:, :-1].clone(), src_batch != PAD_IDX, tgt_batch[:, :-1] != PAD_IDX, tgt_batch[:, 1:].clone())
        return collate_fn

In [26]:
model_config = TransformerConfig(
    max_len = 256,
    src_vocab_size = 19214,
    tgt_vocab_size = 10837,
    n_blocks = 8,
    n_heads = 8,
    emb_dim = 512,
    dropout = 0.1,
    bias = False,
    weight_tying = True,
    use_grad_ckpt = True,
    pad_index = PAD_IDX,
    optimizer = 'AdamW',
    learning_rate = 0.0001,
    attention_type = 'vanilla')

model = Transformer(model_config)

In [27]:
from torch.utils.data import DataLoader

train_iter = TranslationDatasetSpacyMulti30K(TranslationDatasetSpacyMulti30KConfig('de', 'en', 'train'))
val_iter = TranslationDatasetSpacyMulti30K(TranslationDatasetSpacyMulti30KConfig('de', 'en', 'valid'))
train_dl = DataLoader(train_iter, batch_size=64, collate_fn=TranslationDatasetSpacyMulti30K.get_collate_function())
val_dl = DataLoader(val_iter, batch_size=128, collate_fn=TranslationDatasetSpacyMulti30K.get_collate_function())

trainer = pl.Trainer(accelerator='gpu', devices=1,
                     accumulate_grad_batches=4,
                     gradient_clip_val=1.0)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [28]:
trainer.fit(model, train_dl, val_dl)

  rank_zero_warn(
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name           | Type               | Params | In sizes                                           | Out sizes      
-----------------------------------------------------------------------------------------------------------------------------
0 | criterion      | CrossEntropyLoss   | 0      | ?                                                  | ?              
1 | src_embeddings | PosNTokEmbedding   | 9.8 M  | [1, 256]                                           | [1, 256, 512]  
2 | tgt_embeddings | PosNTokEmbedding   | 5.5 M  | [1, 256]                                           | [1, 256, 512]  
3 | encoder        | TransformerEncoder | 25.2 M | [[1, 256, 512], [1, 256]]                          | [1, 256, 512]  
4 | decoder        | TransformerDecoder | 33.6 M | [[1, 256, 512], [1, 256, 512], [1, 256], [1, 256]] | [1, 256, 512]  
5 | lm_head        | TransformerLMHead  | 5.5 M  | [1, 256, 512]                             

Sanity Checking: 0it [00:00, ?it/s]

Epoch 0 val loss: tensor(9.3873, device='cuda:0')


  rank_zero_warn(
  rank_zero_warn(


Training: 0it [00:00, ?it/s]

  rank_zero_warn("Detected KeyboardInterrupt, attempting graceful shutdown...")


: 

In [None]:
trainer.validate(model, val_dl)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Validation: 0it [00:00, ?it/s]

Epoch 1 val loss: tensor(3.0236, device='cuda:0')
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
Runningstage.validating metric      DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
      val_accuracy          0.2608773410320282
        val_loss            3.0202431678771973
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────


[{'val_accuracy': 0.2608773410320282, 'val_loss': 3.0202431678771973}]