<a href="https://colab.research.google.com/github/MMaggieZhou/llm/blob/main/Multi_Head_Self_Attention.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [87]:
import math

import torch
import torch.nn as nn
import torch.nn.functional as F

In [88]:
'''
TODO: make parameters consistent with the paper
'''
class MultiHeadSelfAttention(nn.Module):
    def __init__(self, embed_dim, k_dim, num_heads):
      super().__init__() #?
      assert embed_dim % num_heads == 0, "Embedding dimension must be divisible by the number of heads"
      self.num_heads = num_heads
      self.embed_dim = embed_dim
      self.k_dim = k_dim
      self.v_dim = self.embed_dim // self.num_heads

      self.query = nn.Linear(embed_dim, k_dim * num_heads)
      self.key = nn.Linear(embed_dim, k_dim * num_heads)
      self.value = nn.Linear(embed_dim, embed_dim)

    # break q, k, v into heads
    def _reshape(self, t):
        new_shape = t.size()[:-1] + (self.num_heads, t.size()[-1] // self.num_heads) # TODO: size() or shape?
        t = t.view(new_shape) # (batch_size, sequence_l, num_heads, k_dim or v_dim)
        return t.permute(0,2,1,3) # (batch_size, num_heads, sequence_l, k_dim or v_dim)

    def forward(
        self,
        x,
        attention_mask=None, # all encoders share of a batch share the same mask and same applies to decoders
      ):

      Q = self._reshape(self.query(x)) # (batch_size, num_heads, sequence_l, k_dim)
      K = self._reshape(self.key(x)) # (batch_size, num_heads, sequence_l, k_dim)
      V = self._reshape(self.value(x)) # (batch_size, num_heads, sequence_l, v_dim)

      # softmax(QK/v_dim-2)V
      scores = torch.matmul(Q, K.permute(0,1,3,2)) / math.sqrt(self.k_dim) #(batch_size, num_heads, sequence_l, sequence_l)
      if attention_mask != None:  # (batch_size, 1, sequence_l, sequence_l)
          print(scores.size())
          print(attention_mask.size())
          scores = scores.masked_fill(attention_mask == 0, float('-inf'))
      probs = F.softmax(scores, dim=-1)
      output = torch.matmul(probs, V) # (batch_size, num_heads, sequence_l, v_dim)

      # concat
      output = output.permute(0,2,1,3).contiguous() #? error view size is not compatible with input tensor's size and stride (at least one dimension spans across two contiguous subspaces). Use .reshape(...) instead.
      new_shape = output.size()[:-2] + (self.embed_dim,)
      return output.view(new_shape), scores # (batch_size, sequene_l, embed_dim)

In [85]:
# to prevent attending to invalid tokens, aka padding tokens or future token;
def generate_attention_mask(embedding_mask, is_decoder):
    batch_size, seq_length = embedding_mask.size()
    mask = embedding_mask.unsqueeze(1).unsqueeze(2).expand(batch_size, 1, seq_length, seq_length)

    if is_decoder:
        tril = torch.tril(torch.ones(seq_length, seq_length))  # 0s where j > i, 1s elsewhere
        mask = mask.masked_fill(tril==0, 0)

    return mask

In [86]:
batch_size = 2
seq_length = 5
embed_dim = 64
num_heads = 8
k_dim = 10
is_decoder = True

multihead_attn = MultiHeadSelfAttention(embed_dim, k_dim, num_heads)
x = torch.rand(batch_size, seq_length, embed_dim)
# the first sentence has 5 tokens and second sentence has 3 tokens
embedding_mask = torch.tensor([[1,1,1,1,1],[1,1,1,0,0]])
attention_mask = generate_attention_mask(embedding_mask, is_decoder)
output, scores = multihead_attn(x, attention_mask)
print(output.size())
print(scores)

torch.Size([2, 8, 5, 5])
torch.Size([2, 1, 5, 5])
torch.Size([2, 5, 64])
tensor([[[[ 1.4343e-02,        -inf,        -inf,        -inf,        -inf],
          [-1.2957e-02,  1.2561e-02,        -inf,        -inf,        -inf],
          [-3.0356e-03,  7.3240e-02, -1.8429e-02,        -inf,        -inf],
          [ 6.8252e-02,  1.3929e-01,  1.2465e-01,  2.2523e-01,        -inf],
          [ 7.6309e-02,  1.2280e-01,  9.6493e-02,  1.7525e-01,  7.9725e-02]],

         [[-2.4399e-03,        -inf,        -inf,        -inf,        -inf],
          [-1.2792e-03, -1.2840e-03,        -inf,        -inf,        -inf],
          [-4.1264e-02, -3.5341e-02,  7.2489e-02,        -inf,        -inf],
          [-1.2637e-01, -8.6460e-02, -3.6553e-02,  6.5681e-02,        -inf],
          [-3.1637e-02, -1.7299e-02,  4.9774e-02,  8.0568e-02, -3.4710e-02]],

         [[ 5.3687e-02,        -inf,        -inf,        -inf,        -inf],
          [ 9.9212e-02, -1.0597e-02,        -inf,        -inf,        -inf],