<a href="https://colab.research.google.com/github/AkHiLdEvGoD/DeepLearning-Algorithms/blob/main/MultiHeadAttention.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [23]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

In [9]:
embd_dims = 32
batch_size = 4
seq_len = 8

In [25]:
class MultiHeadAttention(nn.Module):
  def __init__(self,embd_dims,attn_heads):
    super().__init__()

    self.H = attn_heads

    self.Wk = nn.Linear(embd_dims,embd_dims)
    self.Wq = nn.Linear(embd_dims,embd_dims)
    self.Wv = nn.Linear(embd_dims,embd_dims)

    self.out_proj = nn.Linear(embd_dims,embd_dims)

  def forward(self,X):
    B,S,D = X.shape

    Q = self.Wq(X)
    K = self.Wk(X)
    V = self.Wv(X)

    Dh = Q.shape[2]//self.H

    def split_heads(t):
      return t.view(B,S,self.H,Dh).transpose(1,2)

    Q = split_heads(Q)
    K = split_heads(K)
    V = split_heads(V)

    attn_scores = torch.matmul(Q,K.transpose(-2,-1))/math.sqrt(Dh)
    attn_weights = F.softmax(attn_scores,dim=-1)

    attn_output = torch.matmul(attn_weights,V)

    attn_output = attn_output.transpose(1,2).contiguous()
    attn_output = attn_output.view(B,S,D)

    outputs = self.out_proj(attn_output)
    return outputs

In [26]:
X = torch.randn(batch_size,seq_len,embd_dims)
mha = MultiHeadAttention(embd_dims,4)
outs = mha(X)
print(X.shape)
print(outs.shape)

torch.Size([4, 8, 32])
torch.Size([4, 8, 32])
