In [None]:
# Importing packages/libraries
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

In [None]:
class MultiHeadAttention(nn.Module):
  # here input dimension of model = d_model. In PyTorch's implementation embed_dim has been used instead of d_model.
  def __init__(self, d_model, num_heads): # x [batch_size, sequence_length, embedding_dim]
    super(MultiHeadAttention, self).__init__()
    self.num_heads = num_heads
    self.head_dim = d_model // num_heads # splitting the query, key, value into multiple attention heads
    self.d_model = d_model # size of query, key, value vectors

    # we only use a single layer to compute all query, key, value  then split them 
    # vectors and to make our model faster as a single layer requires only one 
    # matrix multiplication while 3 layers would require 3 such multiplications

    self.linear_qkv = nn.Linear(d_model, 3*d_model)
    self.linear_output = nn.Linear(d_model, d_model)
  
  def calculate_weights(self, q, k):
    att_weights = torch.matmul(q, k.transpose(-1, -2))
    scaled_weights = att_weights / math.sqrt(self.d_model)
    return scaled_weights
  
  def forward(self, x, mask= None):
    qkv = self.linear_qkv(x) # batch_size, sequence_length, d_model
    batch_size, seq_len, d_model = qkv.size()
    qkv = qkv.view(batch_size, seq_len, self.num_heads, 3, self.head_dim).permute(0, 2, 1, 4, 3) 
    # after permuting = batch_size, num_heads, seq_len, head_dim, 3
    q, k, v = qkv.unbind(dim=-1)
    weights = self.calculate_weights(q, k)
    if mask != None:
      weights += mask
    weights = F.softmax(weights, dim = -1)
    print(weights.shape, v.shape)
    # weights =  batch_size, num_heads, seq_len, seq_len
    # values  =  batch_size, num_heads, seq_len, head_dim
    updated_values = torch.einsum('bnij,bnjk->bnik', weights, v)
    updated_values = updated_values.reshape(batch_size, seq_len, self.num_heads * self.head_dim)

    output = self.linear_output(updated_values)
    return output
    
def mask_gen(qk):
  mask = torch.full(qk.size() , float('-inf'))
  mask = torch.triu(mask, diagonal=1)
  return mask

In [None]:
def positional_encoding(d_model, max_seq_len=5000):
    all_idx = torch.arange(0, d_model, step=2).float()
    denominator = torch.pow(10000, all_idx/d_model)
    positions = torch.arange(0, max_seq_len).reshape(max_seq_len, 1).float()
    sin_idx = torch.sin(positions/denominator)
    cos_idx = torch.cos(positions/denominator)
    pe = torch.stack((sin_idx, cos_idx)).permute(1, 2, 0).flatten(start_dim=1, end_dim=2)
    return pe

In [None]:
# This class has works only along the last dimension that is along the embedding dimension.
# We can make it more general by adding a parameter that computes the mean across batches as well.

class LayerNormalization(nn.Module):
  def __init__(self, d_model, epsilon = 1e-05):
    super().__init__()
    self.d_model = d_model
    self.epsilon = epsilon
    self.gammas = nn.Parameter(torch.ones(d_model))
    self.betas =  nn.Parameter(torch.ones(d_model))
  
  def forward(self, input_tensor):
    mean = input_tensor.mean(dim = -1, keepdim = True)
    std_dev = torch.sqrt(((input_tensor - mean) ** 2).mean(dim = -1, keepdim = True) + self.epsilon)
    normalized = (input_tensor - mean) / std_dev
    output = self.gammas * normalized + self.betas
    return output


In [None]:
class FeedForward(nn.Module):
  def __init__(self, d_model, hidden, dropout_prob =0.1):
    super().__init__()
    self.d_model = d_model
    self.hidden = hidden
    self.linear1 = nn.Linear(d_model, hidden)
    self.linear2 = nn.Linear(hidden, d_model)
    self.relu = nn.ReLU()
    self.dropout = nn.Dropout(p= dropout_prob)
  
  def forward(self, x):
    x = self.linear1(x)
    x = self.dropout(self.relu(x))
    x = self.linear2(x)
    return x
  

In [None]:
class EncoderBlock(nn.Module):
  def __init__(self, d_model, hidden, num_heads, dropout_prob):
    super().__init__()
    self.mul_head_att = MultiHeadAttention(d_model, num_heads)

    self.norm1 = LayerNormalization(d_model)
    self.norm2 = LayerNormalization(d_model)

    self.ff_layers = FeedForward(d_model, hidden, dropout_prob)

    self.dropout1 = nn.Dropout(p=dropout_prob)
    self.dropout2 = nn.Dropout(p=dropout_prob)

  def forward(self, x):
    res_x = x.clone()
    x = self.norm1(self.dropout1(self.mul_head_att(x)) + res_x)   
    res_x = x.clone()
    x = self.norm2(self.dropout2(self.ff_layers (x)) + res_x)
    return x



In [None]:
class Encoder(nn.Module):
  def __init__(self, d_model, hidden, num_heads, dropout_prob, num_layers):
    super().__init__()
    self.layers = nn.ModuleList([
    EncoderBlock(d_model, hidden, num_heads, dropout_prob)
      for _ in range(num_layers)
    ])

  def forward(self, x):
    for layer in self.layers:
      x = layer(x)
    return x
