In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

# ---- 1) Rotary Positional Embedding ----
def build_rope_frequencies(seq_len, head_dim, base=10000):
    """Precompute RoPE sine and cosine frequencies."""
    pos = torch.arange(seq_len)
    freq = 1.0 / (base ** (torch.arange(0, head_dim, 2) / head_dim))
    angles = pos[:, None] * freq[None, :]  # (seq_len, head_dim/2)
    return torch.sin(angles), torch.cos(angles)  # each (seq_len, head_dim/2)

def apply_rope(x, sin, cos):
    """
    x: (B, T, H, D) where D is even
    sin, cos: (T, D/2)
    """
    B, T, H, D = x.shape
    x1 = x[..., ::2]  # (B, T, H, D/2)
    x2 = x[..., 1::2]

    # expand sin, cos for broadcasting
    sin = sin[None, :, None, :]  # (1, T, 1, D/2)
    cos = cos[None, :, None, :]

    x1_rot = x1 * cos - x2 * sin
    x2_rot = x1 * sin + x2 * cos
    return torch.stack([x1_rot, x2_rot], dim=-1).flatten(-2)  # back to (B,T,H,D)

# ---- 2) Multi-Head Attention ----
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, n_heads):
        super().__init__()
        assert d_model % n_heads == 0
        self.d_model = d_model
        self.n_heads = n_heads
        self.d_head = d_model // n_heads
        
        self.Wq = nn.Linear(d_model, d_model)
        self.Wk = nn.Linear(d_model, d_model)
        self.Wv = nn.Linear(d_model, d_model)
        self.Wo = nn.Linear(d_model, d_model)

    def forward(self, x, sin, cos, mask=None):
        B, T, _ = x.shape
        q = self.Wq(x).view(B, T, self.n_heads, self.d_head)
        k = self.Wk(x).view(B, T, self.n_heads, self.d_head)
        v = self.Wv(x).view(B, T, self.n_heads, self.d_head)

        # Apply RoPE to Q and K
        q = apply_rope(q, sin[:T], cos[:T])
        k = apply_rope(k, sin[:T], cos[:T])

        # Attention
        scores = torch.einsum('bthd,bshd->bhts', q, k) / math.sqrt(self.d_head)
        if mask is not None:
            scores = scores.masked_fill(mask == 0, float('-inf'))
        attn = F.softmax(scores, dim=-1)

        context = torch.einsum('bhts,bshd->bthd', attn, v)
        context = context.reshape(B, T, self.d_model)
        return self.Wo(context)

# ---- 3) SwiGLU Feed-Forward ----
class SwiGLU(nn.Module):
    def __init__(self, d_model, d_ff):
        super().__init__()
        self.W1 = nn.Linear(d_model, d_ff)
        self.W2 = nn.Linear(d_model, d_ff)
        self.W_out = nn.Linear(d_ff, d_model)

    def forward(self, x):
        return self.W_out(F.silu(self.W1(x)) * self.W2(x))

# ---- 4) Transformer Block ----
class TransformerBlock(nn.Module):
    def __init__(self, d_model, n_heads, d_ff):
        super().__init__()
        self.ln1 = nn.LayerNorm(d_model)
        self.attn = MultiHeadAttention(d_model, n_heads)
        self.ln2 = nn.LayerNorm(d_model)
        self.mlp = SwiGLU(d_model, d_ff)

    def forward(self, x, sin, cos, mask=None):
        x = x + self.attn(self.ln1(x), sin, cos, mask)
        x = x + self.mlp(self.ln2(x))
        return x

# ---- 5) Full tiny LM ----
class TinyTransformerLM(nn.Module):
    def __init__(self, vocab_size, d_model=128, n_heads=4, d_ff=256, n_layers=2, max_len=128):
        super().__init__()
        self.embed = nn.Embedding(vocab_size, d_model)
        self.blocks = nn.ModuleList([TransformerBlock(d_model, n_heads, d_ff) for _ in range(n_layers)])
        self.ln_f = nn.LayerNorm(d_model)
        self.unembed = nn.Linear(d_model, vocab_size, bias=False)
        self.unembed.weight = self.embed.weight  # weight tying
        self.sin, self.cos = build_rope_frequencies(max_len, d_model // n_heads)

    def forward(self, idx):
        B, T = idx.shape
        x = self.embed(idx)

        # Causal mask (triangular)
        mask = torch.tril(torch.ones(T, T, device=idx.device)).unsqueeze(0).unsqueeze(0)

        for block in self.blocks:
            x = block(x, self.sin.to(x.device), self.cos.to(x.device), mask)

        x = self.ln_f(x)
        logits = self.unembed(x)
        return logits

# ---- Test ----
vocab_size = 100
model = TinyTransformerLM(vocab_size)

idx = torch.randint(0, vocab_size, (2, 10))  # (batch=2, seq=10)
logits = model(idx)
print("Logits shape:", logits.shape)  # (2, 10, vocab_size)



Logits shape: torch.Size([2, 10, 100])
