In [1]:
from enum import IntEnum
import inspect
import logging
import math
import numpy as np

import torch
from torch import nn

In [2]:
logger = logging.getLogger("tensor_shapes")
handler = logging.StreamHandler()
formatter = logging.Formatter('%(message)s')
handler.setFormatter(formatter)
logger.addHandler(handler)
logger.setLevel(1)

In [3]:
def getclass():
    stack = inspect.stack()
    return stack[3][0].f_locals["self"].__class__

def log_size(tsr: torch.Tensor, name:str):
    cls = getclass()
    logger.log(level=cls.level, msg=f"[{cls.__name__}] {name} size={tsr.shape}")

In [4]:
class TensorLoggingLevels(IntEnum):
    attention = 1
    attention_head = 2
    multihead_attention_block = 3
    enc_dec_block = 4
    enc_dec = 5

In [5]:
class Dim(IntEnum):
    batch = 0
    seq = 1
    feature = 2

In [6]:
class ScaledDotProductAttention(nn.Module):
    level = TensorLoggingLevels.attention
    
    def __init__(self, dropout=0.1):
        super().__init__()
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, q, k, v, mask=None):
        d_k = k.size(-1)
        assert q.size(-1) == d_k
        
        attn = torch.bmm(q, k.transpose(Dim.seq, Dim.feature))
        attn = attn / math.sqrt(d_k)
        atten = torch.exp(attn)
        log_size(attn, "attention weight")
        
        if mask is not None:
            attn = attn.masked_fill(mask, 0)
        attn /= attn.sum(dim=-1, keepdim=True)
        attn = self.dropout(attn)
        output = torch.bmm(attn, v)
        log_size(output, "attention output size")
        return output

In [7]:
attn = ScaledDotProductAttention()

In [8]:
q = torch.rand(5, 10, 20)
k = torch.rand(5, 10, 20)
v = torch.rand(5, 10, 20)

In [10]:
# attn(q, k, v)

In [11]:
class AttentionHead(nn.Module):
    level = TensorLoggingLevels.attention_head
    
    def __init__(self, d_model, d_feature, dropout=0.1):
        super().__init__()
        
        self.attn = ScaledDotProductAttention(dropout)
        self.query_tfm = nn.Linear(d_model, d_feature)
        self.key_tfm = nn.Linear(d_model, d_feature)
        self.value_tfm = nn.Linear(d_model, d_feature)
        
    def forward(self, queries, keys, values, mask=None):
        Q = self.query_tfm(queries)
        K = self.key_tfm(keys)
        V = self.value_tfm(values)
        log_size(Q, "queries, keys, vals")
        x = self.attn(Q, K, V)
        return x

In [12]:
# attn_head = AttentionHead(20, 20)
# attn_head(q, k, v)

In [13]:
logger.setLevel(TensorLoggingLevels.attention_head)

In [14]:
class MultiHeadAttention(nn.Module):
    level = TensorLoggingLevels.multihead_attention_block
    
    def __init__(self, d_model, d_feature, n_heads, dropout=0.1):
        super().__init__()
        self.d_model = d_model
        self.d_feature = d_feature
        self.n_heads = n_heads
        assert d_model == d_feature * n_heads
        
        self.attn_heads = nn.ModuleList([
            AttentionHead(d_model, d_feature, dropout) for _ in range(n_heads)
        ])
        self.projection = nn.Linear(d_feature * n_heads, d_model)
        
    def forward(self, queries, keys, values, mask=None):
        log_size(queries, "Input queries")
        x = [
            attn(queries, keys, values, mask=mask) for i, attn in enumerate(self.attn_heads)
        ]
        log_size(x[0], "output of single head")
        
        x = torch.cat(x, dim=Dim.feature)
        log_size(x, "concatenated output")
        x = self.projection(x)
        log_size(x, "projected output")
        return x

In [15]:
# heads = MultiHeadAttention(20*8, 20, 8)
# heads(q.repeat(1,1,8), k.repeat(1,1,8), v.repeat(1,1,8))

In [16]:
logger.setLevel(TensorLoggingLevels.multihead_attention_block)

In [17]:
class LayerNorm(nn.Module):
    def __init__(self, d_model, eps=1e-8):
        super(LayerNorm, self).__init__()
        self.gamma = nn.Parameter(torch.ones(d_model))
        self.beta = nn.Parameter(torch.zeros(d_model))
        self.eps = eps

    def forward(self, x):
        mean = x.mean(-1, keepdim=True)
        std = x.std(-1, keepdim=True)
        return self.gamma * (x - mean) / (std + self.eps) + self.beta

In [18]:
class EncoderBlock(nn.Module):
    level = TensorLoggingLevels.enc_dec_block
    
    def __init__(self, d_model=512, d_feature=64, d_ff=2048, n_heads=8, dropout=0.1):
        super().__init__()
        
        self.attn_head = MultiHeadAttention(d_model, d_feature, n_heads, dropout)
        self.layer_norm1 = LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)
        self.position_wise_feed_forward = nn.Sequential(
            nn.Linear(d_model, d_ff),
            nn.ReLU(),
            nn.Linear(d_ff, d_model),
        )
        self.layer_norm2 = LayerNorm(d_model)
    
    def forward(self, x, mask=None):
        log_size(x, "Encoder block input")
        att = self.attn_head(x, x, x, mask=mask)
        log_size(x, "Attention output")
        x += self.dropout(self.layer_norm1(att))
        pos = self.position_wise_feed_forward(x)
        log_size(x, "Feedforward output")
        x += self.dropout(self.layer_norm2(pos))
        log_size(x, "Encoder size output")
        return x

In [19]:
enc = EncoderBlock()

In [20]:
# enc(torch.rand(5, 10,  512))

In [21]:
class TransformerEncoder(nn.Module):
    level = TensorLoggingLevels.enc_dec
    
    def __init__(self, n_blocks=6, d_model=512, n_heads=8, d_ff=2048, dropout=0.1):
        super().__init__()
        
        self.encoders = nn.ModuleList([
            EncoderBlock(
                d_model=d_model, 
                d_feature=d_model//n_heads, 
                d_ff=d_ff, 
                dropout=dropout
            ) for _ in range(n_blocks)
        ])
        
    def forward(self, x:torch.FloatTensor, mask=None):
        for encoder in self.encoders:
            x = encoder(x)
        return x

In [22]:
class DecoderBlock(nn.Module):
    level = TensorLoggingLevels.enc_dec_block
    
    def __init__(self, d_model=512, d_feature=64, d_ff=2048, n_heads=8, dropout=0.1):
        super().__init__()
        
        self.masked_attn_head = MultiHeadAttention(d_model, d_feature, n_heads, dropout)
        self.attn_head = MultiHeadAttention(d_model, d_feature, n_heads, dropout)
        self.position_wise_feed_forward = nn.Sequential(
            nn.Linear(d_model, d_ff),
            nn.ReLU(),
            nn.Linear(d_ff, d_model)
        )
        self.layer_norm1 = LayerNorm(d_model)
        self.layer_norm2 = LayerNorm(d_model)
        self.layer_norm3 = LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x, enc_out, src_mask=None, tgt_mask=None):
        att = self.masked_attn_head(x, x, x, mask=src_mask)
        x += self.dropout(self.layer_norm1(att))
        att = self.attn_head(queries=x, keys=enc_out, values=enc_out, mask=tgt_mask)
        x += self.dropout(self.layer_norm2(att))
        pos = self.position_wise_feed_forward(x)
        x += self.dropout(self.layer_norm3(pos))
        return x

In [23]:
dec = DecoderBlock()

In [24]:
# dec(torch.rand(5, 10, 512), enc(torch.rand(5, 10, 512)))

In [25]:
class TransformerDecoder(nn.Module):
    level = TensorLoggingLevels.enc_dec
    
    def __init__(self, n_blocks=6, d_model=512, d_feature=64, d_ff=2048, n_heads=8, 
                 dropout=0.1):
        super().__init__()
        
        self.position_embedding = PositionalEmbedding(d_model)
        self.decoders = nn.ModuleList([
            DecoderBlock(d_model=d_model, d_feature=d_model//n_heads, d_ff=d_ff, 
                         dropout=dropout) for _ in range(n_blocks)
        ])
    
    def forward(self, x:torch.FloatTensor, enc_out:torch.FloatTensor, src_mask=None, 
                tgt_mask=None):
        for decoder in self.decoders:
            x = decoder(x, enc_out, src_mask=src_mask, tgt_mask=tgt_mask)
        return x

In [26]:
class PositionalEmbedding(nn.Module):
    level = 1
    
    def __init__(self, d_model, max_len=512):
        super().__init__()
        
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len).unsqueeze(1).float()
        div_term = torch.exp(torch.arange(0,d_model,2).float() * -(math.log(10000.0)/d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)
        self.weight = nn.Parameter(pe, requires_grad=False)
        
    def forward(self, x):
        return self.weight[:, :x.size(1), :]

In [27]:
class WordPositionEmbedding(nn.Module):
    level = 1
    
    def __init__(self, vocab_size, d_model=512):
        super().__init__()
        self.word_embedding = nn.Embedding(vocab_size, d_model)
        self.position_embedding = PositionalEmbedding(d_model)
        
    def forward(self, x:torch.LongTensor, mask=None)->torch.FloatTensor:
        return self.word_embedding(x) + self.position_embedding(x)

In [28]:
emb = WordPositionEmbedding(1000)
encoder = TransformerEncoder()

In [29]:
# encoder(emb(torch.randint(1000, (5, 30))))

In [30]:
logger.setLevel(TensorLoggingLevels.enc_dec_block)

In [31]:
emb = WordPositionEmbedding(1000)
encoder = TransformerEncoder()
decoder = TransformerDecoder()

In [32]:
src_ids = torch.randint(1000, (5, 30))
tgt_ids = torch.randint(1000, (5, 30))
x = encoder(emb(src_ids))
decoder(emb(tgt_ids), x)

[EncoderBlock] Encoder block input size=torch.Size([5, 30, 512])
[EncoderBlock] Attention output size=torch.Size([5, 30, 512])
[EncoderBlock] Feedforward output size=torch.Size([5, 30, 512])
[EncoderBlock] Encoder size output size=torch.Size([5, 30, 512])
[EncoderBlock] Encoder block input size=torch.Size([5, 30, 512])
[EncoderBlock] Attention output size=torch.Size([5, 30, 512])
[EncoderBlock] Feedforward output size=torch.Size([5, 30, 512])
[EncoderBlock] Encoder size output size=torch.Size([5, 30, 512])
[EncoderBlock] Encoder block input size=torch.Size([5, 30, 512])
[EncoderBlock] Attention output size=torch.Size([5, 30, 512])
[EncoderBlock] Feedforward output size=torch.Size([5, 30, 512])
[EncoderBlock] Encoder size output size=torch.Size([5, 30, 512])
[EncoderBlock] Encoder block input size=torch.Size([5, 30, 512])
[EncoderBlock] Attention output size=torch.Size([5, 30, 512])
[EncoderBlock] Feedforward output size=torch.Size([5, 30, 512])
[EncoderBlock] Encoder size output size=t

tensor([[[-6.2177e-01,  8.3054e-01, -2.5055e+00,  ...,  3.2256e+00,
           2.7803e+00, -5.6453e+00],
         [-4.0257e+00,  5.4899e+00,  5.0048e+00,  ...,  7.3244e-01,
          -2.4788e+00, -1.7646e+00],
         [-1.7118e+00,  4.9984e-01, -4.1395e+00,  ...,  2.2858e+00,
           8.4680e+00,  1.1552e+00],
         ...,
         [ 5.1021e-01, -4.0912e+00, -2.0309e-01,  ..., -3.9896e-01,
          -2.9902e+00, -1.3581e+00],
         [ 1.7264e+00,  1.4768e+00,  2.5084e+00,  ..., -2.6467e+00,
           3.6436e+00,  9.3666e-01],
         [-2.4069e+00,  2.6056e+00, -2.2024e+00,  ...,  4.3321e+00,
          -4.9218e+00, -3.9276e+00]],

        [[ 3.1604e-01, -2.0718e+00, -1.1691e+00,  ..., -8.4731e-01,
          -3.0178e+00,  1.4080e+00],
         [-6.2286e+00,  9.3539e-01, -2.9332e+00,  ..., -4.2893e+00,
           6.5248e-01,  2.0102e+00],
         [ 1.9881e+00,  1.3487e+00, -5.2017e+00,  ..., -6.4079e+00,
          -4.2760e+00,  2.8536e+00],
         ...,
         [-1.8641e+00,  4