In [2]:
#Extending Single head attention to Multi-head attention

#in practical terms, implementing multi-head attention involves creating multiple instances of the self-attention mechanism, each with its own weights, and then
#combining their outputs

In [4]:
from torch import nn
import torch



In [25]:
class CausalAttention(nn.Module):
    def __init__(self, d_in,d_out,dropout,context_length, qkv_bias=False):
        super().__init__()
        self.d_out = d_out
        self.W_query = nn.Linear(d_in,d_out,bias=qkv_bias)
        self.W_key = nn.Linear(d_in,d_out,bias=qkv_bias)        
        self.W_value = nn.Linear(d_in,d_out,bias=qkv_bias)
        self.dropout = nn.Dropout(dropout)
        self.register_buffer("mask",torch.triu(torch.ones(context_length,context_length),diagonal=1))
        
    def forward(self, X):
        b, num_tokens, d_in = X.shape
        keys = self.W_key(X)
        queries = self.W_query(X)
        values = self.W_value(X)

        attn_scores = queries @ keys.transpose(1,2)
        attn_scores.masked_fill(self.mask.bool(),-torch.inf)
        attn_weights = torch.softmax(attn_scores/keys.shape[-1]**0.5, dim=-1)
        attn_weights = self.dropout(attn_weights)
        context_vect = attn_weights @ values
        return context_vect

In [26]:
class MultiHeadAttentionWrapper(nn.Module):
    def __init__(self, d_in, d_out, dropout, context_length, num_heads, qkv_bias=False):
        super().__init__()
        self.heads = nn.ModuleList(
        [CausalAttention(d_in,d_out,dropout,context_length,qkv_bias) for _ in range(num_heads)]
        )
        
    def forward(self, x):
        return torch.cat([head(x) for head in self.heads],dim=-1)

In [27]:
inputs = torch.tensor(
  [[0.43, 0.15, 0.89], # Your     (x^1)
   [0.55, 0.87, 0.66], # journey  (x^2)
   [0.57, 0.85, 0.64], # starts   (x^3)
   [0.22, 0.58, 0.33], # with     (x^4)
   [0.77, 0.25, 0.10], # one      (x^5)
   [0.05, 0.80, 0.55]] # step     (x^6)
)

batch = torch.stack((inputs, inputs),dim=0)
batch.shape

torch.Size([2, 6, 3])

In [40]:
torch.manual_seed(123)
context_length = batch.shape[1]
d_in, d_out = 3, 2
mha = MultiHeadAttentionWrapper(d_in,d_out,0.5,context_length,num_heads=2)

context_vecs = mha(batch)
print(context_vecs)
print("Context_vecs.shape: ",context_vecs.shape)

tensor([[[-0.4067, -0.0197,  0.4922,  0.3193],
         [-0.6766, -0.0975,  0.2342,  0.1723],
         [-0.6757, -0.0977,  0.4925,  0.3194],
         [-0.8168, -0.1481,  0.8956,  0.5905],
         [-0.5348, -0.0429,  0.2794,  0.2031],
         [-0.9019, -0.1643,  0.8488,  0.5948]],

        [[-0.7909, -0.1098,  0.2482,  0.2226],
         [-0.1392, -0.0496,  0.4834,  0.3947],
         [-0.2788, -0.1014,  0.6543,  0.4116],
         [-0.3692, -0.1182,  0.4781,  0.3907],
         [-0.9389, -0.1628,  0.7358,  0.4862],
         [-0.6290, -0.1848,  0.2498,  0.2241]]], grad_fn=<CatBackward0>)
Context_vecs.shape:  torch.Size([2, 6, 4])


In [41]:
multi_head(batch)

tensor([[[-0.3238,  0.1458, -0.2331,  0.3215, -0.0152, -0.1707,  0.1702,
           0.0619,  0.0916,  0.1980],
         [-0.3229,  0.1420, -0.2314,  0.3213, -0.0152, -0.1720,  0.1715,
           0.0584,  0.0893,  0.1920],
         [-0.3228,  0.1418, -0.2314,  0.3213, -0.0152, -0.1718,  0.1714,
           0.0585,  0.0893,  0.1919],
         [-0.3203,  0.1378, -0.2311,  0.3210, -0.0152, -0.1725,  0.1721,
           0.0568,  0.0899,  0.1917],
         [-0.3203,  0.1360, -0.2320,  0.3214, -0.0152, -0.1688,  0.1694,
           0.0604,  0.0895,  0.1901],
         [-0.3211,  0.1401, -0.2308,  0.3209, -0.0151, -0.1742,  0.1732,
           0.0555,  0.0900,  0.1928]],

        [[-0.3238,  0.1458, -0.2331,  0.3215, -0.0152, -0.1707,  0.1702,
           0.0619,  0.0916,  0.1980],
         [-0.3229,  0.1420, -0.2314,  0.3213, -0.0152, -0.1720,  0.1715,
           0.0584,  0.0893,  0.1920],
         [-0.3228,  0.1418, -0.2314,  0.3213, -0.0152, -0.1718,  0.1714,
           0.0585,  0.0893,  0.1919],