In [None]:
# Importing packages/libraries
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

In [None]:
class MultiHeadAttention(nn.Module):
  def __init__(self, input_dim, d_model, num_heads): # x [batch_size, sequence_length, embedding_dim]
    super(MultiHeadAttention, self).__init__()
    self.input_dim =  input_dim 
    self.num_heads = num_heads
    self.head_dim = d_model // num_heads # splitting the query, key, value into multiple attention heads
    self.d_model = d_model # size of query, key, value vectors

    # we only use a single layer to compute all query, key, value  then split them 
    # vectors and to make our model faster as a single layer requires only one 
    # matrix multiplication while 3 layers would require 3 such multiplications

    self.linear_qkv = nn.Linear(self.input_dim, 3*d_model)
    self.linear_output = nn.Linear(d_model, d_model)
  
  def calculate_weights(self, q, k):
    att_weights = torch.matmul(q, k.transpose(-1, -2))
    scaled_weights = att_weights / math.sqrt(self.d_model)
    return scaled_weights
  
  def forward(self, x, mask= None):
    batch_size, seq_len, input_dim = x.size()
    qkv = self.linear_qkv(x) # batch_size, sequence_length, d_model
    qkv = qkv.reshape(batch_size, seq_len, self.num_heads, 3*self.head_dim) 
    qkv = qkv.permute(0,2,1,3) # batch_size, num_heads, seq_len, head_dim (*3 for query, key, value)
    q, k, v = torch.split(qkv, qkv.size(-1) // 3, dim=-1)
    weights = self.calculate_weights(q, k)
    if mask != None:
      weights += mask
    weights = F.softmax(weights, dim = -1)
    updated_values = torch.matmul(weights, v)
    updated_values = updated_values.reshape(batch_size, seq_len, self.num_heads * self.head_dim)
    output = self.linear_output(updated_values)
    return output
    
def mask_gen(qk):
  mask = torch.full(qk.size() , float('-inf'))
  mask = torch.triu(mask, diagonal=1)
  return mask

In [None]:
def positional_encoding(d_model, max_seq_len=5000):
    all_idx = torch.arange(0, d_model, step=2).float()
    denominator = torch.pow(10000, all_idx/d_model)
    positions = torch.arange(0, max_seq_len).reshape(max_seq_len, 1).float()
    sin_idx = torch.sin(positions/denominator)
    cos_idx = torch.cos(positions/denominator)
    pe = torch.stack((sin_idx, cos_idx)).permute(1, 2, 0).flatten(start_dim=1, end_dim=2)
    return pe

pe = positional_encoding(d_model=6, max_seq_len=10)


