<a href="https://colab.research.google.com/github/ArpitKadam/ColabNotebooks/blob/main/Self_Attention_%26_Multi_Head_Attention_with_Trainable_Weights.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# **SELF ATTENTION**

In [None]:
import torch

input = torch.tensor(
    [[0.72, 0.45, 0.31],   ## Dream
     [0.75, 0.20, 0.55],   ## big
     [0.30, 0.80, 0.40],   ## and
     [0.85, 0.35, 0.60],   ## work
     [0.55, 0.15, 0.75],   ## for
     [0.20, 0.20, 0.85]]   ## it
)

words = ["Dream", "big", "and", "work", "for", "it"]

In [None]:
x_2 = input[1]
print(x_2)
d_in = input.shape[1]
print(d_in)
d_out = 2  ## Dimensionality of Context Vectors

tensor([0.7500, 0.2000, 0.5500])
3


In [None]:
torch.manual_seed(100)

W_query = torch.nn.Parameter(torch.randn(d_in, d_out), requires_grad=False)
W_key = torch.nn.Parameter(torch.randn(d_in, d_out), requires_grad=False)
W_value = torch.nn.Parameter(torch.randn(d_in, d_out), requires_grad=False)

In [None]:
print("W_query:")
print(W_query)
print()
print("W_key:")
print(W_key)
print()
print("W_out:")
print(W_value)

W_query:
Parameter containing:
tensor([[ 0.3607, -0.2859],
        [-0.3938,  0.2429],
        [-1.3833, -2.3134]])

W_key:
Parameter containing:
tensor([[-0.3172, -0.8660],
        [ 1.7482, -0.2759],
        [-0.9755,  0.4790]])

W_out:
Parameter containing:
tensor([[-2.3652, -0.8047],
        [ 0.6587, -0.2586],
        [-0.2510,  0.4770]])


In [None]:
query_2 = torch.matmul(x_2, W_query)
key_2 = torch.matmul(x_2, W_key)
value_2 = torch.matmul(x_2, W_value)

In [None]:
print("query_2:")
print(query_2)
print()
print("key_2:")
print(key_2)
print()
print("value_2:")
print(value_2)

query_2:
tensor([-0.5691, -1.4382])

key_2:
tensor([-0.4248, -0.4413])

value_2:
tensor([-1.7802, -0.3929])


### Calculating Q, K & V using x, Wq, Wk, Wv

In [None]:
keys = torch.matmul(input, W_key)
queries = torch.matmul(input, W_query)
values = torch.matmul(input, W_value)

In [None]:
print("keys:")
print(keys)
print("Keys shape:", keys.shape)
print()
print("queries:")
print(queries)
print("Queries shape:", queries.shape)
print()
print("values:")
print(values)
print("Values shape:", values.shape)

keys:
tensor([[ 0.2559, -0.5992],
        [-0.4248, -0.4413],
        [ 0.9132, -0.2889],
        [-0.2430, -0.5453],
        [-0.6438, -0.1585],
        [-0.5430,  0.1788]])
Keys shape: torch.Size([6, 2])

queries:
tensor([[-0.3464, -0.8137],
        [-0.5691, -1.4382],
        [-0.7602, -0.8168],
        [-0.6613, -1.5461],
        [-0.8982, -1.8559],
        [-1.1825, -1.9750]])
Queries shape: torch.Size([6, 2])

values:
tensor([[-1.4843, -0.5479],
        [-1.7802, -0.3929],
        [-0.2829, -0.2575],
        [-1.9304, -0.4883],
        [-1.3903, -0.1236],
        [-0.5546,  0.1928]])
Values shape: torch.Size([6, 2])


### Keys corresponding to second token (x2) and the attention of second token to itself

In [None]:
print("Keys corresponding to second token (x2):")
print(key_2)
print()
print("Attention of second token to itself:")
attn_score_22 = torch.matmul(query_2, key_2)
print(attn_score_22)

Keys corresponding to second token (x2):
tensor([-0.4248, -0.4413])

Attention of second token to itself:
tensor(0.8764)


### All Attention Scores for query_2 (meaning for "big" (x2))

In [None]:
attn_score_2 = torch.matmul(query_2, keys.T)
print("Attention Scores for query_2 (meaning for 'big'):")
print(attn_score_2)

Attention Scores for query_2 (meaning for 'big'):
tensor([ 0.7162,  0.8764, -0.1042,  0.9226,  0.5943,  0.0519])


### Complete Attention Score

In [None]:
attn_score = torch.matmul(queries, keys.T)
print("Complete Attention Score:")
print(attn_score)

Complete Attention Score:
tensor([[ 0.3989,  0.5062, -0.0812,  0.5279,  0.3520,  0.0426],
        [ 0.7162,  0.8764, -0.1042,  0.9226,  0.5943,  0.0519],
        [ 0.2949,  0.6833, -0.4582,  0.6302,  0.6189,  0.2667],
        [ 0.7572,  0.9631, -0.1572,  1.0038,  0.6708,  0.0827],
        [ 0.8822,  1.2005, -0.2841,  1.2303,  0.8724,  0.1559],
        [ 0.8808,  1.3738, -0.5092,  1.3644,  1.0743,  0.2890]])


### Scale by 1/sqrt(d) and take softmax

In [None]:
d_k = keys.shape[-1]
attn_weights_2 = torch.softmax(attn_score_2 / (d_k ** 0.5), dim=-1)

print("d value:")
print(d_k)
print()
print("Attention Weights for 'big':")
print(attn_weights_2)

d value:
2

Attention Weights for 'big':
tensor([0.1859, 0.2082, 0.1041, 0.2151, 0.1705, 0.1162])


In [None]:
attn_weights = torch.softmax(attn_score / (d_k ** 0.5), dim=-1)
print("Attention Weights:")
print(attn_weights)

Attention Weights:
tensor([[0.1776, 0.1916, 0.1265, 0.1945, 0.1718, 0.1380],
        [0.1859, 0.2082, 0.1041, 0.2151, 0.1705, 0.1162],
        [0.1560, 0.2054, 0.0916, 0.1978, 0.1962, 0.1530],
        [0.1841, 0.2129, 0.0964, 0.2191, 0.1732, 0.1143],
        [0.1798, 0.2252, 0.0788, 0.2300, 0.1786, 0.1076],
        [0.1666, 0.2360, 0.0623, 0.2345, 0.1910, 0.1096]])


### Context Vector

In [None]:
context_vec_2 = torch.matmul(attn_weights_2, values)
print("Context Vector for 'big':")
print(context_vec_2)

Context Vector for 'big':
tensor([-1.3927, -0.3141])


In [None]:
context_vec = torch.matmul(attn_weights, values)
print("Context Vector:")
print(context_vec)

Context Vector:
tensor([[-1.3314, -0.2947],
        [-1.3927, -0.3141],
        [-1.3626, -0.2811],
        [-1.4067, -0.3157],
        [-1.4420, -0.3209],
        [-1.4640, -0.3170]])


In [None]:
import torch.nn as nn

class SelfAttention_v1(nn.Module):
  def __init__(self, d_in, d_out):
    super().__init__()
    self.W_query = nn.Parameter(torch.randn(d_in, d_out), requires_grad=False)
    self.W_key = nn.Parameter(torch.randn(d_in, d_out), requires_grad=False)
    self.W_value = nn.Parameter(torch.randn(d_in, d_out), requires_grad=False)

  def forward(self, x):
    keys = torch.matmul(x, self.W_key)
    queries = torch.matmul(x, self.W_query)
    values = torch.matmul(x, self.W_value)

    attn_scores = torch.matmul(queries, keys.T)

    attn_weights = torch.softmax(attn_scores / (keys.shape[-1] ** (0.5)), dim=-1)

    context_vec = torch.matmul(attn_weights, values)

    return context_vec

In [None]:
torch.manual_seed(100)

sa_v1 = SelfAttention_v1(d_in=3, d_out=2)
context_vec = sa_v1(input)
print("Context Vector:")
print(context_vec)

Context Vector:
tensor([[-1.3314, -0.2947],
        [-1.3927, -0.3141],
        [-1.3626, -0.2811],
        [-1.4067, -0.3157],
        [-1.4420, -0.3209],
        [-1.4640, -0.3170]])


In [None]:
class SelfAttention_v2(nn.Module):
  def __init__(self, d_in, d_out, qkv_bias):
    super().__init__()
    self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
    self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
    self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)

  def forward(self, x):
    keys = self.W_key(x)
    queries = self.W_query(x)
    values = self.W_value(x)

    attn_scores = torch.matmul(queries, keys.T)

    attn_weights = torch.softmax(attn_scores / (keys.shape[-1] ** (0.5)), dim=-1)

    context_vec = torch.matmul(attn_weights, values)

    return context_vec

In [None]:
torch.manual_seed(100)

sa_v2 = SelfAttention_v2(d_in=3, d_out=2, qkv_bias=True)
context_vec = sa_v2(input)
print("Context Vector:")
print(context_vec)

Context Vector:
tensor([[0.3103, 0.5039],
        [0.3111, 0.5033],
        [0.3092, 0.5047],
        [0.3112, 0.5032],
        [0.3112, 0.5033],
        [0.3106, 0.5037]], grad_fn=<MmBackward0>)


# **MULTI-HEAD ATTENTION**

In [1]:
import torch

input = torch.tensor(
    [[0.72, 0.45, 0.31],   ## Dream
     [0.75, 0.20, 0.55],   ## big
     [0.30, 0.80, 0.40],   ## and
     [0.85, 0.35, 0.60],   ## work
     [0.55, 0.15, 0.75],   ## for
     [0.20, 0.20, 0.85]]   ## it
)

words = ["Dream", "big", "and", "work", "for", "it"]

In [12]:
from torch.autograd import forward_ad
import torch.nn as nn

class CasualAttention(nn.Module):

  def __init__(self, d_in, d_out, context_length, dropout, qkv_bias):
    super().__init__()

    self.d_out = d_out
    self.d_in = d_in
    self.context_length = context_length
    self.dropout = nn.Dropout(dropout)
    self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
    self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
    self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
    self.register_buffer("mask", torch.triu(torch.ones(context_length, context_length), diagonal=1))

  def forward(self, x):
    b, num_tokens, d_in = x.shape   ## Batch Dimension b
    keys = self.W_key(x)
    queries = self.W_query(x)
    values = self.W_value(x)

    attn_scores = torch.matmul(queries, keys.transpose(1, 2))  ## Here we take 1 and 2 because 0th dimension is Batch Dimension

    attn_scores.masked_fill_(
        self.mask.bool()[:num_tokens, :num_tokens], -torch.inf
    )

    attn_weights = torch.softmax(attn_scores / (keys.shape[-1] ** 0.5), dim=-1)

    attn_weights = self.dropout(attn_weights)

    context_vec = torch.matmul(attn_weights, values)

    return context_vec

In [13]:
d_in = input.shape[-1]
d_out = 2
batch = torch.stack((input, input), dim=0)
print(batch)
print(batch.shape)

tensor([[[0.7200, 0.4500, 0.3100],
         [0.7500, 0.2000, 0.5500],
         [0.3000, 0.8000, 0.4000],
         [0.8500, 0.3500, 0.6000],
         [0.5500, 0.1500, 0.7500],
         [0.2000, 0.2000, 0.8500]],

        [[0.7200, 0.4500, 0.3100],
         [0.7500, 0.2000, 0.5500],
         [0.3000, 0.8000, 0.4000],
         [0.8500, 0.3500, 0.6000],
         [0.5500, 0.1500, 0.7500],
         [0.2000, 0.2000, 0.8500]]])
torch.Size([2, 6, 3])


In [14]:
class MultiHeadAttentionWrapper(nn.Module):

  def __init__(self, d_in, d_out, context_length, num_heads, dropout, qkv_bias):
    super().__init__()
    self.heads = nn.ModuleList(
        [CasualAttention(d_in, d_out, context_length, dropout, qkv_bias) for _ in range(num_heads)]
    )

  def forward(self, x):
    return torch.cat([head(x) for head in self.heads], dim=-1)

In [15]:
torch.manual_seed(100)

d_in = input.shape[-1]
d_out = 2

batch = torch.stack((input, input), dim=0)  ## Here batch size is 2
context_length = batch.shape[1]

mha = MultiHeadAttentionWrapper(d_in, d_out, context_length, num_heads=3, dropout=0.1, qkv_bias=True)

In [16]:
context_vec = mha(batch)
print("Context Vector:")
print(context_vec)
print("Context Vector Shape:")
print(context_vec.shape)

Context Vector:
tensor([[[ 0.0000,  0.0000, -0.7784, -0.2237,  1.2173,  0.1742],
         [ 0.4949,  0.4700, -0.6811, -0.1936,  1.2268,  0.1584],
         [ 0.4578,  0.4371, -0.7688, -0.2008,  0.7948,  0.0513],
         [ 0.4724,  0.4533, -0.7448, -0.2051,  1.2416,  0.1103],
         [ 0.4199,  0.5063, -0.5164, -0.1443,  0.7687,  0.0858],
         [ 0.3511,  0.4237, -0.4388, -0.0960,  1.1969,  0.0722]],

        [[ 0.5688,  0.3726, -0.7784, -0.2237,  1.2173,  0.1742],
         [ 0.4949,  0.4700, -0.6811, -0.1936,  1.2268,  0.1584],
         [ 0.4578,  0.4371, -0.7688, -0.2008,  1.1989,  0.1091],
         [ 0.4724,  0.4533, -0.7448, -0.2051,  1.2416,  0.1103],
         [ 0.2964,  0.2521, -0.6005, -0.1651,  1.2282,  0.0973],
         [ 0.1917,  0.2989, -0.6771, -0.1637,  0.8107,  0.0428]]],
       grad_fn=<CatBackward0>)
Context Vector Shape:
torch.Size([2, 6, 6])


In [17]:
torch.manual_seed(100)

d_in = input.shape[-1]
d_out = 2

batch = torch.stack((input, ), dim=0)  ## Here batch size is 1
context_length = batch.shape[1]

mha = MultiHeadAttentionWrapper(d_in, d_out, context_length, num_heads=3, dropout=0.1, qkv_bias=True)

In [18]:
context_vec = mha(batch)
print("Context Vector:")
print(context_vec)
print("Context Vector Shape:")
print(context_vec.shape)

Context Vector:
tensor([[[ 0.0000,  0.0000, -0.7784, -0.2237,  1.2173,  0.1742],
         [ 0.4949,  0.4700, -0.6811, -0.1936,  1.2268,  0.1584],
         [ 0.4578,  0.4371, -0.7688, -0.2008,  1.1989,  0.1091],
         [ 0.4724,  0.4533, -0.7448, -0.2051,  1.2416,  0.1103],
         [ 0.4199,  0.5063, -0.4810, -0.1316,  1.0136,  0.0963],
         [ 0.3511,  0.4237, -0.3001, -0.0828,  0.7902,  0.0196]]],
       grad_fn=<CatBackward0>)
Context Vector Shape:
torch.Size([1, 6, 6])


In [19]:
torch.manual_seed(100)

d_in = input.shape[-1]
d_out = 2

batch = torch.stack((input, ), dim=0)  ## Here batch size is 1
context_length = batch.shape[1]

mha = MultiHeadAttentionWrapper(d_in, d_out, context_length, num_heads=10, dropout=0.3, qkv_bias=True)

In [20]:
context_vec = mha(batch)
print("Context Vector:")
print(context_vec)
print("Context Vector Shape:")
print(context_vec.shape)

Context Vector:
tensor([[[ 0.0000,  0.0000, -1.0008, -0.2876,  0.0000,  0.0000, -0.4933,
          -0.5937,  0.0000,  0.0000, -0.0992,  0.9979,  0.7522, -0.9791,
           0.2575, -0.6159, -0.1867, -0.7596, -0.3448, -0.9653],
         [ 0.6363,  0.6043, -0.8757, -0.2490,  1.5773,  0.2036, -0.4279,
          -0.7763,  0.3010,  0.3950, -0.1434,  1.0314,  0.3667, -0.4773,
           0.0594, -0.8003, -0.2404, -0.6143, -0.1270, -0.6795],
         [ 0.4227,  0.4015, -0.9884, -0.2582,  0.5195,  0.0743, -0.2724,
          -0.4824,  0.3355,  0.2783, -0.2392,  1.0062,  0.6148, -0.6874,
           0.0867, -0.2074, -0.2970, -0.7485, -0.1171, -0.3277],
         [ 0.1248,  0.1208, -0.9575, -0.2637,  1.2608,  0.1403, -0.1082,
          -0.2462,  0.3563,  0.3711, -0.0497,  0.5579,  0.3758, -0.5190,
           0.1020, -0.3012, -0.2596, -0.4883, -0.3347, -1.0533],
         [ 0.4397,  0.5540, -0.4914, -0.1289,  1.2206,  0.0962, -0.3502,
          -0.7201,  0.2273,  0.2372, -0.2233,  0.8186,  0.6101, -0.