In [1]:
import torch
import torch.nn as nn

In [None]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_in, d_out, 
                 context_length, num_heads,
                 dropout=0.0, qkv_bias=False):
        super().__init__()
        assert d_out % num_heads == 0, "d_out must be divisible by num_heads"
        self.d_out = d_out
        self.num_heads = num_heads
        self.head_dim = d_out // num_heads
        self.W_query = nn.linear(d_in, d_out, bias=qkv_bias)
        self.W_key = nn.linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.linear(d_in, d_out, bias=qkv_bias)
        self.register_buffer(
            "mask", torch.triu(torch.ones(
                context_length, context_length), diagonal=1)(
            ))
        self.dropout = nn.Dropout(dropout)
        self.out_proj = nn.Linear(d_out, d_out)
        
    def forward(self, x):
        B, T, C = x.shape
        
        queries = self.W_query(x)
        keys = self.W_key(x)
        values = self.W_value(x)
        
        queries = queries.view(B, T, self.num_heads, self.head_dim)
        keys = keys.view(B, T, self.num_heads, self.head_dim)
        values = values.view(B, T, self.num_heads, self.head_dim)
        
        queries = queries.transpose(1, 2)
        keys = keys.transpose(1, 2)
        values = values.transpose(1, 2)
        
        att_scores = queries @ keys.transpose(2, 3)
        att_scores = att_scores/keys.shape[-1]**0.5
        
        att_scores.masked_fill_(
            self.mask.bool()[:T, :T], -torch.inf
        )
        
        att_weights = torch.softmax(att_scores, dim=-1)
        att_weights = self.dropout(att_weights)
        
        context_vec = (att_weights @ values).transpose(1, 2)
        
        context_vec = context_vec.contiguous().view(B, T, self.d_out)
        
        out = self.out_proj(context_vec)
        
        return out
        
        
        