# Transformer implementation Using Pytorch

In [None]:
#TODO: Add dropout and fix the mask code
#TODO: implement the model
#TODO: Train the model

In [38]:
# Author: Rishabh Agarwal
# All the imports for the code

import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from tokenizers import Tokenizer
from torch.utils.data import Dataset, DataLoader


In [5]:
# Feed Forward Network
# Used in the Transformer Block
class FFN(nn.Module):
    def __init__(self, d_model=512):
        super().__init__()
        self.linear1 = nn.Linear(d_model, d_model * 4)
        self.linear2 = nn.Linear(d_model * 4, d_model)
        self.relu = nn.ReLU()

    def forward(self, x):
        x = self.linear1(x)
        x = self.relu(x)
        x = self.linear2(x)
        return x 

In [42]:
device = "cuda" if torch.cuda.is_available() else "cpu"
device

'cuda'

In [48]:
# Multiheaded attention class for the transformer
class MultiHeadedAttention(nn.Module):
    def __init__(self, h, d_k, d_model, device, dropout=0.1):
        super().__init__()
        self.h = h
        self.d_k = d_k # d_k = d_v
        self.d_model = d_model

        self.device = device


        assert d_model % h == 0  # Assert that the number of heads divides the model dimension
        assert d_k == d_model // h # Assert that the key and value dimensions are equal to d_model // h

        self.W_Q = nn.Linear(d_model, d_k, device = device)
        self.W_K = nn.Linear(d_model, d_k, device = device)
        self.W_V = nn.Linear(d_model, d_k, device = device)

        self.W_O = nn.Linear(h*d_k, d_model, device = device)

    def forward(self, Q, K, V, mask=None):
        # Q, K, V: (batch_size, seq_len, d_model)
        # mask: (batch_size, seq_len, seq_len)

        batch_size = Q.size(0)

        concatenated_heads = torch.tensor([], device=self.device)

        # The multiple heads being computed to be concatenated together and then applied to the output layer
        # can be done in parallel
        for _ in range(self.h):
            # Linearly project the queries
            computed_Q = self.W_Q(Q)
            # Linearly project the keys
            computed_K = self.W_K(K)
            # Linearly project the queries, keys, and values
            computed_V = self.W_V(V)

            # Calculate the attention scores
            head_i = F.scaled_dot_product_attention(computed_Q, computed_K, computed_V, mask)
            # print(head_i.shape)
            concatenated_heads = torch.cat((concatenated_heads, head_i), dim=-1)
            
        return self.W_O(concatenated_heads)

In [49]:
# Test the MultiHeadedAttention class
def test_multi_headed_attention(device):
    h = 8
    d_k = 64
    d_model = 512
    batch_size = 32

    mha = MultiHeadedAttention(h, d_k, d_model,device)

    Q = torch.randn(batch_size, 10, d_model, device=device)
    K = torch.randn(batch_size, 10, d_model, device=device)
    V = torch.randn(batch_size, 10, d_model, device=device)

    out = mha(Q, K, V)

    assert out.size() == (batch_size, 10, d_model)

    print("MultiHeadedAttention test passed")
test_multi_headed_attention(device)

MultiHeadedAttention test passed
