# Transformer from Scratch



In [38]:
import torch
import torch.nn as nn
import numpy as np

In [40]:
import logging
logger = logging.getLogger("tensor_shapes")
handler = logging.StreamHandler()
formatter = logging.Formatter(
        '%(message)s')
handler.setFormatter(formatter)
logger.addHandler(handler)
logger.setLevel(1)

In [43]:
import inspect
def getclass():
    stack = inspect.stack()
    return stack[3][0].f_locals["self"].__class__

# A helper function to check how tensor sizes change
def log_size(tsr: torch.Tensor, name: str):
    cls = getclass()
    logger.log(level=cls.level, msg=f"[{cls.__name__}] {name} size={tsr.shape}")

SyntaxError: invalid syntax (<ipython-input-43-9dce765217ca>, line 9)

In [24]:
from enum import IntEnum
# Control how much debugging output we want
class TensorLoggingLevels(IntEnum):
    attention = 1
    attention_head = 2
    multihead_attention_block = 3
    enc_dec_block = 4
    enc_dec = 5

In [25]:
class Dim(IntEnum):
    batch = 0
    seq = 1
    feature = 2

# Components


### Scaled dot product attention

$$ \textrm{Attention}(Q, K, V) = \textrm{softmax}(\frac{QK^T}{\sqrt{d_k}})V $$

In [34]:
import math 

class ScaledDotProductAttention(nn.Module):
    level = TensorLoggingLevels.attention
    def __init__(self, dropout=0.1):
        super().__init__()
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, q, k, v, mask=None):
        d_k = k.size(-1)
        assert q.size(-1) == d_k
        
        # Compute the dot product between queries and keys for each batch and position in the sequence
        attn = torch.bmm(q, k.transpose(Dim.seq, Dim.feature))
        
        attn = attn / math.sqrt(d_k)
        
        attn = torch.exp(attn)
        
        log_size(attn, "attention weight") # Batch, Seq, Seq
        
        if mask is not None:
            attn = attn.masked_fill(mask, 0)
        attn = attn / attn.sum(dim=-1, keepdim=True)
        attn = self.dropout(attn)
        output = torch.bmm(attn, v) # (Batch, Seq, Feature)
        log_size(output, "attention output size") # (Batch, Seq, Seq)
        return output

In [35]:
attn = ScaledDotProductAttention()

In [36]:
q = torch.rand(5, 10, 20)
k = torch.rand(5, 10, 20)
v = torch.rand(5, 10, 20)

In [44]:
attn(q, k, v)

TypeError: log() missing 1 required positional argument: 'msg'