In [2]:
'''=====================================
Multi-Head Attention Module (PyTorch)
====================================='''
import torch
import torch.nn as nn
import math

class MultiHeadAttention(nn.Module):
    def __init__(self, embed_dim, num_heads):
        super().__init__()
        assert embed_dim % num_heads == 0, "embed_dim must be divisible by num_heads"

        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads

        self.qkv_proj = nn.Linear(embed_dim, 3 * embed_dim)
        self.out_proj = nn.Linear(embed_dim, embed_dim)

    def forward(self, x, mask=None):
        B, T, C = x.shape

        # QKV projection
        qkv = self.qkv_proj(x)  # (B, T, 3*embed_dim)
        qkv = qkv.reshape(B, T, 3, self.num_heads, self.head_dim)
        qkv = qkv.permute(2, 0, 3, 1, 4)  # (3, B, heads, T, head_dim)
        q, k, v = qkv[0], qkv[1], qkv[2]

        # Scaled dot-product attention
        scores = (q @ k.transpose(-2, -1)) / math.sqrt(self.head_dim)
        if mask is not None:
            scores = scores.masked_fill(mask == 0, float('-inf'))

        attn = torch.softmax(scores, dim=-1)
        out = attn @ v  # (B, heads, T, head_dim)

        # Merge heads
        out = out.transpose(1, 2).reshape(B, T, C)
        return self.out_proj(out)


In [3]:
'''=====================================
Positional Encoding Module (PyTorch)
====================================='''
class PositionalEncoding(nn.Module):
    def __init__(self, max_len, embed_dim):
        super().__init__()
        self.pos_embedding = nn.Embedding(max_len, embed_dim)

    def forward(self, x):
        B, T, C = x.shape
        positions = torch.arange(T, device=x.device).unsqueeze(0)
        return x + self.pos_embedding(positions)


In [4]:
# Fake input: batch_size=2, seq_len=5, vocab_size=100
x = torch.randint(0, 100, (2, 5))
embed_dim = 32
num_heads = 4
max_len = 10

# Token embedding (for testing)
token_emb = nn.Embedding(100, embed_dim)
x_emb = token_emb(x)

# Add positional encoding
pos_enc = PositionalEncoding(max_len, embed_dim)
x_emb = pos_enc(x_emb)

# Multi-head attention
mha = MultiHeadAttention(embed_dim, num_heads)
out = mha(x_emb)  # output shape: (2, 5, 32)
print(out.shape)


torch.Size([2, 5, 32])
