In [1]:
import numpy as np
import math
import torch
import torch.nn as nn
import torch.nn.functional as F

In [2]:
def scaled_dot_product(q, k, v, mask = None) :
  """ Notations :
          --> mask is for decoder only
          --> k.transpose(-2,-1) as we are transposing the last 2 dims only
          --> in softmax we used dim=-1 as we are applyin it to the last dim
  """
  d_k = q.size()[-1]
  scaled = torch.matmul(q, k.transpose(-2,-1)) / math.sqrt(d_k)
  if mask is not None :
    mask = torch.full(scaled.size(), float('-inf'))
    mask = torch.triu(mask, diagonal=1)
    scaled += mask
  attention = F.softmax(scaled, dim=-1)
  values = torch.matmul(attention, v)

  return values, attention


In [3]:
class Multi_Head_Attention(nn.Module) :
  def __init__(self, input_dim, d_model, num_heads) :
    super().__init__()
    self.input_dim = input_dim
    self.d_model = d_model
    self.num_heads = num_heads
    self.head_dim = d_model // num_heads
    self.qkv_layer = nn.Linear(input_dim, 3*d_model)
    self.linear_layer = nn.Linear(d_model, d_model)

  def forward(self, x, mask = None) :
    batch_size, seq_len, input_dim = x.size()
    qkv = self.qkv_layer(x)
    qkv = qkv.reshape(batch_size, seq_len, self.num_heads, 3*self.head_dim)
    qkv = qkv.permute(0, 2, 1, 3)
    q, k, v = qkv.chunk(3, dim=-1)
    values, attention = scaled_dot_product(q, k, v, mask)
    values = values.reshape(batch_size, seq_len, self.num_heads * self.head_dim)
    out = self.linear_layer(values)
    return out


In [5]:
input_dim = 1024
d_model = 512
num_heads = 8
batch_size = 30
seq_len = 5

x = torch.randn( (batch_size, seq_len, input_dim) )

model = Multi_Head_Attention(input_dim, d_model, num_heads)
out = model.forward(x)

In [6]:
out.shape

torch.Size([30, 5, 512])