# Transformer implementation Using Pytorch

In [None]:
# Author: Rishabh Agarwal
# All the imports for the code

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np


In [None]:
# Feed Forward Network
# Used in the Transformer Block
class FFN(nn.Module):
    def __init__(self, d_model, d_ff = 2048):
        super().__init__()
        self.linear1 = nn.Linear(d_model, d_ff)
        self.linear2 = nn.Linear(d_ff, d_model)
        self.relu = nn.ReLU()

    def forward(self, x):
        x = self.linear1(x)
        x = self.relu(x)
        x = self.linear2(x)
        return x 

In [None]:
# Multiheaded attention class for the transformer
class MultiHeadedAttention(nn.Module):
    def __init__(self, h, d_k, d_model):
        super().__init__()
        self.h = h
        self.d_k = d_k # d_k = d_v
        self.d_model = d_model

        assert d_model % h == 0  # Assert that the number of heads divides the model dimension
        assert d_k == d_model // h # Assert that the key and value dimensions are equal to d_model // h

        self.W_Q = nn.Linear(d_model, d_k)
        self.W_K = nn.Linear(d_model, d_k)
        self.W_V = nn.Linear(d_model, d_k)

        self.W_O = nn.Linear(h*d_k, d_model)

    def forward(self, Q, K, V, mask=None):
        # Q, K, V: (batch_size, seq_len, d_model)
        # mask: (batch_size, seq_len, seq_len)

        batch_size = Q.size(0)

        # Linearly project the queries, keys, and values
        Q = self.W_Q(Q).view(batch_size, -1, self.h, self.d_k).transpose(1, 2)