<a href="https://colab.research.google.com/github/MMaggieZhou/llm/blob/main/Multi_Head_Self_Attention.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import math

import torch
import torch.nn as nn
import torch.nn.functional as F

In [None]:
'''
Implementation choises:
Projections of multi-heads are concatinated into one dimension;
the calculated Q, K, V are then reshaped to have head number in new dimension
before calculating attention score
Number of flops ~
O(batch_size * sequence_l * dmodel^2) +
O(batch_size * dmodel * sequence_l^2)
'''

class MultiHeadSelfAttention(nn.Module):
    def __init__(self, dmodel, dk, h):
      super().__init__() #?
      assert dmodel % h == 0, "Embedding dimension must be divisible by the number of heads"
      self.h = h
      self.dmodel = dmodel
      self.dk = dk
      self.dv = self.dmodel // self.h

      self.Q = nn.Linear(dmodel, dk * h)
      self.K = nn.Linear(dmodel, dk * h)
      self.V = nn.Linear(dmodel, dmodel)

    # break q, k, v into heads
    def _reshape(self, t):
        new_shape = t.size()[:-1] + (self.h, t.size()[-1] // self.h)
        t = t.view(new_shape) # (batch_size, sequence_l, h, dk or dv)
        return t.permute(0,2,1,3) # (batch_size, h, sequence_l, dk or dv)

    def forward(
        self,
        x,
        attention_mask=None, # all encoders share of a batch share the same mask and same applies to decoders
      ):
      # O(batch_size * sequence_l * dk * h * dmodel)
      Q = self._reshape(self.Q(x)) # (batch_size, h, sequence_l, dk)
      # O(batch_size * sequence_l * dk * h * dmodel)
      K = self._reshape(self.K(x)) # (batch_size, h, sequence_l, dk)
      # O(batch_size * sequence_l * dmodel^2)
      V = self._reshape(self.V(x)) # (batch_size, h, sequence_l, dv)

      # softmax(QK/dv-2)V, O(batch_size * h * dk * sequence_l ^ 2)
      scores = torch.matmul(Q, K.permute(0,1,3,2)) / math.sqrt(self.dk) #(batch_size, h, sequence_l, sequence_l)
      if attention_mask != None:  # (batch_size, 1, sequence_l, sequence_l)
          scores = scores.masked_fill(attention_mask == 0, float('-inf'))
      probs = F.softmax(scores, dim=-1)
      # O(batch_size * dmodel * sequence_l^2)
      output = torch.matmul(probs, V) # (batch_size, h, sequence_l, dv)

      # concat
      output = output.permute(0,2,1,3).contiguous() #? error view size is not compatible with input tensor's size and stride (at least one dimension spans across two contiguous subspaces). Use .reshape(...) instead.
      new_shape = output.size()[:-2] + (self.dmodel,)
      return output.view(new_shape), scores # output: (batch_size, sequene_l, dmodel), return scores for debugging

Position-wise Feed-Forward Networks:

O(batch_size * sequence_l * dmodel * dff)

Final output layer:
O(batch_size * sequence_l * dmodel * #tokens)

In [None]:
# to prevent attending to invalid tokens, aka padding tokens or future token;
def generate_attention_mask(embedding_mask, is_decoder):
    batch_size, seq_length = embedding_mask.size()
    mask = embedding_mask.unsqueeze(1).unsqueeze(2).expand(batch_size, 1, seq_length, seq_length)

    if is_decoder:
        tril = torch.tril(torch.ones(seq_length, seq_length))  # 0s where j > i, 1s elsewhere
        mask = mask.masked_fill(tril==0, 0)

    return mask

In [None]:
batch_size = 2
seq_length = 5
dmodel = 64
h = 8
dk = 10
is_decoder = True

multihead_attn = MultiHeadSelfAttention(dmodel, dk, h)
x = torch.rand(batch_size, seq_length, dmodel)
# the first sentence has 5 tokens and second sentence has 3 tokens
embedding_mask = torch.tensor([[1,1,1,1,1],[1,1,1,0,0]])
attention_mask = generate_attention_mask(embedding_mask, is_decoder)
output, scores = multihead_attn(x, attention_mask)
print(output.size())
print(scores)

torch.Size([2, 5, 64])
tensor([[[[ 0.1050,    -inf,    -inf,    -inf,    -inf],
          [ 0.1195,  0.0612,    -inf,    -inf,    -inf],
          [ 0.0789,  0.0508,  0.0496,    -inf,    -inf],
          [-0.0016,  0.0775,  0.0460,  0.0938,    -inf],
          [ 0.0679,  0.0545,  0.0227,  0.0369,  0.0350]],

         [[ 0.1365,    -inf,    -inf,    -inf,    -inf],
          [ 0.0953, -0.0316,    -inf,    -inf,    -inf],
          [ 0.1592,  0.0084,  0.0808,    -inf,    -inf],
          [ 0.0997, -0.0295,  0.0350,  0.0579,    -inf],
          [ 0.1123,  0.0090,  0.0596,  0.0739,  0.0521]],

         [[ 0.0861,    -inf,    -inf,    -inf,    -inf],
          [ 0.1755,  0.0604,    -inf,    -inf,    -inf],
          [ 0.1689,  0.1097,  0.1334,    -inf,    -inf],
          [ 0.1076,  0.0564,  0.0125,  0.0753,    -inf],
          [ 0.0549,  0.0259, -0.0416, -0.0207, -0.1313]],

         [[ 0.0301,    -inf,    -inf,    -inf,    -inf],
          [-0.0230, -0.2002,    -inf,    -inf,    -inf],
  