In [20]:
d_model = 512 
num_heads  = 8
drop_prob = 0.1 
batch_size =  30 
max_sequence_length = 200 
ffn_hidden = 2048
num_layers = 5 



In [21]:
import torch 
from torch import nn 
import torch.nn.functional as F 
import numpy as np 
import pandas as pd
import math

In [22]:
x = torch.randn((batch_size,max_sequence_length,d_model))

y = torch.randn((batch_size,max_sequence_length,d_model))

mask = torch.full([max_sequence_length,max_sequence_length],float('-inf'))

mask = torch.triu(mask,diagonal=1)



In [101]:
def scaled_dot_product(q, k, v, mask=None):
    d_k = q.size()[-1]
    scaled = torch.matmul(q, k.transpose(-1, -2)) / math.sqrt(d_k)
    print(f"Scaled.size(): {scaled.size()}")
    if mask is not None:
        print(f"-- Adding Mask --")
        scaled += mask 
    attention = F.softmax(scaled, dim=-1)
    values = torch.matmul(attention, v)
    return values, attention

In [102]:
class PositionwiseFeedForward(nn.Module):
    def __init__(self, d_model, hidden, drop_prob=0.1):
        super(PositionwiseFeedForward, self).__init__()
        self.linear1 = nn.Linear(d_model, hidden)
        self.linear2 = nn.Linear(hidden, d_model)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(drop_prob)

    def forward(self, x):
        x = self.linear1(x)
        print(f"x after linear layer: {x.size()}")
        x = self.relu(x)
        print(f"x after relu layer: {x.size()}")
        x = self.dropout(x)
        x = self.linear2(x)
        return x

In [103]:
class LayerNormalization(nn.Module):
    def __init__(self, parameters_shape, eps=1e-5):
        super().__init__()
        self.parameters_shape = parameters_shape
        self.eps = eps 
        self.gamma = nn.Parameter(torch.ones(parameters_shape))
        self.beta = nn.Parameter(torch.zeros(parameters_shape)) 

    def forward(self, inputs):
        dim = [-(i+1) for i in range(len(self.parameters_shape))]
        print(f"dims: {dim}")
        mean = inputs.mean(dim=dim, keepdim=True)
        var = ((inputs - mean) ** 2).mean(dim=dim, keepdim=True)
        std = (var + self.eps).sqrt()
        y = (inputs - mean) / std
        out = self.gamma * y + self.beta 
        return out

In [104]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super().__init__()
        self.d_model = d_model 
        self.num_heads = num_heads
        self.head_dim = d_model // num_heads 
        assert d_model % num_heads == 0, "d_model must be divisible by num_heads"
        self.qkv_model = nn.Linear(d_model, 3 * d_model)
        self.linear_layer = nn.Linear(d_model, d_model)

    def forward(self, x, mask=None):
        batch_size, sequence_length, d_model = x.size()
        print(f"x.size(): {x.size()}")
        qkv = self.qkv_model(x)
        print(f"qkv.size(): {qkv.size()}")
        qkv = qkv.reshape(batch_size, sequence_length, self.num_heads, 3 * self.head_dim)
        print(f"qkv.size(): {qkv.size()}")
        qkv = qkv.permute(0, 2, 1, 3)  # (batch_size, num_heads, max_sequence_length, 3 * head_dim)
        q, k, v = qkv.chunk(3, dim=-1)
        print(f"q.size(): {q.size()}")
        print(f"k.size(): {k.size()}")
        print(f"v.size(): {v.size()}")
        values, attention = scaled_dot_product(q, k, v, mask=mask)
        print(f"values.size(): {values.size()}")
        values = values.reshape(batch_size, sequence_length, self.num_heads * self.head_dim)
        out = self.linear_layer(values)
        return out, attention

In [105]:
class MultiHeadCrossAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super().__init__()
        self.d_model = d_model
        self.num_heads = num_heads
        self.head_dim = d_model // num_heads
        assert d_model % num_heads == 0, "d_model must be divisible by num_heads"
        self.kv_layer = nn.Linear(d_model, 2 * d_model)
        self.q_layer = nn.Linear(d_model, d_model)
        self.linear_layer = nn.Linear(d_model, d_model)
    
    def forward(self, x, y, mask=None):
        batch_size, sequence_length, d_model = x.size()
        print(f"x.size(): {x.size()}")
        kv = self.kv_layer(x)
        print(f"kv.size: {kv.size()}")
        q = self.q_layer(y)
        print(f"q.size(): {q.size()}")
        kv = kv.reshape(batch_size, sequence_length, self.num_heads, 2 * self.head_dim)
        q = q.reshape(batch_size, sequence_length, self.num_heads, self.head_dim)
        kv = kv.permute(0, 2, 1, 3)  # (batch_size, num_heads, max_sequence_length, 2 * head_dim)
        q = q.permute(0, 2, 1, 3)  # (batch_size, num_heads, max_sequence_length, head_dim)
        k, v = kv.chunk(2, dim=-1)
        print(f"k.size(): {k.size()}")
        print(f"v.size(): {v.size()}")
        values, attention = scaled_dot_product(q, k, v, mask=mask)
        print(f"values.size(): {values.size()}")
        values = values.reshape(batch_size, sequence_length, d_model)
        out = self.linear_layer(values)
        return out, attention


In [106]:
class DecoderLayer(nn.Module):
    def __init__(self, d_model, num_heads, ffn_hidden, drop_prob=0.1):
        super(DecoderLayer, self).__init__()
        self.self_attn = MultiHeadAttention(d_model, num_heads)
        self.norm = LayerNormalization(parameters_shape=[d_model])
        self.dropout1 = nn.Dropout(p=drop_prob)
        self.encoder_decoder_attention = MultiHeadCrossAttention(d_model=d_model, num_heads=num_heads)
        self.norm2 = LayerNormalization(parameters_shape=[d_model])
        self.dropout2 = nn.Dropout(p=drop_prob)
        self.ffn = PositionwiseFeedForward(d_model=d_model, hidden=ffn_hidden, drop_prob=drop_prob)
        self.norm3 = LayerNormalization(parameters_shape=[d_model])
        self.dropout3 = nn.Dropout(p=drop_prob)

    def forward(self, x, y, decoder_mask):
        _y = y  # 30x 200x 512
        print("Masked Self Attention")
        y, _ = self.self_attn(y, mask=decoder_mask)
        y = self.dropout1(y)
        print("Adding normalization layer")
        y = self.norm(y + _y)

        _y = y  # 30x 200x 512 
        print("Cross Attention")
        y, _ = self.encoder_decoder_attention(y, x, mask=decoder_mask)
        y = self.dropout2(y)
        print("Adding normalization layer")
        y = self.norm2(y + _y)

        _y = y  # 30x 200x 512 
        print("Feed Forward Network")
        y = self.ffn(y)
        y = self.dropout3(y)
        print("Adding normalization layer")
        y = self.norm3(y + _y)
        return y

class SequentialDecoder(nn.Sequential):
    def forward(self, *inputs):
        x, y, decoder_mask = inputs
        for module in self._modules.values():
            y = module(x, y, decoder_mask)
        return y


In [107]:
class Decoder(nn.Module):
    def __init__(self, d_model, ffn_hidden, num_heads, drop_prob, num_layers=1):
        super().__init__()
        self.layers = SequentialDecoder(*[DecoderLayer(d_model, num_heads, ffn_hidden, drop_prob)
                                          for _ in range(num_layers)])
    
    def forward(self, x, y, mask):
        y = self.layers(x, y, mask)
        return y

In [108]:
decoder = Decoder(d_model,ffn_hidden,num_heads,drop_prob,num_layers)

In [109]:
decoder

Decoder(
  (layers): SequentialDecoder(
    (0): DecoderLayer(
      (self_attn): MultiHeadAttention(
        (qkv_model): Linear(in_features=512, out_features=1536, bias=True)
        (linear_layer): Linear(in_features=512, out_features=512, bias=True)
      )
      (norm): LayerNormalization()
      (dropout1): Dropout(p=0.1, inplace=False)
      (encoder_decoder_attention): MultiHeadCrossAttention(
        (kv_layer): Linear(in_features=512, out_features=1024, bias=True)
        (q_layer): Linear(in_features=512, out_features=512, bias=True)
        (linear_layer): Linear(in_features=512, out_features=512, bias=True)
      )
      (norm2): LayerNormalization()
      (dropout2): Dropout(p=0.1, inplace=False)
      (ffn): PositionwiseFeedForward(
        (linear1): Linear(in_features=512, out_features=2048, bias=True)
        (linear2): Linear(in_features=2048, out_features=512, bias=True)
        (relu): ReLU()
        (dropout): Dropout(p=0.1, inplace=False)
      )
      (norm3): L

In [110]:
out = decoder(x,y,mask)

Masked Self Attention
x.size(): torch.Size([30, 200, 512])
qkv.size(): torch.Size([30, 200, 1536])
qkv.size(): torch.Size([30, 200, 8, 192])
q.size(): torch.Size([30, 8, 200, 64])
k.size(): torch.Size([30, 8, 200, 64])
v.size(): torch.Size([30, 8, 200, 64])
Scaled.size(): torch.Size([30, 8, 200, 200])
-- Adding Mask --
values.size(): torch.Size([30, 8, 200, 64])
Adding normalization layer
dims: [-1]


Cross Attention
x.size(): torch.Size([30, 200, 512])
kv.size: torch.Size([30, 200, 1024])
q.size(): torch.Size([30, 200, 512])
k.size(): torch.Size([30, 8, 200, 64])
v.size(): torch.Size([30, 8, 200, 64])
Scaled.size(): torch.Size([30, 8, 200, 200])
-- Adding Mask --
values.size(): torch.Size([30, 8, 200, 64])
Adding normalization layer
dims: [-1]
Feed Forward Network
x after linear layer: torch.Size([30, 200, 2048])
x after relu layer: torch.Size([30, 200, 2048])
Adding normalization layer
dims: [-1]
Masked Self Attention
x.size(): torch.Size([30, 200, 512])
qkv.size(): torch.Size([30, 200, 1536])
qkv.size(): torch.Size([30, 200, 8, 192])
q.size(): torch.Size([30, 8, 200, 64])
k.size(): torch.Size([30, 8, 200, 64])
v.size(): torch.Size([30, 8, 200, 64])
Scaled.size(): torch.Size([30, 8, 200, 200])
-- Adding Mask --
values.size(): torch.Size([30, 8, 200, 64])
Adding normalization layer
dims: [-1]
Cross Attention
x.size(): torch.Size([30, 200, 512])
kv.size: torch.Size([30, 200, 1024])
