Import required libraires

In [1]:
import torch
import torch.nn as nn

Multi-head attention class

In [2]:
class MultiHeadAttention(nn.Module):

    def __init__(self,d_in,d_out,context_length,dropout=0.5,num_heads=2 ,qkvbias = False):

        super().__init__()
        assert (d_out % num_heads == 0), \
            "d_out must be divisible by num_heads"
        
        self.d_out = d_out
        self.num_heads = num_heads
        self.head_dim = d_out // num_heads

        self.W_query = nn.Linear(d_in,d_out,qkvbias)
        self.W_key = nn.Linear(d_in,d_out,qkvbias)
        self.W_value = nn.Linear(d_in,d_out,qkvbias)
        self.out_proj = nn.Linear(d_out,d_out)
        self.dropout = nn.Dropout(dropout)
        self.register_buffer("mask", torch.tril(torch.ones(context_length, context_length),diagonal=1))

    def forward(self, x):

        b, num_tokens , d_in = x.shape
        keys = self.W_key(x)
        queries = self.W_query(x)
        values = self.W_value(x)

        keys = keys.view(b,num_tokens,self.num_heads,self.head_dim)
        queries = queries.view(b,num_tokens,self.num_heads,self.head_dim)
        values = values.view(b,num_tokens,self.num_heads,self.head_dim)

        keys = keys.transpose(1,2)  # (b, num_heads, num_tokens, head_dim)
        queries = queries.transpose(1,2)  # (b, num_heads, num_tokens, head_dim)
        values = values.transpose(1,2)  # (b, num_heads, num_tokens, head_dim)

        attn_scores = queries @ keys.transpose(2,3)

        mask_bool = self.mask.bool()[:num_tokens, :num_tokens]

        attn_scores = attn_scores.masked_fill_(mask_bool,-torch.inf)

        attn_weights = torch.softmax(attn_scores/keys.shape[-1]**0.5,dim =-1)
        attn_weights = self.dropout(attn_weights)

        context_vector = (attn_weights @ values ).transpose(1,2) #Shape: (b, num_tokens, num_heads, head_dim)

        context_vector = context_vector.contiguous().view(b,num_tokens,self.d_out)
        context_vector = self.out_proj(context_vector)

        return context_vector



In [3]:
inputs = torch.tensor([
 [0.43, 0.15, 0.89],
 [0.55, 0.87, 0.66],
 [0.57, 0.85, 0.64],
 [0.22, 0.58, 0.33],
 [0.77, 0.25, 0.10],
 [0.05, 0.80, 0.55]
]
)

In [4]:
x_2 = inputs[1]
d_in = inputs.shape[1]
d_out = 4
dropout = 0.5

In [5]:
batch = torch.stack((inputs, inputs), dim=0)
print(batch.shape)

torch.Size([2, 6, 3])


In [6]:
torch.manual_seed(123)
context_length = batch.shape[1]
mha = MultiHeadAttention(d_in, d_out, context_length, dropout = 0.5, num_heads=2, qkvbias=False)
context_vecs = mha(batch)
print(context_vecs.shape)  # Should be (batch_size, num_tokens, d_out

torch.Size([2, 6, 4])
