In [2]:
import torch 
import torch.nn as nn

In [23]:
class MultiHeadAttn(nn.Module):
    def __init__(self, d_in, d_out, num_head, dropout, context_length):
        super().__init__()
        self.d_out = d_out
        self.w_query = nn.Linear(d_in, d_out, bias=False)
        self.w_key = nn.Linear(d_in, d_out, bias=False)
        self.w_value = nn.Linear(d_in, d_out, bias=False)
        self.out_proj = nn.Linear(d_out, d_out, bias=False)
        self.num_head = num_head

        self.dropout = nn.Dropout(p=dropout)

        self.head_dim = d_out // num_head
        self.register_buffer('mask', torch.triu(torch.ones(context_length, context_length), diagonal=1))

    def forward(self, x):
        batch, num_token, d_in = x.shape

        key = self.w_key(x)
        query = self.w_query(x)
        value = self.w_value(x)

        key = key.view(batch, num_token, self.num_head, self.head_dim)
        query = query.view(batch, num_token, self.num_head, self.head_dim)
        value = value.view(batch, num_token, self.num_head, self.head_dim)

        key = key.transpose(1,2)
        query = query.transpose(1,2)
        value = value.transpose(1,2)

        attn_scores = query @ key.transpose(2,3)

        attn_scores.masked_fill_(self.mask.bool(), -torch.inf)

        attn_weight = torch.softmax(attn_scores / key.shape[-1]**0.5, dim=-1)
        # print(attn_weight)

        attn_weight_dropout = self.dropout(attn_weight)

        context_vec = (attn_weight_dropout @ value).transpose(1,2)

        context_vec = context_vec.contiguous().view(batch, num_token, self.d_out)

        context_vec = self.out_proj(context_vec)
        return context_vec

        



In [24]:
x = torch.rand(2,3,5)
mha = MultiHeadAttn(5, 6, 2, 0.5, 3)

In [25]:
y = mha(x)
y.shape

torch.Size([2, 3, 6])