**stacking multiple single-head attention layers**

stacking multiple casual attention together to make a multi-head attention

In [1]:
import torch
inputs = torch.tensor([
    [0.43, 0.15, 0.89],  # your        (x^1)
    [0.55, 0.87, 0.66],  # journey     (x^2)
    [0.57, 0.85, 0.64],  # starts     (x^3)
    [0.22, 0.58, 0.33],  # with       (x^4)
    [0.77, 0.25, 0.10],  # one        (x^5)
    [0.05, 0.80, 0.55]  # step        (X^6)
])



In [2]:
import torch.nn as nn

class CasualAttention(nn.Module):
    def __init__(self, d_in, d_out, context_length, dropout, qkv_bias=False):
        super().__init__()
        self.d_out=d_out
        self.W_query=nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key=nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value=nn.Linear(d_in, d_out, bias=qkv_bias)
        self.dropout=nn.Dropout(dropout)
        self.register_buffer(
            'mask',
            torch.triu(torch.ones(context_length, context_length), diagonal=1)
                    
                            )
        

    def forward(self, x):
        b, num_tokens, d_in=x.shape
        keys=self.W_key(x)
        queries=self.W_query(x)
        values=self.W_value(x)

        attn_score=queries@keys.transpose(1,2)
        attn_score.masked_fill_(self.mask.bool()[:num_tokens, :num_tokens], -torch.inf)
        attn_weights=torch.softmax(attn_score/keys.shape[-1]**0.5, dim=-1)
        attn_weights=self.dropout(attn_weights)
        context_vec=attn_weights@values
        return context_vec





In [3]:
batch=torch.stack((inputs, inputs), dim=0)
print(batch)
print(batch.shape)

tensor([[[0.4300, 0.1500, 0.8900],
         [0.5500, 0.8700, 0.6600],
         [0.5700, 0.8500, 0.6400],
         [0.2200, 0.5800, 0.3300],
         [0.7700, 0.2500, 0.1000],
         [0.0500, 0.8000, 0.5500]],

        [[0.4300, 0.1500, 0.8900],
         [0.5500, 0.8700, 0.6600],
         [0.5700, 0.8500, 0.6400],
         [0.2200, 0.5800, 0.3300],
         [0.7700, 0.2500, 0.1000],
         [0.0500, 0.8000, 0.5500]]])
torch.Size([2, 6, 3])


In [4]:
torch.manual_seed(123)
context_length=batch.shape[1]
d_in=batch.shape[2]
d_out=2
ca=CasualAttention(d_in=d_in, d_out=d_out, context_length=context_length,dropout=0.2)



In [5]:
context_vec=ca(batch)
print(context_vec)

tensor([[[-0.5649,  0.2770],
         [-0.7343,  0.0072],
         [-0.7875, -0.0790],
         [-0.7093, -0.1053],
         [-0.6907, -0.1226],
         [-0.4311, -0.0610]],

        [[ 0.0000,  0.0000],
         [-0.7343,  0.0072],
         [-0.4833,  0.0046],
         [-0.7093, -0.1053],
         [-0.6907, -0.1226],
         [-0.5738, -0.1035]]], grad_fn=<UnsafeViewBackward0>)


In [7]:
print(context_vec.shape)

torch.Size([2, 6, 2])


In [6]:
# torh.cat : helps to concate a given sequences of tensors along side a specific dimension
t1 = torch.tensor([[1, 2], [3, 4]])
t2 = torch.tensor([[5, 6], [7, 8]])


cat_dim0=torch.cat([t1,t1],dim=0)
print(cat_dim0)

cat_dim1=torch.cat([t1,t2], dim=1)
print(cat_dim1)

cat_dim_1 = torch.cat([t1, t2], dim=-1)
print(cat_dim_1)

tensor([[1, 2],
        [3, 4],
        [1, 2],
        [3, 4]])
tensor([[1, 2, 5, 6],
        [3, 4, 7, 8]])
tensor([[1, 2, 5, 6],
        [3, 4, 7, 8]])


*calling casual attention multiple times*

In [8]:
# num_heads : total number of casual attention layers we use to 
# make multihead attention

class MultiHeadAttentionWrapper(nn.Module):
    def __init__(self, d_in, d_out, context_length, 
                 dropout, num_heads, qkv_bias=False):
        super().__init__()
        self.heads=nn.ModuleList(
            [CasualAttention(d_in=d_in, d_out=d_out, context_length=context_length, dropout=dropout, qkv_bias=qkv_bias)
             for _ in range(num_heads)]
            
        )

    def forward(self,x):
        return torch.cat([head(x) for head in self.heads], dim=-1)
    



*for this multihead attention, we used 2 causal head (num_heads=2), each casual head returns context vector of shape 2 (d_out=2), so for this case, final embedding dimension is 4 (2 d_out or d_outxnum_heads)*

In [9]:
torch.manual_seed(123)
context_length=batch.shape[1]
d_in, d_out=3,2
mhaw=MultiHeadAttentionWrapper(d_in=d_in, d_out=d_out,context_length=context_length,dropout=0.0,num_heads=2)
context_vec=mhaw(batch)


In [10]:
print(context_vec)
print(context_vec.shape)

tensor([[[-0.4519,  0.2216,  0.4772,  0.1063],
         [-0.5874,  0.0058,  0.5891,  0.3257],
         [-0.6300, -0.0632,  0.6202,  0.3860],
         [-0.5675, -0.0843,  0.5478,  0.3589],
         [-0.5526, -0.0981,  0.5321,  0.3428],
         [-0.5299, -0.1081,  0.5077,  0.3493]],

        [[-0.4519,  0.2216,  0.4772,  0.1063],
         [-0.5874,  0.0058,  0.5891,  0.3257],
         [-0.6300, -0.0632,  0.6202,  0.3860],
         [-0.5675, -0.0843,  0.5478,  0.3589],
         [-0.5526, -0.0981,  0.5321,  0.3428],
         [-0.5299, -0.1081,  0.5077,  0.3493]]], grad_fn=<CatBackward0>)
torch.Size([2, 6, 4])


**Implementing multi-head attention with weight splits**

In [11]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False):
        super().__init__()

        assert d_out % num_heads == 0, "d_out must be divisible by num_heads"
      

        self.d_out=d_out
        self.num_heads=num_heads
        self.head_dim=d_out//num_heads # reduce projection dim to match desired output dim
        self.W_query=nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key=nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value=nn.Linear(d_in, d_out, bias=qkv_bias)

        self.out_proj=nn.Linear(d_out, d_in) # use linear layer to combine head output
        self.dropout=nn.Dropout(dropout)
        self.register_buffer(
            'mask',
            torch.triu(torch.ones(context_length, context_length), diagonal=1)
        )


    def forward(self,x):
        b,num_tokens,d_in=x.shape
        keys=self.W_key(x)
        queries=self.W_query(x)
        values=self.W_value(x)

        # from (b, num_tokens, d_out) to (b, num_token, num_heads, head_dim) unrolling
        keys=keys.view(b, num_tokens, self.num_heads, self.head_dim)
        queries=queries.view(b, num_tokens, self.num_heads, self.head_dim)
        values=values.view(b, num_tokens, self.num_heads, self.head_dim)

        # (b, num_tokens, num_heads, head_dim) to (b, num_heads, num_tokens, head_dim)
        keys=keys.transpose(1,2)
        queries=queries.transpose(1,2)
        values=values.transpose(1,2)

        # keys (b, num_heads, num_tokens, head_dim) to (b, num_heads, head_dim, num_tokens)
        attn_scores=queries@keys.transpose(2,3)
        mask_bool=self.mask.bool()[:num_tokens, :num_tokens]
        attn_scores.masked_fill_(mask_bool, -torch.inf)

        attn_weights=torch.softmax(attn_scores/keys.shape[-1]**0.5, dim=-1)
        attn_weights=self.dropout(attn_weights)
        context_vec=(attn_weights@values).transpose(1,2)

        context_vec=context_vec.contiguous().view(b,num_tokens, self.d_out) # combining heads: self.d_out=self.num_heads*self.head_dim
        context_vec=self.out_proj(context_vec)
        return context_vec





In [12]:
print(batch.shape)

torch.Size([2, 6, 3])


In [13]:
torch.manual_seed(123)
batch_size, context_length, d_in=batch.shape
d_out=2
mha=MultiHeadAttention(d_in=d_in, d_out=d_out,context_length=context_length,dropout=0.0,num_heads=2)
context_vec=mha(batch)
print(context_vec)
print(context_vec.shape)


tensor([[[-0.1933,  0.0272, -0.2507],
         [-0.2179, -0.0689, -0.4201],
         [-0.2267, -0.0993, -0.4760],
         [-0.2430, -0.0712, -0.4813],
         [-0.2484, -0.0658, -0.4875],
         [-0.2548, -0.0558, -0.4908]],

        [[-0.1933,  0.0272, -0.2507],
         [-0.2179, -0.0689, -0.4201],
         [-0.2267, -0.0993, -0.4760],
         [-0.2430, -0.0712, -0.4813],
         [-0.2484, -0.0658, -0.4875],
         [-0.2548, -0.0558, -0.4908]]], grad_fn=<ViewBackward0>)
torch.Size([2, 6, 3])
