In [1]:
import torch

#each row corresponds to each token
inputs=torch.tensor(
    [
        [0.72,0.45,0.31],
        [0.75,0.20,0.55],
        [0.30,0.80,0.40],
        [0.85,0.35,0.60],
        [0.55,0.15,0.75],
        [0.25,0.20,0.85]
    ]
)

words=['Dream','big','and','work','for','it']

In [2]:
import torch.nn as nn

class CasualAttention(nn.Module):

  def __init__(self,d_in,d_out,context_length,dropout=0.5,qkv_bias=False):
    super().__init__()

    self.d_out=d_out
    self.W_query=nn.Linear(d_in,d_out,bias=qkv_bias)
    self.W_key=nn.Linear(d_in,d_out,bias=qkv_bias)
    self.W_value=nn.Linear(d_in,d_out,bias=qkv_bias)
    self.dropout=nn.Dropout(dropout)
    self.register_buffer('mask',torch.triu(torch.ones(context_length,context_length),diagonal=1))

  def forward(self,x):
    batch_size,num_tokens,d_in=x.shape

    keys=self.W_key(x)
    queries=self.W_query(x)
    values=self.W_value(x)

    attn_scores=queries@keys.transpose(1,2)
    attn_scores.masked_fill_(self.mask.bool()[:num_tokens,:num_tokens],-torch.inf)
    attn_weights=torch.softmax(attn_scores/keys.shape[-1]**0.5,dim=-1)

    attn_weights=self.dropout(attn_weights)

    context_vector=attn_weights @ values
    return context_vector

In [3]:
d_in=inputs.shape[-1]
d_out=2

In [4]:
batch=torch.stack((inputs,inputs),dim=0)
print(batch.shape)

torch.Size([2, 6, 3])


In [5]:
#Multi-Head Attention without weight splits - inefficient as each head requires separate weight,query,value matrix multiplications
class MultiHeadAttentionWrapper(nn.Module):
  def __init__(self,d_in,d_out,context_length,dropout,num_heads,qkv_bias=False):
    super().__init__()
    self.heads=nn.ModuleList(
        [CasualAttention(d_in,d_out,context_length,dropout,qkv_bias) for _ in range(num_heads)]
    )

  def forward(self,x):
    #concatenating all context vectors from all heads along column
    return torch.cat([head(x) for head in self.heads],dim=-1)

In [6]:
torch.manual_seed(123)
context_length=batch.shape[1]
d_in,d_out=3,2
mha=MultiHeadAttentionWrapper(d_in,d_out,context_length,0.0,num_heads=2)

In [7]:
context_vecs=mha(batch)
print(context_vecs)
print(context_vecs.shape)
#here the shape last column 4 denotes that each input is 6x2 but because of 2 heads it is 2*2=4

tensor([[[-0.5762, -0.1627,  0.5569,  0.3635],
         [-0.5650, -0.0630,  0.5599,  0.3006],
         [-0.5472, -0.1226,  0.5285,  0.3435],
         [-0.5787, -0.0943,  0.5621,  0.3388],
         [-0.5593, -0.0436,  0.5509,  0.3046],
         [-0.5287, -0.0033,  0.5277,  0.2743]],

        [[-0.5762, -0.1627,  0.5569,  0.3635],
         [-0.5650, -0.0630,  0.5599,  0.3006],
         [-0.5472, -0.1226,  0.5285,  0.3435],
         [-0.5787, -0.0943,  0.5621,  0.3388],
         [-0.5593, -0.0436,  0.5509,  0.3046],
         [-0.5287, -0.0033,  0.5277,  0.2743]]], grad_fn=<CatBackward0>)
torch.Size([2, 6, 4])


In [10]:
x=torch.tensor([[
    [1.0,2.0,3.0,4.0,5.0,6.0],
    [6.0,5.0,4.0,3.0,2.0,1.0],
    [1.0,1.0,1.0,1.0,1.0,1.0]
]])

batch_size,num_tokens,d_in=x.shape

In [18]:
#random matrices for wq,wk,wv
torch.manual_seed(0)

Wq=torch.randn(d_in,d_in)
Wk=torch.randn(d_in,d_in)
Wv=torch.randn(d_in,d_in)

key=x@Wk
query=x@Wq
value=x@Wv


print("Query:\n",query)
print("Key:\n",key)
print("Value:\n",value)

print("x shape:",x.shape)
print("Wq shape:",Wq.shape)
print("Wk shape:",Wk.shape)
print("Wv shape:",Wv.shape)
print("query shape:",query.shape)
print("key shape:",key.shape)
print("value shape:",value.shape)

Query:
 tensor([[[ -9.0244, -11.7287,  15.5360,  -1.4474,  -4.5326,   9.4674],
         [ -8.0564, -13.2309,   8.2228,  -8.9680,   3.1995,   4.8321],
         [ -2.4401,  -3.5657,   3.3941,  -1.4879,  -0.1904,   2.0428]]])
Key:
 tensor([[[  8.2602,  14.1116,  -5.0345, -16.4865,  -2.9948,   8.3139],
         [ -6.1188,  -0.1587,  -5.0885, -14.3014,   4.9540,   5.6093],
         [  0.3059,   1.9933,  -1.4461,  -4.3983,   0.2799,   1.9890]]])
Value:
 tensor([[[ 0.5076, -3.4353,  1.8576,  2.8041,  8.9427, 13.1841],
         [-1.9113, -3.6934,  1.8502,  1.7622,  1.6981,  3.0978],
         [-0.2005, -1.0184,  0.5297,  0.6523,  1.5201,  2.3260]]])
x shape: torch.Size([1, 3, 6])
Wq shape: torch.Size([6, 6])
Wk shape: torch.Size([6, 6])
Wv shape: torch.Size([6, 6])
query shape: torch.Size([1, 3, 6])
key shape: torch.Size([1, 3, 6])
value shape: torch.Size([1, 3, 6])


In [19]:
num_heads=2
head_dim=3

query=query.view(1,3,num_heads,head_dim)
key=key.view(1,3,num_heads,head_dim)
value=value.view(1,3,num_heads,head_dim)

#(batch_size,no_of-tokens,no_of_heads,head_dim)
print("Query after unrolling:",query)
print("\nQuery shape after unrolling:",query.shape)
print("Key after unrolling:",key)
print("\nKey shape after unrolling:",key.shape)
print("Value after unrolling:",value)
print("\nValue shape after unrolling:",value.shape)
#here the values are arranged for each token with 2 heads of 3 dimensions

Query after unrolling: tensor([[[[ -9.0244, -11.7287,  15.5360],
          [ -1.4474,  -4.5326,   9.4674]],

         [[ -8.0564, -13.2309,   8.2228],
          [ -8.9680,   3.1995,   4.8321]],

         [[ -2.4401,  -3.5657,   3.3941],
          [ -1.4879,  -0.1904,   2.0428]]]])

Query shape after unrolling: torch.Size([1, 3, 2, 3])
Key after unrolling: tensor([[[[  8.2602,  14.1116,  -5.0345],
          [-16.4865,  -2.9948,   8.3139]],

         [[ -6.1188,  -0.1587,  -5.0885],
          [-14.3014,   4.9540,   5.6093]],

         [[  0.3059,   1.9933,  -1.4461],
          [ -4.3983,   0.2799,   1.9890]]]])

Key shape after unrolling: torch.Size([1, 3, 2, 3])
Value after unrolling: tensor([[[[ 0.5076, -3.4353,  1.8576],
          [ 2.8041,  8.9427, 13.1841]],

         [[-1.9113, -3.6934,  1.8502],
          [ 1.7622,  1.6981,  3.0978]],

         [[-0.2005, -1.0184,  0.5297],
          [ 0.6523,  1.5201,  2.3260]]]])

Value shape after unrolling: torch.Size([1, 3, 2, 3])


In [20]:
#grouping based on heads

#(batch_size,no_of-tokens,no_of_heads,head_dim)
#dim-1(no_of_tokens),dim-2(no_of_heads) before transposing
query=query.transpose(1,2)
value=value.transpose(1,2)
key=key.transpose(1,2)

#After transposing
#(batch_size,no_of_heads,no_of-tokens,head_dim)

print("Query after grouped by heads:",query)
print("\nQuery shape after grouped by heads:",query.shape)
print("Key after grouped by heads:",key)
print("\nKey shape after grouped by heads:",key.shape)
print("Value after grouped by heads:",value)
print("\nValue shape after grouped by heads:",value.shape)


Query after grouped by heads: tensor([[[[ -9.0244, -11.7287,  15.5360],
          [ -8.0564, -13.2309,   8.2228],
          [ -2.4401,  -3.5657,   3.3941]],

         [[ -1.4474,  -4.5326,   9.4674],
          [ -8.9680,   3.1995,   4.8321],
          [ -1.4879,  -0.1904,   2.0428]]]])

Query shape after grouped by heads: torch.Size([1, 2, 3, 3])
Key after grouped by heads: tensor([[[[  8.2602,  14.1116,  -5.0345],
          [ -6.1188,  -0.1587,  -5.0885],
          [  0.3059,   1.9933,  -1.4461]],

         [[-16.4865,  -2.9948,   8.3139],
          [-14.3014,   4.9540,   5.6093],
          [ -4.3983,   0.2799,   1.9890]]]])

Key shape after grouped by heads: torch.Size([1, 2, 3, 3])
Value after grouped by heads: tensor([[[[ 0.5076, -3.4353,  1.8576],
          [-1.9113, -3.6934,  1.8502],
          [-0.2005, -1.0184,  0.5297]],

         [[ 2.8041,  8.9427, 13.1841],
          [ 1.7622,  1.6981,  3.0978],
          [ 0.6523,  1.5201,  2.3260]]]])

Value shape after grouped by heads: 

In [22]:
#Finding attention scores

#transpose last 2 dimensions of key

#before transpose:#(batch_size,no_of_heads,no_of-tokens,head_dim)
key_t=key.transpose(2,3)
#after transpose:#(batch_size,no_of_heads,head_dim,no_of-tokens)

attn_scores=query@key_t
print("Attention scores:",attn_scores)
#(batch_size,no_of_heads,no_of-tokens,no_of-tokens)- tells how much each token related with all tokens for each head
print("Attention scores shape:",attn_scores.shape)

Attention scores: tensor([[[[-318.2692,  -21.9748,  -48.6063],
          [-294.6535,    9.5538,  -40.7285],
          [ -87.5604,   -1.7744,  -12.7621]],

         [[ 116.1476,   51.3506,   23.9283],
          [ 178.4425,  171.2106,   49.9505],
          [  42.0843,   31.7945,   10.5541]]]])
Attention scores shape: torch.Size([1, 2, 3, 3])


In [25]:
#apply mask
seq_len=x.shape[1]#no of tokens

mask=torch.triu(torch.ones(seq_len,seq_len),diagonal=1).bool()
print("Casual mask:\n",mask)

attn_scores.masked_fill_(mask,-torch.inf)
print("Attention scores after masking:",attn_scores)

Casual mask:
 tensor([[False,  True,  True],
        [False, False,  True],
        [False, False, False]])
Attention scores after masking: tensor([[[[-318.2692,      -inf,      -inf],
          [-294.6535,    9.5538,      -inf],
          [ -87.5604,   -1.7744,  -12.7621]],

         [[ 116.1476,      -inf,      -inf],
          [ 178.4425,  171.2106,      -inf],
          [  42.0843,   31.7945,   10.5541]]]])


In [27]:
#scaling by square root of key dimension and apply softmax
torch.set_printoptions(precision=3,sci_mode=False)
head_dim=3 #head dimension for key
attn_weights=torch.softmax(attn_scores/head_dim**0.5,dim=-1)
print("Attention weights shape:",attn_weights.shape)
print("Attention weights :",attn_weights)

Attention weights shape: torch.Size([1, 2, 3, 3])
Attention weights : tensor([[[[1.000, 0.000, 0.000],
          [0.000, 1.000, 0.000],
          [0.000, 0.998, 0.002]],

         [[1.000, 0.000, 0.000],
          [0.985, 0.015, 0.000],
          [0.997, 0.003, 0.000]]]])


In [28]:
#Apply dropout
dropout=torch.nn.Dropout(0.1)
attn_weights=dropout(attn_weights)
print("Attention weight after dropout:",attn_weights)

Attention weight after dropout: tensor([[[[1.111, 0.000, 0.000],
          [0.000, 1.111, 0.000],
          [0.000, 1.109, 0.002]],

         [[1.111, 0.000, 0.000],
          [1.094, 0.017, 0.000],
          [1.108, 0.003, 0.000]]]])


In [29]:
#find context vector
context_vec=attn_weights@value
print("Context vector:",context_vec)
print("Context vector shape:",context_vec.shape)
#(batch_size,no_of_heads,no_of_tokens,head_dim)

Context vector: tensor([[[[ 0.564, -3.817,  2.064],
          [-2.124, -4.104,  2.056],
          [-2.120, -4.099,  2.053]],

         [[ 3.116,  9.936, 14.649],
          [ 3.098,  9.814, 14.479],
          [ 3.113,  9.915, 14.620]]]])
Context vector shape: torch.Size([1, 2, 3, 3])


In [31]:
#again swap dimension 1,2
#before swap:(batch_size,no_of_heads,no_of_tokens,head_dim)
context_vec=context_vec.transpose(1,2)
#after swap:(batch_size,no_of_tokens,no_of_heads,head_dim)
print("Context vector after swapping:",context_vec)
print("Context vector shape after swapping:",context_vec.shape)

Context vector after swapping: tensor([[[[ 0.564, -3.817,  2.064],
          [-2.124, -4.104,  2.056],
          [-2.120, -4.099,  2.053]],

         [[ 3.116,  9.936, 14.649],
          [ 3.098,  9.814, 14.479],
          [ 3.113,  9.915, 14.620]]]])
Context vector shape after swapping: torch.Size([1, 2, 3, 3])


In [32]:
#concatenating heads
context_vec=context_vec.reshape(batch_size,seq_len,num_heads*head_dim)
print("Context vector after concatenating heads:",context_vec)
print("Context vector shape after concatenating heads:",context_vec.shape)

Context vector after concatenating heads: tensor([[[ 0.564, -3.817,  2.064, -2.124, -4.104,  2.056],
         [-2.120, -4.099,  2.053,  3.116,  9.936, 14.649],
         [ 3.098,  9.814, 14.479,  3.113,  9.915, 14.620]]])
Context vector shape after concatenating heads: torch.Size([1, 3, 6])


In [43]:
 #Multi-Head Attention Class

class MultiHeadAttention(nn.Module):
  def __init__(self,d_in,d_out,context_length,num_heads,dropout,qkv_bias=False):
    super().__init__()

    assert (d_out%num_heads==0),"d_out must be divisible by num_heads"

    self.d_out=d_out
    self.num_heads=num_heads
    self.head_dim=d_out//num_heads

    self.W_key=nn.Linear(d_in,d_out,bias=qkv_bias)
    self.W_query=nn.Linear(d_in,d_out,bias=qkv_bias)
    self.W_value=nn.Linear(d_in,d_out,bias=qkv_bias)

    self.dropout=nn.Dropout(dropout)
    self.out_proj=nn.Linear(d_out,d_out)
    self.register_buffer("mask",torch.triu(torch.ones(context_length,context_length),diagonal=1))

  def forward(self,x):
      batch_size,num_of_tokens,d_in=x.shape

      key=self.W_key(x)
      query=self.W_query(x)
      value=self.W_value(x)

      # (batch, tokens, d_out) → (batch, tokens, num_heads, head_dim)
      key = key.view(batch_size, num_of_tokens, self.num_heads, self.head_dim)
      query = query.view(batch_size, num_of_tokens, self.num_heads, self.head_dim)
      value = value.view(batch_size, num_of_tokens, self.num_heads, self.head_dim)

      #arrange key,query,value based on no of heads
      #(batch, tokens, num_heads, head_dim) → (batch, num_heads, tokens, head_dim)
      key=key.transpose(1,2)
      value=value.transpose(1,2)
      query=query.transpose(1,2)

      attn_scores=query@key.transpose(2,3)

      mask_bool=self.mask.bool()[:num_of_tokens,:num_of_tokens]

      attn_scores.masked_fill_(mask_bool,-torch.inf)

      attn_weights=torch.softmax(attn_scores/key.shape[-1]**0.5,dim=-1)

      attn_weights=self.dropout(attn_weights)

      context_vec=attn_weights@value

      #reshape and conactenating heads
      context_vec=context_vec.transpose(1,2)

      context_vec=context_vec.contiguous().view(batch_size,num_of_tokens,self.d_out)
      context_vec=self.out_proj(context_vec)

      return context_vec



In [44]:
torch.manual_seed(123)

inputs= torch.tensor([
    [0.12, 0.87, 0.45, 0.33, 0.91, 0.58],
    [0.76, 0.24, 0.69, 0.11, 0.54, 0.82],
    [0.39, 0.95, 0.18, 0.67, 0.44, 0.06]
])

batch=torch.stack((inputs,inputs),dim=0)
print(batch.shape)

batch_size,context_length,d_in=batch.shape
d_out=6
mha=MultiHeadAttention(d_in,d_out,context_length,2,0.0)
context_vecs=mha(batch)
print(context_vecs)
print(context_vecs.shape)


torch.Size([2, 3, 6])
tensor([[[ 0.079, -0.106,  0.050,  0.058, -0.425, -0.357],
         [ 0.067, -0.036, -0.002, -0.024, -0.330, -0.274],
         [ 0.065, -0.079,  0.022, -0.008, -0.360, -0.334]],

        [[ 0.079, -0.106,  0.050,  0.058, -0.425, -0.357],
         [ 0.067, -0.036, -0.002, -0.024, -0.330, -0.274],
         [ 0.065, -0.079,  0.022, -0.008, -0.360, -0.334]]],
       grad_fn=<ViewBackward0>)
torch.Size([2, 3, 6])
