In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import copy
from torch.autograd import Variable
import matplotlib.pyplot as plt
import seaborn
seaborn.set_context(context="talk")


## Implementation of multi-head attention

In [2]:
## defining our input variable
x = torch.randn(1,3,6)
print(x)

tensor([[[ 1.2851,  0.0689,  0.5441, -0.0941, -0.7290,  0.4415],
         [-1.6659,  0.3111, -1.2321,  1.6866, -2.0077,  0.5211],
         [ 0.9060,  1.1238, -0.5217, -0.3128,  1.2894,  1.1534]]])


In [3]:
print(x.shape)


torch.Size([1, 3, 6])


* The shape of our embedding matrix is [1,3,6]. 1 means that our embedding matrix has a batch size of 1, has three tokens and d_in of 6.
* Now let's calculate the head_dim

In [4]:
w_q = torch.randn(6,6)
w_k = torch.randn(6,6)
w_v = torch.randn(6,6)
print(w_q)
print(w_k)
print(w_v)

tensor([[-0.8967,  0.3720,  0.7535, -2.1120, -0.6774, -0.0399],
        [-1.9073,  1.5266,  0.1816,  0.0970, -0.4176,  0.8444],
        [-0.4699, -1.1347,  0.7681, -0.0721,  0.8910,  0.9507],
        [ 0.3501, -2.5955,  0.7957,  0.8863,  1.1494, -0.7801],
        [-0.7152, -0.1136, -0.8741, -0.4199,  0.6200, -0.0035],
        [-0.0411, -0.0320, -1.9584, -0.4949,  0.2590, -1.3881]])
tensor([[ 0.8658,  0.4797,  1.6069,  0.7115,  0.4459, -0.1898],
        [ 0.6261,  0.5471,  0.4684,  0.5704,  1.2109,  2.0190],
        [ 0.3319,  1.2252, -0.3378, -0.4697, -0.6574, -0.4810],
        [-1.8914, -2.3508,  0.4883,  1.0043,  0.9003, -0.0177],
        [ 0.4347,  0.3509, -0.2219,  1.2907, -0.8218, -2.1762],
        [ 1.0719,  0.1443,  1.2829, -0.3455, -0.6673,  0.4915]])
tensor([[ 1.1864, -0.8665,  0.4948,  1.0562, -0.3337,  0.8361],
        [ 0.4017, -0.4543,  1.1050, -1.5086,  0.6530, -2.5984],
        [ 1.1178, -0.0788, -0.3701, -0.1889, -0.1921,  2.5239],
        [-0.2591,  1.7034, -0.5846,  0

In [5]:
## calculating query,key and value vectors of our input embedding
q = torch.matmul(x,w_q)
k = torch.matmul(x,w_k)
v = torch.matmul(x,w_v)

print("query vector:\n",q)
print("key vector:\n",k)
print("value vector:\n",v)


query vector:
 tensor([[[-1.0691,  0.2787,  1.0963, -2.7425, -0.8602, -0.0129],
         [ 3.4841, -2.9129, -0.0688,  5.7173,  0.7295, -2.8742],
         [-3.7898,  3.2731, -3.1487, -3.1563, -0.8091, -0.9448]]])
key vector:
 tensor([[[ 1.6707e+00,  1.3500e+00,  2.5958e+00, -4.8993e-01,  5.1850e-01,
           1.4385e+00],
         [-5.1605e+00, -6.7326e+00, -1.7743e-01, -1.5066e+00,  3.2645e+00,
           6.1325e+00],
         [ 3.7034e+00,  1.7645e+00,  3.1995e+00,  2.4821e+00, -3.2199e-03,
           1.1429e-01]]])
value vector:
 tensor([[[ 0.1827, -1.8770, -0.7480,  1.4792, -0.4581,  2.1782],
         [-6.8700,  4.0457, -4.2274, -0.7914, -0.1641, -6.5143],
         [-1.8634, -4.2756,  3.2738, -0.3711,  2.0484, -1.5570]]])


In [7]:
class MultiHeadAttention(nn.Module):
  def __init__(self,d_in,d_out,context_length,dropout,num_heads,qkv_bias=False):
    super().__init__()
    assert (d_out % num_heads == 0), \
       "d_out must be divisible by num_heads"
    self.d_out = d_out
    self.num_heads = num_heads
    self.head_dim = d_out // num_heads
    self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
    self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
    self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
    self.out_proj = nn.Linear(d_out, d_out) # linear layer to combine head output
    self.dropout = nn.Dropout(dropout)
    self.register_buffer(
        "mask",
        torch.tril(torch.ones(context_length,context_length), diagonal=1)
    )

  def forward(self,x):
    b, num_tokens,d_in = x.shape
    keys = self.W_key(x)
    queries = self.W_query(x)
    values = self.W_value(x)

    #splitting the matrix by adding a `num_heads` dimension
    ## unroll last dim: (b,num_tokens,d_out) -> (b,num_tokens,num_heads,head_dim)
    keys = keys.view(b,num_tokens,self.num_heads,self.head_dim)
    queries = queries.view(b,num_tokens,self.num_heads,self.head_dim)
    values = values.view(b,num_tokens,self.num_heads,self.head_dim)

    ##transpose : (b,um_tokens,num_heads,head_dim) to (b,num_heads,num_tokens,head_dim)\
    ##grouping in terms of heads
    keys = keys.transpose(1,2)
    queries = queries.transpose(1,2)
    values = values.transpose(1,2)

    ##computing scaled dot-product attention
    attn_scores = queries @ keys.transpose(2,3)

    #original mask teucated to the number of the tokens and converted to boolean
    mask_bool = self.mask.bool()[:num_tokens, : num_tokens]

    #use the mask to fill attention scores
    attn_scores = attn_scores.masked_fill_(mask_bool,-torch.inf)

    attn_weights = torch.softmax(attn_scores/keys.shape[-1]**0.5,dim=-1)
    attn_weights = self.dropout(attn_weights)

    ##combining heads,where self.d_out = self.num_heads * self.head_dim
    context_vec = context_vec.contiguous().view(b,num_tokens,self.d_out)
    context_vec = self.out_proj(context_vec) #optional projection

    return context_vec