In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor

In [None]:
torch.set_printoptions(precision=2)

In [None]:
t = torch.randn((3,3)); t

tensor([[ 0.68, -0.97,  0.98],
        [ 0.75, -1.10,  0.63],
        [ 0.23,  2.03,  1.22]])

In [None]:
mask = torch.tril(torch.ones_like(t)); mask

tensor([[1., 0., 0.],
        [1., 1., 0.],
        [1., 1., 1.]])

In [None]:
t.masked_fill(mask==0, float('-inf'))

tensor([[ 0.68,  -inf,  -inf],
        [ 0.75, -1.10,  -inf],
        [ 0.23,  2.03,  1.22]])

In [None]:
from math import sqrt

In [None]:
def scaled_dot_product(query: Tensor, key: Tensor, value: Tensor, mask: bool = False) -> Tensor:
    key_dim = key.shape[-1]

    scores = (query @ key.transpose(-1, -2)) / sqrt(key_dim)
    
    if mask:
        scores_mask = torch.tril(torch.ones_like(scores))
        scores = scores.masked_fill(scores_mask==0, float('-inf'))
        
    return F.softmax(scores, dim=-1) @ value

In [None]:
key, query, value = [torch.randn((3,3)) for _ in range(3)]

In [None]:
scaled_dot_product(key, query, value)

tensor([[-0.68, -0.87,  1.06],
        [-0.17, -0.43,  0.60],
        [-0.76, -0.92,  1.10]])

In [None]:
scaled_dot_product(key, query, value, mask=True)

tensor([[ 0.69,  0.42, -0.51],
        [ 0.09, -0.13,  0.14],
        [-0.76, -0.92,  1.10]])

In [None]:
class AttentionHead(nn.Module):
    def __init__(self, channels: int, mask: bool = False) -> None:
        super().__init__()
        self.linear = nn.Linear(channels, channels*3, bias=False)
        self.mask = mask
    
    def forward(self, x: Tensor) -> Tensor:
        h = self.linear(x)
        query, key, value = torch.chunk(h, chunks=3, dim=-1)
        
        attention_scores = scaled_dot_product(query, key, value, self.mask)
        
        return attention_scores

In [None]:
t = torch.randn((3,3))

In [None]:
AttentionHead(channels=3, mask=False)(t)

tensor([[-0.48, -0.11,  0.01],
        [-0.57, -0.20,  0.02],
        [-0.66, -0.29,  0.03]], grad_fn=<MmBackward0>)

In [None]:
AttentionHead(channels=3, mask=True)(t)

tensor([[ 0.87, -0.60, -0.61],
        [ 0.45, -0.36, -0.30],
        [ 0.45, -0.36, -0.30]], grad_fn=<MmBackward0>)

In [None]:
class MultiHeadAttention(nn.Module):
    def __init__(self, in_channels: int, out_channels: int, n_heads: int, mask: bool = False) -> None:
        super().__init__()
        self.attention_heads = nn.ModuleList([AttentionHead(in_channels, mask) for _ in range(n_heads)])
        
        self.linear = nn.Linear(in_channels*n_heads, out_channels)
    
    def forward(self, x: Tensor) -> Tensor:
        h = torch.cat([head(x) for head in self.attention_heads], dim=-1)
        
        return self.linear(h)

In [None]:
t = torch.randn((3, 3))

In [None]:
MultiHeadAttention(in_channels=3, out_channels=3, n_heads=3)(t)

tensor([[ 0.22,  0.14, -0.29],
        [ 0.24,  0.14, -0.28],
        [ 0.24,  0.12, -0.28]], grad_fn=<AddmmBackward0>)

In [None]:
MultiHeadAttention(in_channels=3, out_channels=3, n_heads=3, mask=True)(t)

tensor([[ 0.68, -0.04,  0.03],
        [ 0.31,  0.17, -0.09],
        [ 0.27,  0.10, -0.13]], grad_fn=<AddmmBackward0>)

In [None]:
from torch.nn import Module

In [None]:
def linear(in_channels: int, out_channels: int, Activation: Module = None, init: bool = True) -> Module:
    layers = [nn.Linear(in_channels, out_channels)]
    
    if Activation:
        layers.append(Activation())
    
    return nn.Sequential(*layers)

In [None]:
t = torch.randn((3,3))

In [None]:
linear(3, 3, nn.ReLU)(t)

tensor([[0.00, 0.00, 0.00],
        [0.00, 0.39, 0.00],
        [0.00, 0.00, 0.13]], grad_fn=<ReluBackward0>)

In [None]:
class FeedForward(nn.Module):
    def __init__(self, 
                 in_channels: int, 
                 hidden_channels: int, 
                 out_channels: int, 
                 n_hidden: int = 0, 
                 p_dropout: float = 0.,
                 Activation=nn.ReLU) -> None:
        super().__init__()
        self.in_layer = linear(in_channels, hidden_channels, Activation)
        
        self.hidden = nn.Sequential(*[linear(hidden_channels, hidden_channels, Activation) 
                                      for _ in range(n_hidden)])
        
        self.out_layer = linear(hidden_channels, out_channels, Activation=None)
        
        self.dropout = nn.Dropout(p_dropout)
    
    def forward(self, x: Tensor) -> Tensor:
        h = self.in_layer(x)
        h = self.hidden(h)
        h = self.out_layer(h)
        return self.dropout(h)

In [None]:
t = torch.randn((3, 3))

In [None]:
FeedForward(in_channels=3, hidden_channels=10, out_channels=3, n_hidden=2)(t)

tensor([[ 0.26, -0.19, -0.10],
        [ 0.20, -0.17, -0.07],
        [ 0.17, -0.19, -0.03]], grad_fn=<AddmmBackward0>)

In [None]:
def apply_relu_init(module):
    if isinstance(module, nn.Linear):
        nn.init.kaiming_normal_(module.weight, nonlinearity='relu')

In [None]:
feed_forward = FeedForward(in_channels=3, hidden_channels=10, out_channels=3, n_hidden=2).apply(apply_relu_init)

In [None]:
feed_forward(t)

tensor([[ 0.82, -1.62, -0.80],
        [-0.54, -2.96, -1.40],
        [-0.66, -1.98, -0.33]], grad_fn=<AddmmBackward0>)

In [None]:
class TransformerLayer(nn.Module):
    def __init__(self, 
                 in_channels: int, 
                 hidden_channels: int, 
                 out_channels: int, 
                 n_heads: int, 
                 n_hidden_layers: int = 0,
                 p_dropout: float = 0.,
                 mask: bool = False) -> None:
        super().__init__()
        self.norm1 = nn.LayerNorm((in_channels, hidden_channels))
        self.norm2 = nn.LayerNorm((hidden_channels, out_channels))
        
        self.attention = MultiHeadAttention(in_channels, hidden_channels, n_heads, mask)
        self.feed_forward = FeedForward(hidden_channels, hidden_channels, out_channels, n_hidden_layers, p_dropout)
    
    def forward(self, x: Tensor) -> Tensor:
        h = self.attention(self.norm1(x)) + x
        return self.feed_forward(self.norm2(h)) + h

In [None]:
t = torch.randn((3,3))

In [None]:
decoder = TransformerLayer(in_channels=3, 
                           hidden_channels=3, 
                           out_channels=3, 
                           n_heads=3, 
                           n_hidden_layers=0,
                           p_dropout=0.1,
                           mask=True)

In [None]:
decoder(t)

tensor([[ 0.64, -0.69, -1.81],
        [ 0.29,  0.60,  0.26],
        [ 0.67, -1.11, -0.89]], grad_fn=<AddBackward0>)

In [None]:
from functools import partial
from typing import Callable

In [None]:
class Hook:
    def __init__(self, module: Module, func: Callable) -> None:
        self.hook = module.register_forward_hook(partial(func, self))
    
    def __del__(self):
        self.remove()
        
    def __enter__(self, *args):
        return self
    
    def __exit__(self, *args):
        self.remove()
    
    def remove(self):
        self.hook.remove()

In [None]:
decoder = DecoderBlock(in_channels=3, hidden_channels=3, out_channels=3, n_heads=3, n_hidden_layers=0)

In [None]:
modules = []
prev = None
for i, module in enumerate(decoder.modules()):
    if isinstance(module, nn.ReLU) and i != 0:
        modules.append((module, prev))
    prev = module

In [None]:
for _, module in modules:
    apply_relu_init(module)

In [None]:
decoder(t)

tensor([[-0.84,  1.38, -3.20],
        [ 0.82,  0.02, -0.82],
        [ 0.59, -1.21, -0.89]], grad_fn=<AddBackward0>)

In [None]:
class Embeddings(nn.Module):
    def __init__(self, vocab_sz: int, hidden_channels: int) -> None:
        pass
    
    def forward(self, x: Tensor) -> Tensor:
        pass