In [8]:
from enum import IntEnum
import inspect
import logging
import math
import numpy as np

import torch
from torch import nn

In [2]:
logger = logging.getLogger("tensor_shapes")
handler = logging.StreamHandler()
formatter = logging.Formatter('%(message)s')
handler.setFormatter(formatter)
logger.addHandler(handler)
logger.setLevel(1)

In [4]:
def getclass():
    stack = inspect.stack()
    return stack[3][0].f_locals["self"].__class__

def log_size(tsr: torch.Tensor, name:str):
    cls = getclass()
    logger.log(level=cls.level, msg=f"[{cls.__name__}] {name} size=[tsr.shape]")

In [20]:
class TensorLoggingLevels(IntEnum):
    attention = 1
    attention_head = 2
    multihead_attention_block = 3
    enc_dec_block = 4
    enc_dec = 5

In [7]:
class Dim(IntEnum):
    batch = 0
    seq = 1
    feature = 2

In [9]:
class ScaledDotProductAttention(nn.Module):
    level = TensorLoggingLevels.attention
    
    def __init__(self, dropout=0.1):
        super().__init__()
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, q, k, v, mask=None):
        d_k = k.size(-1)
        assert q.size(-1) == d_k
        
        attn = torch.bmm(q, k.transpose(Dim.seq, Dim.feature))
        attn = attn / math.sqrt(d_k)
        atten = torch.exp(attn)
        log_size(attn, "attention weight")
        
        if mask is not None:
            attn = attn.masked_fill(mask, 0)
        attn /= attn.sum(dim=-1, keepdim=True)
        attn = self.dropout(attn)
        output = torch.bmm(attn, v)
        log_size(output, "attention output size")
        return output

In [10]:
attn = ScaledDotProductAttention()

In [11]:
q = torch.rand(5, 10, 20)
k = torch.rand(5, 10, 20)
v = torch.rand(5, 10, 20)

In [13]:
# attn(q, k, v)

In [14]:
class AttentionHead(nn.Module):
    level = TensorLoggingLevels.attention_head
    
    def __init__(self, d_model, d_feature, dropout=0.1):
        super().__init__()
        
        self.attn = ScaledDotProductAttention(dropout)
        self.query_tfm = nn.Linear(d_model, d_feature)
        self.key_tfm = nn.Linear(d_model, d_feature)
        self.value_tfm = nn.Linear(d_model, d_feature)
        
    def forward(self, queries, keys, values, mask=None):
        Q = self.query_tfm(queries)
        K = self.key_tfm(keys)
        V = self.value_tfm(values)
        log_size(Q, "queries, keys, vals")
        x = self.attn(Q, K, V)
        return x

In [16]:
# attn_head = AttentionHead(20, 20)
# attn_head(q, k, v)

In [17]:
logger.setLevel(TensorLoggingLevels.attention_head)

In [21]:
class MultiHeadAttention(nn.Module):
    level = TensorLoggingLevels.multihead_attention_block
    
    def __init__(self, d_model, d_feature, n_heads, dropout=0.1):
        super().__init__()
        self.d_model = d_model
        self.d_feature = d_feature
        self.n_heads = n_heads
        assert d_model == d_feature * n_heads
        
        self.attn_heads = nn.ModuleList([
            AttentionHead(d_model, d_feature, dropout) for _ in range(n_heads)
        ])
        self.projection = nn.Linear(d_feature * n_heads, d_model)
        
    def forward(self, queries, keys, values, mask=None):
        log_size(queries, "Input queries")
        x = [
            attn(queries, keys, values, mask=mask) for i, attn in enumerate(self.attn_heads)
        ]
        log_size(x[0], "output of single head")
        
        x = torch.cat(x, dim=Dim.feature)
        log_size(x, "concatenated output")
        x = self.projection(x)
        log_size(x, "projected output")
        return x

In [24]:
# heads = MultiHeadAttention(20*8, 20, 8)
# heads(q.repeat(1,1,8), k.repeat(1,1,8), v.repeat(1,1,8))

In [25]:
logger.setLevel(TensorLoggingLevels.multihead_attention_block)

In [27]:
class LayerNorm(nn.Module):
    def __init__(self, d_model, eps=1e-8):
        super(LayerNorm, self).__init__()
        self.gamma = nn.Parameter(torch.ones(d_model))
        self.beta = nn.Parameter(torch.zeros(d_model))
        self.eps = eps
        
    def forward(self, x):
        mean = x.mean(-1, keepdim=True)
        std = x.std(-1, keepdim=True)
        return self.gamma * (x - mean) / (std + self.eps) + self.beta

In [None]:
class EncoderBlock(nn.Module):
    level = TensorLoggingLevels.enc_dec_block
    
    def __init__(self, d_model=512, d_feature=64, d_ff=2048, n_heads=8, dropout=0.1):
        super.__init__()
        