In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [3]:
# toy input batch_size=1 (one sentence) sequence_length=4 (4 words) embedding_dim=8 (one word represent by 8 values)
x = torch.rand(1, 4, 8) # (B, T, D) - batch size, time step/ sequence len/ words, embedding dim

In [8]:
class MultiHeadSelfAttention(nn.Module):
    def __init__(self, embed_dim, num_heads):
        super().__init__()
        assert embed_dim % num_heads == 0, "Embedding dim must be divisible by num_heads"

        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads

        # layers to learn Q, K, V
        self.q_proj = nn.Linear(embed_dim, embed_dim)
        self.k_proj = nn.Linear(embed_dim, embed_dim)
        self.v_proj = nn.Linear(embed_dim, embed_dim)

        # output projection
        self.out_proj = nn.Linear(embed_dim, embed_dim)

    def forward(self, x):
        B, T, D = x.shape

        q = self.q_proj(x).reshape(B, T, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
        k = self.k_proj(x).reshape(B, T, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
        v = self.v_proj(x).reshape(B, T, self.num_heads, self.head_dim).permute(0, 2, 1, 3)

        scores = torch.matmul(q, k.transpose(-2, -1)) / self.head_dim ** 0.5
        attn = F.softmax(scores, dim=-1)
        out = torch.matmul(attn, v)

        out = out.transpose(1, 2).contiguous().reshape(B, T, D)
        return self.out_proj(out)


In [9]:
attn = MultiHeadSelfAttention(embed_dim=8, num_heads=2)
out = attn(x)
print(out.shape)


torch.Size([1, 4, 8])
