# Transformer Stack

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [2]:
class ScaledDotProductAttention(nn.Module):
    def __init__(self, d):
        super().__init__()
        self.normalization_factor = d**0.5

    def forward(self, q, k, v, mask = None):
        unnormalized = F.softmax(torch.matmul(q, torch.transpose(k, -2, -1)), dim = -1)
        normalized = unnormalized / self.normalization_factor 
        if mask is not None:
            normalized = torch.mul(mask, normalized)
        attention_value = torch.matmul(normalized, v) 
        return attention_value, normalized

In [3]:
class Attention(nn.Module):
    def __init__(self, input_dim, embedding_dim):
        super().__init__()

        self.query_weights = nn.Linear(input_dim, embedding_dim)
        self.key_weights = nn.Linear(input_dim, embedding_dim)
        self.value_weights = nn.Linear(input_dim, embedding_dim)

        self.scaled_attention = ScaledDotProductAttention(embedding_dim)

    def forward(self, q, k, v):
        q = self.query_weights(q)
        k = self.key_weights(k) 
        v = self.value_weights(v)

        attention, attention_weights = self.scaled_attention(q, k, v)
        return attention, attention_weights

In [4]:
class MultiHeadAttention(nn.Module):
    def __init__(self, input_dim, num_heads, embedding_dim, output_dim):
        super().__init__()
        assert (embedding_dim % num_heads == 0), f"embedding dim {embedding_dim} must be divisible by num heads {num_heads}"
        
        self.input_dim = input_dim
        self.num_heads = num_heads
        self.embedding_dim = embedding_dim // num_heads
        self.multihead_attention = [(Attention(input_dim, self.embedding_dim)) for _ in range(num_heads)]

        self.output_projection = nn.Linear(embedding_dim, output_dim)
    
    def forward(self, q, k, v):
        attention_list, weights_list = [], []
        for attention_head in self.multihead_attention:
            attention, attention_weights = attention_head(q, k, v)
            attention_list.append(attention)
            weights_list.append(attention_weights.unsqueeze(dim = 1))
        
        attention = torch.cat(attention_list, dim = -1)
        weights_list = torch.cat(weights_list, dim = 1)
        attention = self.output_projection(attention)
        return attention, weights_list

# Encoder

In [5]:
class Encoder(nn.Module):
    def __init__(self, input_dim, embedding_dim, output_dim, num_heads):
        super().__init__()
        self.multi_head_attn = MultiHeadAttention(input_dim, num_heads, 
                                                  embedding_dim, 
                                                  output_dim)
        self.layer_norm_inter = nn.LayerNorm(output_dim)
        self.fc = nn.Sequential(*[
            nn.Linear(output_dim, output_dim),
            nn.ReLU(),
            nn.Linear(output_dim, output_dim),
        ])
        self.layer_norm_final = nn.LayerNorm(output_dim)

    def forward(self, x):
        attn, attn_weights = self.multi_head_attn(x, x, x)
        x = x + attn
        x = self.layer_norm_inter(x)
        fc_result = self.fc(x)
        x = self.layer_norm_final(fc_result + x)
        return x

In [6]:
batch_size = 8
seq_len = 10

input_dim = 128
embedding_dim = 128
output_dim = 128
num_heads = 4

In [7]:
x = torch.rand(size=(batch_size, seq_len, input_dim))
enc = Encoder(input_dim, embedding_dim, output_dim, num_heads)
print("Encoder output dim : ", enc(x).shape)

Encoder output dim :  torch.Size([8, 10, 128])


# Decoder 

In [8]:
class Decoder(nn.Module):
    def __init__(self, input_dim, embedding_dim, output_dim, num_heads):
        super().__init__()
        self.multi_head_attn_1 = MultiHeadAttention(input_dim, num_heads, 
                                                  embedding_dim, 
                                                  output_dim)
        self.layer_norm_1 = nn.LayerNorm(output_dim)
        self.multi_head_attn_2 = MultiHeadAttention(output_dim, num_heads, 
                                                    output_dim, 
                                                    output_dim)
        self.layer_norm_2 = nn.LayerNorm(output_dim)
        self.fc = nn.Sequential(*[
            nn.Linear(output_dim, output_dim), 
            nn.ReLU(), 
            nn.Linear(output_dim, output_dim)
        ])
        self.layer_norm_final = nn.LayerNorm(output_dim)

    def forward(self, x, encoder):
        attn, attn_weights = self.multi_head_attn_1(x, x, x)
        x = self.layer_norm_1(x + attn)
        attn, attn_weights = self.multi_head_attn_2(x, encoder, encoder)
        x = self.layer_norm_2(x + attn)
        x = self.fc(x)
        x = self.layer_norm_final(x)
        return x

In [9]:
x = torch.rand(size=(batch_size, seq_len, input_dim))
enc = Encoder(input_dim, embedding_dim, output_dim, num_heads)
encoder = enc.forward(x)

dec = Decoder(input_dim, embedding_dim, output_dim, num_heads)
print("Decoder output dim : ", dec(x, encoder).shape)

Decoder output dim :  torch.Size([8, 10, 128])
