<a href="https://colab.research.google.com/github/Dipak22/Case-Studies/blob/master/Attention_implementation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [12]:
import torch
import torch.nn as nn
import torch.nn.functional as F

batch_size = 4
seq_len = 64
embed_dim = 128

x = torch.randn(batch_size, seq_len, embed_dim)
print("X shape: ", x.shape)

X shape:  torch.Size([4, 64, 128])


In [13]:
# temporary implemetation

similarity = x @ x.transpose(1,2)
print("similarilty shape: ", similarity.shape)
print("Pre normalized variance:", similarity.var())
similarity_norm = similarity/(embed_dim**0.5)
print("Normalized variance: ", similarity_norm.var())

attn = similarity_norm.softmax(dim=-1)

context_vectors = attn@x
print("output shape: ", context_vectors.shape)

similarilty shape:  torch.Size([4, 64, 64])
Pre normalized variance: tensor(380.4652)
Normalized variance:  tensor(2.9724)
output shape:  torch.Size([4, 64, 128])


In [14]:
#Self Attention
class Attention(nn.Module):
  def __init__(self, embed_dim):
    super().__init__()

    self.embed_dim = embed_dim
    self.query = nn.Linear(self.embed_dim, self.embed_dim)
    self.key = nn.Linear(self.embed_dim, self.embed_dim)
    self.value = nn.Linear(self.embed_dim, self.embed_dim)
  def forward(self,x):
    q = self.query(x)
    k = self.key(x)
    v = self.value(x)
    similarity = (q @ k.transpose(1,2))/(self.embed_dim**0.5)
    attention = similarity.softmax(dim = -1)
    output = attention @ v
    return output

attention = Attention(embed_dim=128)
output = attention(x)
print(output.shape)


torch.Size([4, 64, 128])


In [15]:
#Multihead Attention
class MultiHeadAttention(nn.Module):
  def __init__(self, embed_dim, n_heads, attn_p =0, proj_p=0):
    super().__init__()
    self.embed_dim = embed_dim
    self.n_heads = n_heads
    self.head_dim = self.embed_dim//self.n_heads
    self.query = nn.Linear(self.embed_dim, self.embed_dim)
    self.key = nn.Linear(self.embed_dim, self.embed_dim)
    self.value = nn.Linear(self.embed_dim, self.embed_dim)
    self.attn_drop = nn.Dropout(attn_p)
    self.proj = nn.Linear(embed_dim, embed_dim)
    self.proj_drop = nn.Dropout(proj_p)
  def forward(self, x):
    batch_size, seq_len, _ = x.shape
    q = self.query(x).reshape(batch_size,seq_len,self.n_heads,self.head_dim).transpose(1,2).contiguous()
    k = self.key(x).reshape(batch_size,seq_len,self.n_heads,self.head_dim).transpose(1,2).contiguous()
    v = self.value(x).reshape(batch_size,seq_len,self.n_heads,self.head_dim).transpose(1,2).contiguous()


    #calculate attention
    similarity = (q @ k.transpose(-2,-1))/self.head_dim**0.5
    attention = similarity.softmax(dim = -1)
    attention = self.attn_drop(attention)
    x = attention@v
    x =x.transpose(1,2).reshape(batch_size, seq_len, self.embed_dim)
    x = self.proj(x)
    x = self.proj_drop(x)
    return x

model = MultiHeadAttention(64, 2)
x = torch.rand(4,10,64)
output = model(x)
print("input shape: ", x.shape)
print("output shape: ", output.shape)

input shape:  torch.Size([4, 10, 64])
output shape:  torch.Size([4, 10, 64])


In [17]:
#Attention mask
rand_attn = torch.rand(1,6,6)
attention_mask = torch.tensor([1,1,1,1,0,0]).unsqueeze(0).unsqueeze(1).bool()
rand_attn.masked_fill_(~attention_mask, float("-inf"))
rand_attn

tensor([[[0.6870, 0.6528, 0.0015, 0.9785,   -inf,   -inf],
         [0.6008, 0.3150, 0.7785, 0.2160,   -inf,   -inf],
         [0.5098, 0.6347, 0.9880, 0.2996,   -inf,   -inf],
         [0.1662, 0.1758, 0.0752, 0.1740,   -inf,   -inf],
         [0.7759, 0.3496, 0.9653, 0.2836,   -inf,   -inf],
         [0.5883, 0.2423, 0.0327, 0.0624,   -inf,   -inf]]])

In [19]:
#AttentionMask with multiple heads
rand_attn = torch.rand(1,2,6,6) # 2 heads
attention_mask = torch.tensor([1,1,1,1,0,0]).unsqueeze(0).bool()
attention_mask = attention_mask.unsqueeze(1).unsqueeze(1)
rand_attn.masked_fill_(~attention_mask, float('-inf'))
rand_attn

tensor([[[[0.6221, 0.7785, 0.0073, 0.2319,   -inf,   -inf],
          [0.4934, 0.8452, 0.5697, 0.3594,   -inf,   -inf],
          [0.5603, 0.9668, 0.7311, 0.3309,   -inf,   -inf],
          [0.2627, 0.3846, 0.4629, 0.7572,   -inf,   -inf],
          [0.4675, 0.1218, 0.2573, 0.0620,   -inf,   -inf],
          [0.6462, 0.5877, 0.7196, 0.2832,   -inf,   -inf]],

         [[0.0208, 0.5180, 0.7303, 0.0207,   -inf,   -inf],
          [0.9449, 0.7401, 0.2446, 0.3210,   -inf,   -inf],
          [0.3413, 0.3449, 0.3984, 0.3886,   -inf,   -inf],
          [0.6146, 0.9930, 0.6880, 0.9037,   -inf,   -inf],
          [0.2537, 0.0947, 0.4935, 0.3183,   -inf,   -inf],
          [0.4476, 0.2766, 0.8942, 0.6780,   -inf,   -inf]]]])

In [27]:
#Multihead self attention with Attention mask
class SelfAttention(nn.Module):
  def __init__(self, embed_dim, n_heads, attn_p =0, proj_p = 0):
    super().__init__()
    self.embed_dim = embed_dim
    self.n_heads = n_heads
    self.head_dim = self.embed_dim // self.n_heads

    self.q_proj = nn.Linear(self.embed_dim, self.embed_dim)
    self.k_proj = nn.Linear(self.embed_dim, self.embed_dim)
    self.v_proj = nn.Linear(self.embed_dim, self.embed_dim)
    self.attn_drop = nn.Dropout(attn_p)
    self.proj = nn.Linear(self.embed_dim, self.embed_dim)
    self.proj_drop = nn.Dropout(proj_p)
  def forward(self, x, attention_mask=None):
    batch_size, seq_len , _ = x.shape
    print(x.shape)
    q = self.q_proj(x).reshape(batch_size, seq_len, self.n_heads, self.head_dim).transpose(1,2).contiguous()
    k = self.k_proj(x).reshape(batch_size, seq_len, self.n_heads, self.head_dim).transpose(1,2).contiguous()
    v = self.v_proj(x).reshape(batch_size, seq_len, self.n_heads, self.head_dim).transpose(1,2).contiguous()

    attn = (q@k.transpose(-2,-1))/(self.head_dim)**0.5

    # apply attn mask
    if attention_mask is not None:
      attention_mask = attention_mask.unsqueeze(1).unsqueeze(1).repeat(1,1,seq_len,1)
      attn = attn.masked_fill_(~attention_mask, float("-inf"))
    attn = attn.softmax(dim = -1)
    attn = self.attn_drop(attn)

    print("after attention mask")
    print(attn)
    x = attn@v

    x = x.transpose(1,2).reshape(batch_size, seq_len, embed_dim)
    x = self.proj(x)
    x = self.proj_drop(x)

    return x

seq_lens = [3,5,4]
embed_dim = 9
n_heads = 3
model = SelfAttention(embed_dim = embed_dim, n_heads= n_heads)
rand = torch.randn(len(seq_lens), max(seq_lens), embed_dim)
masks = torch.nn.utils.rnn.pad_sequence([torch.ones(l) for l in seq_lens], batch_first=True, padding_value=0).bool()
print("attention masks")
print(masks)

output = model(rand, attention_mask = masks)
print("output")
output


attention masks
tensor([[ True,  True,  True, False, False],
        [ True,  True,  True,  True,  True],
        [ True,  True,  True,  True, False]])
torch.Size([3, 5, 9])
after attention mask
tensor([[[[0.4269, 0.1933, 0.3798, 0.0000, 0.0000],
          [0.3186, 0.3029, 0.3785, 0.0000, 0.0000],
          [0.3290, 0.2973, 0.3737, 0.0000, 0.0000],
          [0.2805, 0.4264, 0.2931, 0.0000, 0.0000],
          [0.2617, 0.4004, 0.3379, 0.0000, 0.0000]],

         [[0.4584, 0.3048, 0.2368, 0.0000, 0.0000],
          [0.4060, 0.2677, 0.3263, 0.0000, 0.0000],
          [0.2874, 0.2983, 0.4144, 0.0000, 0.0000],
          [0.3014, 0.3234, 0.3752, 0.0000, 0.0000],
          [0.1804, 0.3904, 0.4292, 0.0000, 0.0000]],

         [[0.3517, 0.3001, 0.3482, 0.0000, 0.0000],
          [0.3506, 0.4296, 0.2197, 0.0000, 0.0000],
          [0.2882, 0.4222, 0.2897, 0.0000, 0.0000],
          [0.2676, 0.3445, 0.3879, 0.0000, 0.0000],
          [0.2867, 0.3108, 0.4025, 0.0000, 0.0000]]],


        [[[0.1903

tensor([[[-0.1722,  0.0553, -0.0037, -0.1829,  0.3531, -0.3346, -0.2964,
          -0.1518,  0.4501],
         [-0.1951,  0.0696, -0.0111, -0.1379,  0.3413, -0.3316, -0.3649,
          -0.1019,  0.4501],
         [-0.2399,  0.0629, -0.0296, -0.0949,  0.3154, -0.3835, -0.3497,
          -0.0767,  0.4696],
         [-0.1884,  0.0990, -0.0041, -0.0467,  0.3144, -0.3690, -0.3527,
          -0.0743,  0.4343],
         [-0.2473,  0.1059, -0.0093, -0.0032,  0.2889, -0.3937, -0.3526,
          -0.0518,  0.4668]],

        [[-0.2290,  0.0654, -0.0831,  0.1258,  0.2146, -0.1069, -0.3625,
          -0.0582,  0.2011],
         [-0.2255,  0.0356, -0.1181,  0.1294,  0.2026, -0.1413, -0.3250,
          -0.0618,  0.1754],
         [-0.2160,  0.0779, -0.0677,  0.1187,  0.2340, -0.0836, -0.3726,
          -0.0511,  0.1991],
         [-0.2069,  0.0738, -0.0673,  0.1254,  0.2103, -0.1016, -0.3911,
          -0.0597,  0.2029],
         [-0.2136,  0.0705, -0.0810,  0.1441,  0.2083, -0.1018, -0.3682,
       

In [29]:
#causal masking
seq_len = 8
ones = torch.ones((seq_len, seq_len))
causal_mask = torch.tril(ones).bool()

# apply padding mask too
padding_mask= torch.tensor([1,1,1,1,0, 0, 0, 0]).bool()
padding_mask = padding_mask.unsqueeze(0).repeat(seq_len,1)

causal_mask = causal_mask.masked_fill_(~padding_mask, 0)
causal_mask

tensor([[ True, False, False, False, False, False, False, False],
        [ True,  True, False, False, False, False, False, False],
        [ True,  True,  True, False, False, False, False, False],
        [ True,  True,  True,  True, False, False, False, False],
        [ True,  True,  True,  True, False, False, False, False],
        [ True,  True,  True,  True, False, False, False, False],
        [ True,  True,  True,  True, False, False, False, False],
        [ True,  True,  True,  True, False, False, False, False]])

In [31]:
#Multihead self attention with Attention mask
class SelfAttention(nn.Module):
  def __init__(self, embed_dim, n_heads, attn_p =0, proj_p = 0, causal = False):
    super().__init__()
    self.embed_dim = embed_dim
    self.n_heads = n_heads
    self.head_dim = self.embed_dim // self.n_heads
    self.causal = causal

    self.q_proj = nn.Linear(self.embed_dim, self.embed_dim)
    self.k_proj = nn.Linear(self.embed_dim, self.embed_dim)
    self.v_proj = nn.Linear(self.embed_dim, self.embed_dim)
    self.attn_drop = nn.Dropout(attn_p)
    self.proj = nn.Linear(self.embed_dim, self.embed_dim)
    self.proj_drop = nn.Dropout(proj_p)
  def forward(self, x, attention_mask=None):
    batch_size, seq_len , _ = x.shape
    print(x.shape)
    q = self.q_proj(x).reshape(batch_size, seq_len, self.n_heads, self.head_dim).transpose(1,2).contiguous()
    k = self.k_proj(x).reshape(batch_size, seq_len, self.n_heads, self.head_dim).transpose(1,2).contiguous()
    v = self.v_proj(x).reshape(batch_size, seq_len, self.n_heads, self.head_dim).transpose(1,2).contiguous()

    attn = (q@k.transpose(-2,-1))/(self.head_dim)**0.5

    if self.causal:
      #create a causal mask
      ones = torch.ones((seq_len, seq_len), device = attn.device)
      causal_mask = torch.tril(ones)

      # add dimension for batch and n_heads
      causal_mask = causal_mask.reshape(1,1,seq_len,seq_len).bool()


      # apply attn mask
      if attention_mask is not None:
        causal_mask = causal_mask.repeat(batch_size,1,1,1)
        attention_mask = attention_mask.unsqueeze(1).unsqueeze(1).repeat(1,1,seq_len,1)
        causal_mask.masked_fill_(~attention_mask, False)
      attn = attn.masked_fill_(~causal_mask, float("-inf"))
    attn = attn.softmax(dim = -1)
    attn = self.attn_drop(attn)

    print("after attention mask")
    print(attn)
    x = attn@v

    x = x.transpose(1,2).reshape(batch_size, seq_len, embed_dim)
    x = self.proj(x)
    x = self.proj_drop(x)

    return x

seq_lens = [3,5,4]
embed_dim = 9
n_heads = 3
model = SelfAttention(embed_dim = embed_dim, n_heads= n_heads, causal=True)
rand = torch.randn(len(seq_lens), max(seq_lens), embed_dim)
masks = torch.nn.utils.rnn.pad_sequence([torch.ones(l) for l in seq_lens], batch_first=True, padding_value=0).bool()
print("attention masks")
print(masks)

output = model(rand, attention_mask = masks)
print("output")
output


attention masks
tensor([[ True,  True,  True, False, False],
        [ True,  True,  True,  True,  True],
        [ True,  True,  True,  True, False]])
torch.Size([3, 5, 9])
after attention mask
tensor([[[[1.0000, 0.0000, 0.0000, 0.0000, 0.0000],
          [0.5347, 0.4653, 0.0000, 0.0000, 0.0000],
          [0.2565, 0.3754, 0.3682, 0.0000, 0.0000],
          [0.1956, 0.5322, 0.2722, 0.0000, 0.0000],
          [0.4415, 0.2914, 0.2671, 0.0000, 0.0000]],

         [[1.0000, 0.0000, 0.0000, 0.0000, 0.0000],
          [0.3635, 0.6365, 0.0000, 0.0000, 0.0000],
          [0.3391, 0.3604, 0.3004, 0.0000, 0.0000],
          [0.4478, 0.4288, 0.1234, 0.0000, 0.0000],
          [0.1710, 0.3492, 0.4798, 0.0000, 0.0000]],

         [[1.0000, 0.0000, 0.0000, 0.0000, 0.0000],
          [0.6304, 0.3696, 0.0000, 0.0000, 0.0000],
          [0.2523, 0.4797, 0.2680, 0.0000, 0.0000],
          [0.3931, 0.3325, 0.2743, 0.0000, 0.0000],
          [0.3258, 0.2925, 0.3817, 0.0000, 0.0000]]],


        [[[1.0000

tensor([[[ 0.1671, -0.0286,  0.1037, -0.0251,  0.4422,  0.0877,  0.2845,
          -0.0369, -0.6447],
         [ 0.1617, -0.1590, -0.0444,  0.1780,  0.3622, -0.2162,  0.0505,
          -0.0452, -0.3431],
         [ 0.1507, -0.3053, -0.0203,  0.1416,  0.2164, -0.2129,  0.1077,
          -0.0880, -0.3099],
         [ 0.1185, -0.2628,  0.0477,  0.1022,  0.3259, -0.1692,  0.1274,
          -0.1175, -0.3682],
         [ 0.1535, -0.3351,  0.0430,  0.0048,  0.2281, -0.0921,  0.2625,
          -0.1350, -0.4605]],

        [[ 0.0056,  0.0772, -0.1156,  0.2442, -0.3029, -0.2755, -0.5676,
           0.2253,  0.2744],
         [ 0.3288, -0.0360, -0.0312,  0.2432, -0.0740, -0.0835, -0.0926,
           0.1240, -0.1185],
         [ 0.4110, -0.1339,  0.0999,  0.0430, -0.0184,  0.1371,  0.2247,
          -0.0389, -0.4166],
         [ 0.4597, -0.2973,  0.0865, -0.0749,  0.0102,  0.1314,  0.4172,
          -0.2244, -0.6192],
         [ 0.3527, -0.3146, -0.0435,  0.1315, -0.0435, -0.1370,  0.1150,
       

In [36]:
#cross attention
#Multihead self attention with Attention mask
class CrossAttention(nn.Module):
  def __init__(self, embed_dim, n_heads, attn_p =0, proj_p = 0):
    super().__init__()
    self.embed_dim = embed_dim
    self.n_heads = n_heads
    self.head_dim = self.embed_dim // self.n_heads

    self.q_proj = nn.Linear(self.embed_dim, self.embed_dim)
    self.k_proj = nn.Linear(self.embed_dim, self.embed_dim)
    self.v_proj = nn.Linear(self.embed_dim, self.embed_dim)
    self.attn_drop = nn.Dropout(attn_p)
    self.proj = nn.Linear(self.embed_dim, self.embed_dim)
    self.proj_drop = nn.Dropout(proj_p)
  def forward(self, src,tgt, attention_mask=None):
    batch_size, src_seq_len , _ = src.shape
    _,tgt_seq_len,_ = tgt.shape
    q = self.q_proj(tgt).reshape(batch_size, tgt_seq_len, self.n_heads, self.head_dim).transpose(1,2).contiguous()
    k = self.k_proj(src).reshape(batch_size, src_seq_len, self.n_heads, self.head_dim).transpose(1,2).contiguous()
    v = self.v_proj(src).reshape(batch_size, src_seq_len, self.n_heads, self.head_dim).transpose(1,2).contiguous()

    attn = (q@k.transpose(-2,-1))/(self.head_dim)**0.5

    # apply attn mask
    if attention_mask is not None:
      attention_mask = attention_mask.unsqueeze(1).unsqueeze(1).repeat(1,1,tgt_seq_len,1)
      attn = attn.masked_fill_(~attention_mask, float("-inf"))
    attn = attn.softmax(dim = -1)
    attn = self.attn_drop(attn)

    print("after attention mask")
    print(attn)
    x = attn@v

    x = x.transpose(1,2).reshape(batch_size, tgt_seq_len, embed_dim)
    x = self.proj(x)
    x = self.proj_drop(x)

    return x

english_seq_lens = [3,5,4]
french_seq_lens = [7,6,2]

embed_dim = 18
num_heads = 3
a = CrossAttention(embed_dim, num_heads)

### Create random tensor in the shape (Batch x Seq Len x Embed Dim) for French and English ###
### This will be a tensor upto the max(seq_lens) ###
rand_english = torch.randn(len(english_seq_lens),max(english_seq_lens),embed_dim)
rand_french = torch.randn(len(french_seq_lens),max(french_seq_lens),embed_dim)


### Create Attention Mask from the seq_lens (shortest sequences padded to the longest ###
english_masks = torch.nn.utils.rnn.pad_sequence([torch.ones(l) for l in english_seq_lens], batch_first=True, padding_value=0).bool()
french_masks = torch.nn.utils.rnn.pad_sequence([torch.ones(l) for l in french_seq_lens], batch_first=True, padding_value=0).bool()

print("English Attention Mask:")
print(english_masks)
print("French Attention Mask:")
print(french_masks)

### Pass through MHA ###
output = a(src=rand_english, tgt=rand_french, attention_mask=english_masks)
print("Final Output:", output.shape)


English Attention Mask:
tensor([[ True,  True,  True, False, False],
        [ True,  True,  True,  True,  True],
        [ True,  True,  True,  True, False]])
French Attention Mask:
tensor([[ True,  True,  True,  True,  True,  True,  True],
        [ True,  True,  True,  True,  True,  True, False],
        [ True,  True, False, False, False, False, False]])
after attention mask
tensor([[[[0.4222, 0.2891, 0.2887, 0.0000, 0.0000],
          [0.4217, 0.2769, 0.3015, 0.0000, 0.0000],
          [0.2570, 0.4412, 0.3018, 0.0000, 0.0000],
          [0.2927, 0.3663, 0.3410, 0.0000, 0.0000],
          [0.1743, 0.5946, 0.2311, 0.0000, 0.0000],
          [0.2537, 0.4107, 0.3356, 0.0000, 0.0000],
          [0.4089, 0.2630, 0.3281, 0.0000, 0.0000]],

         [[0.2534, 0.4615, 0.2850, 0.0000, 0.0000],
          [0.4160, 0.2547, 0.3293, 0.0000, 0.0000],
          [0.3078, 0.3459, 0.3463, 0.0000, 0.0000],
          [0.3992, 0.2319, 0.3689, 0.0000, 0.0000],
          [0.3385, 0.3016, 0.3599, 0.0000, 0