In [5]:
import torch.nn as nn
import torch

class SelfAttention(nn.Module):

    def __init__(self, d_in, d_out_kq, d_out_v):
        super().__init__()
        self.d_out_kq = d_out_kq

        self.W_query = nn.Parameter(torch.rand(d_in, d_out_kq))
        self.W_key   = nn.Parameter(torch.rand(d_in, d_out_kq))
        self.W_value = nn.Parameter(torch.rand(d_in, d_out_v))

    def forward(self, x):
        keys = x @ self.W_key
        queries = x @ self.W_query
        values = x @ self.W_value

        attn_scores = queries @ keys.T  # unnormalized attention weights

        # no mask
        attn_weights = torch.softmax(
            attn_scores / self.d_out_kq**0.5, dim=-1
        )

        # # causal mask
        seq_len = attn_weights.shape[0] # seq_len x seq_len
        mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1)
        masked = attn_scores.masked_fill(mask.bool(), -torch.inf)
        attn_weights = torch.softmax(masked / d_out_kq**0.5, dim=1)


        context_vec = attn_weights @ values
        return context_vec

In [6]:
# sentence = 'Life is short, eat dessert first'
# dc = {s:i for i,s
#       in enumerate(sorted(sentence.replace(',', '').split()))}

# print(dc)
# sentence_int = torch.tensor(
#     [dc[s] for s in sentence.replace(',', '').split()]
# )
# print(sentence_int)
# vocab_size = 50_000

# torch.manual_seed(123)
# embed = torch.nn.Embedding(vocab_size, 3)
# embedded_sentence = embed(sentence_int).detach()

# print(embedded_sentence)
# print(embedded_sentence.shape)

In [7]:
embedded_sentence = torch.randn((6,3))

# reduce d_out_v from 4 to 1, because we have 4 heads
d_in, d_out_kq, d_out_v = 3, 2, 4

sa = SelfAttention(d_in, d_out_kq, d_out_v)
print(sa(embedded_sentence).shape)


torch.Size([6, 4])


In [8]:
class MultiHeadAttentionWrapper(nn.Module):

    def __init__(self, d_in, d_out_kq, d_out_v, num_heads):
        super().__init__()
        self.heads = nn.ModuleList(
            [SelfAttention(d_in, d_out_kq//num_heads, d_out_v//num_heads)
             for _ in range(num_heads)]
        )

    def forward(self, x):
        return torch.cat([head(x) for head in self.heads], dim=-1)

In [9]:
torch.manual_seed(123)

block_size = embedded_sentence.shape[1]
mha = MultiHeadAttentionWrapper(
    d_in, d_out_kq, d_out_v, num_heads=2
)

context_vecs = mha(embedded_sentence)

print(context_vecs)
print("context_vecs.shape:", context_vecs.shape)


tensor([[-0.4987, -1.2924, -1.0457, -1.4010],
        [-0.3102, -0.9339, -0.7622, -0.8023],
        [-0.1753, -0.5300, -0.4057, -0.4976],
        [-0.4339, -0.9364, -0.8478, -1.1181],
        [-0.2924, -0.6129, -0.6661, -0.9779],
        [ 0.3994,  0.6958,  1.1481,  2.2647]], grad_fn=<CatBackward0>)
context_vecs.shape: torch.Size([6, 4])
