# Parallel Multi Head Self-Attention

In [1]:
import torch
import torch.nn as nn

In [67]:
class MHSAOptimized(nn.Module):
    def __init__(self, d_in, d_out, context_len, n_heads=2, causal=True, dropout=0.1, qkv_bias=False):
        super().__init__()
        assert (d_out % n_heads == 0), 'd_out must be a multiple of n_heads'
        # self.d_in = d_in
        self.d_out = d_out
        self.n_heads = n_heads
        self.causal = causal
        # self.dropout = dropout
        self.head_dim = d_out // n_heads
        
        # k, q, v
        self.query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.key = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.value = nn.Linear(d_in, d_out, bias=qkv_bias)
        
        # liner projection, not a nesseary projection, can be skipped
        self.out_proj = nn.Linear(d_out, d_out)
        
        # dropout
        self.dropout = nn.Dropout(dropout)
        
        if self.causal:
            self.register_buffer(
                "mask",
                torch.triu(torch.ones(context_len, context_len),
                        diagonal=1)
            )
            
    def forward(self, x):
        b, n_tokens, d_in = x.shape
        q = self.query(x)
        k = self.key(x)
        v = self.value(x)
        
        # reshaping for n_heads attention, and transpose
        # (b, n_tokens, self.n_heads, self.head_dim) -> (b, self.n_heads, n_tokens, self.head_dim)
        q = q.view(b, n_tokens, self.n_heads, self.head_dim).transpose(1, 2)
        k = k.view(b, n_tokens, self.n_heads, self.head_dim).transpose(1, 2)
        v = v.view(b, n_tokens, self.n_heads, self.head_dim).transpose(1, 2)
        
        att_score = q @ k.transpose(2, 3) # out -> (b, self.n_heads, n_tokens, n_tokens)
        
        # MASKING future attention score, replace with inf that will be changed to zero by softmax
        if self.causal:
            att_score.masked_fill_(self.mask.bool()[:n_tokens, :n_tokens], float('-inf'))
        att_weights = torch.softmax(att_score/k.shape[-1]**0.5, dim=-1)
        
        # dropout
        att_weights = self.dropout(att_weights)
        
        # context vectors
        # -> (b, self.n_heads, n_tokens, self.head_dim) -> (b, n_tokens, self.n_heads, self.head_dim)
        context_vec = (att_score @ v).transpose(1, 2)
        # print(context_vec)   
        
        context_vec = context_vec.contiguous().view(b, n_tokens, self.d_out)
        
        # liner projection
        context_vec = self.out_proj(context_vec)
        
        return context_vec

In [68]:
# input
torch.manual_seed(124)
inputs = torch.randn(3,2)
batch = torch.stack((inputs, inputs), dim=0)
batch, inputs.shape

(tensor([[[ 0.2922,  1.5814],
          [ 0.9303,  0.6592],
          [ 0.3796, -0.3670]],
 
         [[ 0.2922,  1.5814],
          [ 0.9303,  0.6592],
          [ 0.3796, -0.3670]]]),
 torch.Size([3, 2]))

In [73]:
context_len = batch.shape[1]
print('context_length', context_len)
d_in = inputs.shape[1]
d_out = 2

# causal will lead to nan and -inf with small dimension inputs
mhsa = MHSAOptimized(d_in=d_in, d_out=d_out, context_len=context_len, n_heads=2, 
                    causal=False, dropout=0.2, qkv_bias=False)
mhsa

context_length 3


MHSAOptimized(
  (query): Linear(in_features=2, out_features=2, bias=False)
  (key): Linear(in_features=2, out_features=2, bias=False)
  (value): Linear(in_features=2, out_features=2, bias=False)
  (out_proj): Linear(in_features=2, out_features=2, bias=True)
  (dropout): Dropout(p=0.2, inplace=False)
)

In [74]:
with torch.no_grad():
    context = mhsa.forward(batch)
print(f"Context Matrix after multi head Self-Attention:\n", context)

Context Matrix after multi head Self-Attention:
 tensor([[[ 0.9573, -0.0071],
         [ 0.9248, -0.0426],
         [ 0.4983, -0.5417]],

        [[ 0.9573, -0.0071],
         [ 0.9248, -0.0426],
         [ 0.4983, -0.5417]]])
