In [2]:
#Extending Single head attention to Multi-head attention

#in practical terms, implementing multi-head attention involves creating multiple instances of the self-attention mechanism, each with its own weights, and then
#combining their outputs

In [1]:
from torch import nn
import torch



In [2]:
class CausalAttention(nn.Module):
    def __init__(self, d_in,d_out,dropout,context_length, qkv_bias=False):
        super().__init__()
        self.d_out = d_out
        self.W_query = nn.Linear(d_in,d_out,bias=qkv_bias)
        self.W_key = nn.Linear(d_in,d_out,bias=qkv_bias)        
        self.W_value = nn.Linear(d_in,d_out,bias=qkv_bias)
        self.dropout = nn.Dropout(dropout)
        self.register_buffer("mask",torch.triu(torch.ones(context_length,context_length),diagonal=1))
        
    def forward(self, X):
        b, num_tokens, d_in = X.shape
        keys = self.W_key(X)
        queries = self.W_query(X)
        values = self.W_value(X)

        attn_scores = queries @ keys.transpose(1,2)
        attn_scores.masked_fill(self.mask.bool(),-torch.inf)
        attn_weights = torch.softmax(attn_scores/keys.shape[-1]**0.5, dim=-1)
        attn_weights = self.dropout(attn_weights)
        context_vect = attn_weights @ values
        return context_vect

In [3]:
class MultiHeadAttentionWrapper(nn.Module):
    def __init__(self, d_in, d_out, dropout, context_length, num_heads, qkv_bias=False):
        super().__init__()
        self.heads = nn.ModuleList(
        [CausalAttention(d_in,d_out,dropout,context_length,qkv_bias) for _ in range(num_heads)]
        )
        
    def forward(self, x):
        return torch.cat([head(x) for head in self.heads],dim=-1)

In [4]:
inputs = torch.tensor(
  [[0.43, 0.15, 0.89], # Your     (x^1)
   [0.55, 0.87, 0.66], # journey  (x^2)
   [0.57, 0.85, 0.64], # starts   (x^3)
   [0.22, 0.58, 0.33], # with     (x^4)
   [0.77, 0.25, 0.10], # one      (x^5)
   [0.05, 0.80, 0.55]] # step     (x^6)
)

batch = torch.stack((inputs, inputs),dim=0)
batch.shape

torch.Size([2, 6, 3])

In [5]:
torch.manual_seed(123)
context_length = batch.shape[1]
d_in, d_out = 3, 2
mha = MultiHeadAttentionWrapper(d_in,d_out,0.5,context_length,num_heads=2)

context_vecs = mha(batch)
print(context_vecs)
print("Context_vecs.shape: ",context_vecs.shape)

tensor([[[-0.4067, -0.0197,  0.4922,  0.3193],
         [-0.6766, -0.0975,  0.2342,  0.1723],
         [-0.6757, -0.0977,  0.4925,  0.3194],
         [-0.8168, -0.1481,  0.8956,  0.5905],
         [-0.5348, -0.0429,  0.2794,  0.2031],
         [-0.9019, -0.1643,  0.8488,  0.5948]],

        [[-0.7909, -0.1098,  0.2482,  0.2226],
         [-0.1392, -0.0496,  0.4834,  0.3947],
         [-0.2788, -0.1014,  0.6543,  0.4116],
         [-0.3692, -0.1182,  0.4781,  0.3907],
         [-0.9389, -0.1628,  0.7358,  0.4862],
         [-0.6290, -0.1848,  0.2498,  0.2241]]], grad_fn=<CatBackward0>)
Context_vecs.shape:  torch.Size([2, 6, 4])


## Implementing Multi-head attention with weight splits

In [246]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False):
        super().__init__()
        assert(d_out % num_heads == 0), \
        "d_out must be divisible by num_heads"
        
        self.d_out =d_out
        self.num_heads = num_heads
        self.head_dim = d_out//num_heads
        self.W_query = nn.Linear(d_in,d_out,bias=qkv_bias)
        self.W_key = nn.Linear(d_in,d_out,bias=qkv_bias)        
        self.W_value = nn.Linear(d_in,d_out,bias=qkv_bias)        
        self.out_proj = nn.Linear(d_out, d_out) #Linear layer to combine head outputs
        self.dropout = nn.Dropout(dropout)
        self.register_buffer(
            "mask",
            torch.triu(torch.ones(context_length,context_length),diagonal=1)
            
        )
        
    def forward(self,x):
        b, num_tokens, d_in = x.shape
        keys = self.W_key(x)
        queries = self.W_query(x)
        values = self.W_value(x)
        
        keys = keys.view(b,num_tokens,self.num_heads,self.head_dim)
        values = values.view(b,num_tokens,self.num_heads,self.head_dim)
        queries = queries.view(b,num_tokens,self.num_heads,self.head_dim)        
        
        keys = keys.transpose(1,2)
        queries = queries.transpose(1,2)
        values = values.transpose(1,2)
        
        attn_score = queries @ keys.transpose(2,3)
        
        attn_score = attn_score.masked_fill(self.mask.bool()[:num_tokens,:num_tokens],-torch.inf)
        attn_weights = torch.softmax(attn_score/keys.shape[-1]**0.5,dim=-1)
        attn_weights = self.dropout(attn_weights)
        print(attn_weights)
        context_vect = (attn_weights @ values).transpose(1,2)
        
        context_vect = context_vect.contiguous().view(b,num_tokens,self.d_out)
        context_vect = self.out_proj(context_vect)
        
        return context_vect

In [247]:
x = MultiHeadAttention(3,4,6,0.0,2)

In [248]:
batch.shape

torch.Size([2, 6, 3])

In [249]:
t = x(batch)

t

tensor([[[[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
          [0.4489, 0.5511, 0.0000, 0.0000, 0.0000, 0.0000],
          [0.2886, 0.3545, 0.3570, 0.0000, 0.0000, 0.0000],
          [0.2336, 0.2571, 0.2579, 0.2514, 0.0000, 0.0000],
          [0.1744, 0.2040, 0.2051, 0.1967, 0.2198, 0.0000],
          [0.1539, 0.1700, 0.1706, 0.1661, 0.1783, 0.1611]],

         [[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
          [0.5095, 0.4905, 0.0000, 0.0000, 0.0000, 0.0000],
          [0.3422, 0.3294, 0.3284, 0.0000, 0.0000, 0.0000],
          [0.2530, 0.2454, 0.2451, 0.2565, 0.0000, 0.0000],
          [0.2032, 0.1966, 0.1964, 0.2065, 0.1973, 0.0000],
          [0.1679, 0.1631, 0.1629, 0.1709, 0.1623, 0.1729]]],


        [[[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
          [0.4489, 0.5511, 0.0000, 0.0000, 0.0000, 0.0000],
          [0.2886, 0.3545, 0.3570, 0.0000, 0.0000, 0.0000],
          [0.2336, 0.2571, 0.2579, 0.2514, 0.0000, 0.0000],
          [0.1744, 0.2040, 0.2051,

tensor([[[-0.2568,  0.0783,  0.1045, -0.6221],
         [-0.4398,  0.0679,  0.0397, -0.5446],
         [-0.4944,  0.0663,  0.0228, -0.5250],
         [-0.5030,  0.0626, -0.0233, -0.5050],
         [-0.4923,  0.0476, -0.0123, -0.5231],
         [-0.5076,  0.0538, -0.0455, -0.4996]],

        [[-0.2568,  0.0783,  0.1045, -0.6221],
         [-0.4398,  0.0679,  0.0397, -0.5446],
         [-0.4944,  0.0663,  0.0228, -0.5250],
         [-0.5030,  0.0626, -0.0233, -0.5050],
         [-0.4923,  0.0476, -0.0123, -0.5231],
         [-0.5076,  0.0538, -0.0455, -0.4996]]], grad_fn=<ViewBackward0>)

In [241]:
x = nn.Linear(3,2)

In [60]:
x.weight

Parameter containing:
tensor([[-0.3724, -0.0359, -0.1879],
        [-0.5469, -0.5157, -0.4731]], requires_grad=True)