In [4]:
import torch.nn as nn
import torch

In [7]:
class MultiHeadAttention(nn.Module):
    def __init__(self,d_in,d_out,context_length,dropout,num_heads,qkv_biasing=False):
        super().__init__()
        assert (d_out % num_heads ==0), \
        "d_out must be divisible by num_heads"

        self.d_out=d_out
        self.num_heads=num_heads

        self.head_dim=d_out // num_heads

        self.W_query=nn.Linear(d_in,d_out,bias=qkv_biasing)
        self.W_key=nn.Linear(d_in,d_out,bias=qkv_biasing)
        self.W_value=nn.Linear(d_in,d_out,bias=qkv_biasing)
        self.out_proj=nn.Linear(d_out,d_out)
        self.dropout=nn.Dropout(dropout)
        self.register_buffer(
                        "mask",
                        torch.triu(torch.ones(context_length,context_length),diagonal=1)
                             )
        
    def forward(self,x):
        b,num_tokens,d_in=x.shape

        keys=self.W_key(x)
        values=self.W_value(x)
        queries=self.W_query(x)

        keys=keys.view(b,num_tokens,self.num_heads,self.head_dim)
        values=values.view(b,num_tokens,self.num_heads,self.head_dim)
        queries=queries.view(b,num_tokens,self.num_heads,self.head_dim)

        keys = keys.transpose(1, 2)
        queries = queries.transpose(1, 2)
        values = values.transpose(1, 2)

        attn_scores=queries @ keys.transpose(2,3)

        mask_bool=self.mask.bool()[:num_tokens,:num_tokens]

        attn_scores.masked_fill_(mask_bool,-torch.inf)

        attn_weights=torch.softmax(attn_scores/keys.shape[-1]**0.5,dim=-1)
        attn_weights=self.dropout(attn_weights)

        context_vector=(attn_weights @ values).transpose(1,2)

        context_vector=context_vector.contiguous().view(b,num_tokens,self.d_out)
        context_vector=self.out_proj(context_vector)

        return context_vector





In [8]:
# Hyperparameters
batch_size = 2
seq_len = 6
d_in = 16
d_out = 32
num_heads = 4
dropout = 0.1

# Dummy input (like token embeddings)
x = torch.randn(batch_size, seq_len, d_in)

# Create attention layer
mha = MultiHeadAttention(
    d_in=d_in,
    d_out=d_out,
    context_length=seq_len,
    dropout=dropout,
    num_heads=num_heads
)

# Forward pass
output = mha(x)

print("Input shape:", x.shape)
print("Output shape:", output.shape)


Input shape: torch.Size([2, 6, 16])
Output shape: torch.Size([2, 6, 32])
