In [2]:
import torch
import torch.nn as nn

### Multi-head attention
- One way to implement multi-head attention is to use a nn.Modulelist container and hold multiple self attention modules in it
- or you can ensure the embedding dimension is divisible by num_heads

In [3]:
class SelfAttention(nn.Module):
    def __init__(self, d_in, d_out_kq, d_out_v):
        super().__init__()
        self.d_out_kq = d_out_kq
        self.w_query = nn.Parameter(torch.randn(d_in, d_out_kq))
        self.w_key = nn.Parameter(torch.randn(d_in, d_out_kq))
        self.w_value = nn.Parameter(torch.randn(d_in, d_out_v))

    def forward(self, x):
        queries = x @ self.w_query
        keyes = x @ self.w_key
        values = x @ self.w_value

        attn_scores = queries @ keyes.T

        attn_weights = torch.softmax((attn_scores / self.d_out_kq ** 0.5), dim=-1)

        context_vec = attn_weights @ values

        return context_vec

In [9]:
class MultiHeadAttentionWrapper(nn.Module):
    def __init__(self, d_in, d_out_kq, d_out_v, num_heads):
        super().__init__()
        self.heads = nn.ModuleList([SelfAttention(d_in, d_out_kq, d_out_v) for _ in range(num_heads)])

    def forward(self, x):
        return torch.cat([m(x) for m in self.heads], dim=-1)

In [19]:
mha = MultiHeadAttentionWrapper(3, 2, 1, 4)

In [20]:
x = torch.randn((6, 3))

In [21]:
out = mha(x)

In [22]:
out.shape

torch.Size([6, 4])

In [102]:
class MHAAttention(nn.Module):
    def __init__(self, d_in, d_out_kq, d_out_v, num_heads):
        super().__init__()
        self.d_in = d_in
        self.d_out_kq = d_out_kq
        self.d_out_v = d_out_v
        self.num_heads = num_heads
        self.w_query = nn.Parameter(torch.randn(d_in, d_out_kq))
        self.w_key = nn.Parameter(torch.randn(d_in, d_out_kq))
        self.w_value = nn.Parameter(torch.randn(d_in, d_out_v))

    def forward(self, x):
        seq_length = x.shape[0]
        queries = x @ self.w_query
        keyes = x @ self.w_key
        values = x @ self.w_value

        queries = queries.view(seq_length, self.num_heads, self.d_out_kq // self.num_heads)
        keyes = keyes.view(seq_length, self.num_heads, self.d_out_kq // self.num_heads)
        values = values.view(seq_length, self.num_heads, self.d_out_v // self.num_heads)

        
        attn_scores = queries @ keyes.transpose(-2, -1)

        print(attn_scores.shape)

        attn_weights = torch.softmax((attn_scores / self.d_out_kq ** 0.5), dim=-1)

        context_vec = attn_weights @ values

        print(context_vec.shape)

        # return context_vec.view(seq_length, self.num_heads * self.d_in)

In [103]:
x = torch.randn(6, 16)

In [104]:
new_mha = MHAAttention(16, 16, 16, 2)

In [107]:
new_mha(x)

torch.Size([6, 2, 2])
torch.Size([6, 2, 8])
