In [None]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

# Multi-Head Attendion
## insted of one set of WQ,WK,WV we use multiple sets of WQ,WK,WV(one set per head) then concatenate them

In [1]:
import numpy as np

In [4]:
#no. of tokens in sequence
sequence_len = 4

#embedding dimension
d_model = 8

#no. of heads
num_heads = 2

d_k = d_model // num_heads

x = np.random.rand(sequence_len, d_model)
print(x)
print("Shape : ",x.shape)

[[0.25956222 0.98277522 0.17384214 0.69557968 0.14521894 0.76165535
  0.36051313 0.01817976]
 [0.62865045 0.11134285 0.81020522 0.63884798 0.23628508 0.54954327
  0.89199386 0.94937864]
 [0.79526722 0.22907666 0.72682005 0.0978701  0.12811503 0.41897402
  0.16488601 0.43543915]
 [0.23793443 0.74085503 0.62276851 0.97132178 0.58277127 0.2088883
  0.99396971 0.70441528]]
Shape :  (4, 8)


## Create a single Q,K,V Matrix for all heads

In [9]:
W_Q = np.random.rand(d_model, d_model)
W_K = np.random.rand(d_model, d_model)
W_V = np.random.rand(d_model, d_model)

#output weight matrix
W_O = np.random.rand(d_model, d_model)

print(W_Q.shape, W_O.shape)

(8, 8) (8, 8)


In [8]:
Q = x @ W_Q
K = x @ W_K
V = x @ W_V

print("Shape : ",Q.shape)

Shape :  (4, 8)


## Splitting into heads

Before transpose :   
Token-wise grouping:
[  
  Token1 → [Head1, Head2]  
  Token2 → [Head1, Head2]  
  Token3 → [Head1, Head2]  
  Token4 → [Head1, Head2]  
]

after transpose :   
Head-wise grouping:
[  
  Head1 → [Token1, Token2, Token3, Token4]  
  Head2 → [Token1, Token2, Token3, Token4]
]

In [12]:
def split_head(x,num_heads):
    seq_len, d_model = x.shape
    d_k = d_model // num_heads

    return x.reshape(seq_len, num_heads, d_k).transpose(1,0,2)   #2 heads, 4 tokens, 4 features

In [11]:
Q = split_head(Q, num_heads)
K = split_head(K, num_heads)
V = split_head(V, num_heads)

print("Shape : ",Q.shape)

Shape :  (2, 4, 4)


## Scaled dot product attention

In [19]:
def attention(Q,K,V):
    scores = Q @ K.transpose(0,2,1)  # (seq_len, d_k) . (d_k, seq_len)
    scaled_scores = scores/np.sqrt(d_k)

    exp_score = np.exp(scaled_scores - np.max(scaled_scores, axis=1, keepdims=True))
    weights = exp_score / np.sum(exp_score, axis=1, keepdims=True)

    return weights @ V

In [27]:
attention_output = attention(Q,K,V)
print("Attention output : \n", attention_output)
print("Shape : ",attention_output.shape)

Attention output : 
 [[[0.19643224 0.16312995 0.15375897 0.26000008]
  [3.71027252 2.80866658 2.94323139 4.8826525 ]
  [0.21987734 0.18066746 0.17116901 0.2915263 ]
  [4.46956723 3.41398518 3.57380655 5.86625016]]

 [[0.03972046 0.12436724 0.09192507 0.10893019]
  [1.41702041 4.03827238 2.97508027 3.66383149]
  [0.05124659 0.15832554 0.11787337 0.13909673]
  [2.34771594 6.67622039 4.87923675 6.07384234]]]
Shape :  (2, 4, 4)


## Concatenate heads

In [23]:
def combine_heads(x):
    num_heads, seq_len, d_k = x.shape
    return x.transpose(1,0,2).reshape(seq_len, num_heads*d_k)

In [28]:
combined = combine_heads(attention_output)
print("Combined : \n", combined)
print("Shape : ", combined.shape)

Combined : 
 [[0.19643224 0.16312995 0.15375897 0.26000008 0.03972046 0.12436724
  0.09192507 0.10893019]
 [3.71027252 2.80866658 2.94323139 4.8826525  1.41702041 4.03827238
  2.97508027 3.66383149]
 [0.21987734 0.18066746 0.17116901 0.2915263  0.05124659 0.15832554
  0.11787337 0.13909673]
 [4.46956723 3.41398518 3.57380655 5.86625016 2.34771594 6.67622039
  4.87923675 6.07384234]]
Shape :  (4, 8)


## Final output Projection

In [26]:
output = combined @ W_O
print("Output : \n",output)
print("Shape : ",output.shape)

Output : 
 [[ 0.70011774  0.76679881  0.68603718  0.6784066   0.60919351  0.46505364
   0.54045798  0.65922593]
 [16.543151   17.97109289 15.89802806 15.72613439 13.2438327  10.47243305
  13.02129305 14.63928872]
 [ 0.82066264  0.89785018  0.80082514  0.79287402  0.70118878  0.53984805
   0.63610647  0.76229677]
 [23.46416715 25.42645276 22.33298768 22.19191217 17.99286795 14.41894328
  18.62665567 20.13867653]]
Shape :  (4, 8)


## Full Function for complete process

In [38]:
def multi_head_attention(X, num_heads):
    seq_len, d_model = X.shape
    d_k = d_model // num_heads
    
    W_Q = np.random.rand(d_model, d_model)
    W_K = np.random.rand(d_model, d_model)
    W_V = np.random.rand(d_model, d_model)
    W_O = np.random.rand(d_model, d_model)
    print("Shape of WQ,WK,WV,WO : ", W_Q.shape)
    
    Q = X @ W_Q
    K = X @ W_K
    V = X @ W_V
    print("\nShape of Q,K,V : ",Q.shape)
    
    Q = split_head(Q, num_heads)
    K = split_head(K, num_heads)
    V = split_head(V, num_heads)
    print("\nShape of Q,K,V after splitting heads : ",Q.shape)
    
    attention_output = attention(Q, K, V)
    print("\nAttention Output : \n", attention_output)
    print("Shape of Attention output : ", attention_output.shape)
    
    combined = combine_heads(attention_output)
    print("\nCombined output : \n",combined)
    print("Shape  : ",combined.shape)
    
    return combined @ W_O

In [39]:
output = multi_head_attention(x, num_heads)
print("\nOutput : \n", output)
print("Shape : ",output.shape)

Shape of WQ,WK,WV,WO :  (8, 8)

Shape of Q,K,V :  (4, 8)

Shape of Q,K,V after splitting heads :  (2, 4, 4)

Attention Output : 
 [[[0.34340241 0.17350628 0.31006168 0.30325394]
  [6.14475502 3.03911896 5.48529832 5.38118869]
  [0.15726633 0.08049152 0.14221146 0.13916164]
  [5.17095368 2.5563167  4.62136053 4.53183084]]

 [[0.07575317 0.08867292 0.11741894 0.09710465]
  [5.03603849 5.95313578 7.51012616 5.80497877]
  [0.17809756 0.20787699 0.27220411 0.222202  ]
  [1.94520672 2.2917526  2.92319917 2.29732966]]]
Shape of Attention output :  (2, 4, 4)

Combined output : 
 [[0.34340241 0.17350628 0.31006168 0.30325394 0.07575317 0.08867292
  0.11741894 0.09710465]
 [6.14475502 3.03911896 5.48529832 5.38118869 5.03603849 5.95313578
  7.51012616 5.80497877]
 [0.15726633 0.08049152 0.14221146 0.13916164 0.17809756 0.20787699
  0.27220411 0.222202  ]
 [5.17095368 2.5563167  4.62136053 4.53183084 1.94520672 2.2917526
  2.92319917 2.29732966]]
Shape  :  (4, 8)

Output : 
 [[ 0.80086817  0.7633