# Transformer implementation Using Pytorch

## Below is the Todo-list
Will be updated as we go about the code

In [2]:
#TODO: look into torchtext
#TODO: get the data
#TODO: pre-process the data - tokenize it as well
#TODO: implement the training loop
#TODO: Train the model

## Implementation Starts here

In [1]:
# Author: Rishabh Agarwal
# All the imports for the code

import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from tokenizers import Tokenizer
from torch.utils.data import Dataset, DataLoader

import torchtext


In [3]:
device = "cuda" if torch.cuda.is_available() else "cpu"
device

'cuda'

In [4]:
# Feed Forward Network
# Used in the Transformer Block
class FFN(nn.Module):
    def __init__(self, device, d_model=512):
        super().__init__()
        self.linear1 = nn.Linear(d_model, d_model * 4, device=device)
        self.linear2 = nn.Linear(d_model * 4, d_model, device=device)
        self.gelu = nn.GELU()
        # One reason of choosing GELU - not in the initial implementation of the transformer in the paper
        """
        relu can suffer from "problems where significant amount of neuron in the network become zero and don’t practically do anything." 
        gelu is smoother near zero and "is differentiable in all ranges, and allows to have gradients(although small) in negative range" which helps with this problem.
        """

        self.device = device

    def forward(self, x):
        x = x.to(self.device)
        x = self.linear1(x)
        x = self.gelu(x)
        x = self.linear2(x)
        return x 

In [5]:
# Multiheaded attention class for the transformer
class MultiHeadedAttention(nn.Module):
    def __init__(self, h, d_k, d_model, device, dropout=0.1):
        super().__init__()
        self.h = h
        self.d_k = d_k # d_k = d_v
        self.d_model = d_model
        self.dropout = dropout

        self.device = device


        assert d_model % h == 0  # Assert that the number of heads divides the model dimension
        assert d_k == d_model // h # Assert that the key and value dimensions are equal to d_model // h

        self.W_Q = nn.Linear(d_model, d_k, device = device)
        self.W_K = nn.Linear(d_model, d_k, device = device)
        self.W_V = nn.Linear(d_model, d_k, device = device)

        self.W_O = nn.Linear(h*d_k, d_model, device = device)


    def forward(self, Q, K, V, mask=None):
        # Q, K, V: (batch_size, seq_len, d_model)
        # mask: (batch_size, seq_len, seq_len)

        batch_size = Q.size(0)

        concatenated_heads = torch.tensor([], device=self.device)

        # The multiple heads being computed to be concatenated together and then applied to the output layer
        # can be done in parallel
        for _ in range(self.h):
            # Linearly project the queries
            computed_Q = self.W_Q(Q)
            # Linearly project the keys
            computed_K = self.W_K(K)
            # Linearly project the queries, keys, and values
            computed_V = self.W_V(V)

            # Calculate the attention scores
            head_i = F.scaled_dot_product_attention(computed_Q, computed_K, computed_V, mask, dropout_p=self.dropout)
            # print(head_i.shape)
            concatenated_heads = torch.cat((concatenated_heads, head_i), dim=-1)
            
        return self.W_O(concatenated_heads)

In [6]:
# Test the MultiHeadedAttention class
def test_multi_headed_attention(device):
    h = 8
    d_k = 64
    d_model = 512
    batch_size = 32

    mha = MultiHeadedAttention(h, d_k, d_model,device)

    Q = torch.randn(batch_size, 10, d_model, device=device)
    K = torch.randn(batch_size, 10, d_model, device=device)
    V = torch.randn(batch_size, 10, d_model, device=device)

    out = mha(Q, K, V)

    assert out.size() == (batch_size, 10, d_model)

    print("MultiHeadedAttention test passed")
test_multi_headed_attention(device)

MultiHeadedAttention test passed


  head_i = F.scaled_dot_product_attention(computed_Q, computed_K, computed_V, mask, dropout_p=self.dropout)


In [None]:
torch.__version__



In [7]:
# One of N Encoder Blocks in the Transformer
class Encoder(nn.Module):
    def __init__(self, h, d_k, d_model, device, dropout=0.1):
        super().__init__()
        self.mha = MultiHeadedAttention(h, d_k, d_model, device, dropout)
        self.layernorm1 = nn.LayerNorm(d_model, device=device)

        self.ffn = FFN(d_model=d_model, device=device)
        self.layernorm2 = nn.LayerNorm(d_model, device =device)

        self.device = device
    def forward(self, input_emb):
        input_emb = input_emb.to(self.device)
        x = self.mha(input_emb, input_emb, input_emb)
        x = self.layernorm1(x + input_emb)
        ffn_out = self.ffn(x).to(self.device)
        x = self.layernorm2(x + ffn_out)
        return x

In [8]:
# Test the encoder block
def test_encoder_block(device):
    h = 8
    d_k = 64
    d_model = 512
    batch_size = 32

    encoder_block = Encoder(h, d_k, d_model, device)

    input_emb = torch.randn(batch_size, 10, d_model, device=device)

    out = encoder_block(input_emb)

    assert out.size() == (batch_size, 10, d_model)

    print("Encoder block test passed")
test_encoder_block(device)

Encoder block test passed


In [9]:
# One of N Decoders' Implementation in the Transformer
class Decoder(nn.Module):
    def __init__(self,h, d_k, d_model, device, dropout=0.1 ):
        super().__init__()
        self.masked_mha = MultiHeadedAttention(h, d_k, d_model, device, dropout)
        self.mha = MultiHeadedAttention(h, d_k, d_model, device, dropout)

        self.ffn = FFN(d_model =d_model, device=device)

        self.layernorm1 = nn.LayerNorm(d_model, device=device)
        self.layernorm2 = nn.LayerNorm(d_model, device=device)
        self.layernorm3 = nn.LayerNorm(d_model, device=device)

        self.device = device
        
    def mask(self, x):
        """
        Mask the output embeddings
        """
        batch, seq, emb = x.shape

        mask = torch.tril(torch.ones(batch,seq, seq), diagonal=1) == 0
        mask = mask.masked_fill(mask == 0, float('-inf'))
        mask = mask.to(self.device)
        return mask
    
    def forward(self, output_emb, residual):
        """
        Residual connection comes from encoder
        """

        assert output_emb.shape == residual.shape

        mask = self.mask(output_emb)
        x = self.masked_mha(output_emb, output_emb, output_emb, mask).to(self.device)
        x = self.layernorm1(output_emb + x).to(self.device)

        y = self.mha(residual, residual, x).to(self.device)
        y = self.layernorm2(x + y).to(self.device)

        ffn_out = self.ffn(y).to(self.device)
        output = self.layernorm3(y + ffn_out).to(self.device)

        return output

In [10]:
# Test the decoder block
def test_decoder_block(device):
    h = 8
    d_k = 64
    d_model = 512
    batch_size = 32

    decoder_block = Decoder(h, d_k, d_model, device)

    output_emb = torch.randn(batch_size, 10, d_model, device=device)
    residual = torch.randn(batch_size, 10, d_model, device=device)

    out = decoder_block(output_emb, residual)

    assert out.size() == (batch_size, 10, d_model)

    print("Decoder block test passed")
test_decoder_block(device)


Decoder block test passed


In [18]:
class Transformer(nn.Module):
    def __init__(self, n, h, d_k, d_model, device, dropout=0.1):
        super().__init__()
        self.n = n
        self.encoder = nn.ModuleList([Encoder(h, d_k, d_model, device, dropout) for _ in range(n)])
        self.decoder = nn.ModuleList([Decoder(h, d_k, d_model, device, dropout) for _ in range(n)])
        self.final = nn.Sequential(
            nn.Linear(d_model, d_model, device=device),
            nn.Softmax(dim =-1),
        )
    def forward(self, input_emb, output_emb):
        for i in range(self.n):
            input_emb = self.encoder[i](input_emb)
            output_emb = self.decoder[i](output_emb, input_emb)
        out = self.final(output_emb)
        return out  

In [19]:
# test transformer block
def test_transformer(device):
    n = 6
    h = 8
    d_k = 64
    d_model = d_k*h


    transformer = Transformer(n, h, d_k, d_model, device)

    sample_input = torch.randn(32, 10, d_model, device=device) # positional encoding not added yet 
    sample_output = torch.randn(32, 10, d_model, device=device) # positional encoding not added yet 

    out = transformer(sample_input, sample_output)

    assert out.size() == (32, 10, d_model)

    print("Transformer test passed")

test_transformer(device)

Transformer test passed
