In [None]:
# reimplementation_pytorch.ipynb

# -----------------------------
# 1. Imports & Config
# -----------------------------
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

device = "cuda" if torch.cuda.is_available() else "cpu"
print("Using device:", device)

# -----------------------------
# 2. Positional Encoding (Sinusoidal)
# -----------------------------
class PositionalEncoding(nn.Module):
    def __init__(self, dim, max_len=2048):
        super().__init__()
        pe = torch.zeros(max_len, dim)
        pos = torch.arange(0, max_len).unsqueeze(1)
        div = torch.exp(torch.arange(0, dim, 2) * (-math.log(10000.0) / dim))
        pe[:, 0::2] = torch.sin(pos * div)
        pe[:, 1::2] = torch.cos(pos * div)
        pe = pe.unsqueeze(0)  # (1, max_len, dim)
        self.register_buffer('pe', pe)

    def forward(self, x):
        return x + self.pe[:, :x.size(1)]

# -----------------------------
# 3. Multi-Head Attention (no FlashAttention for now)
# -----------------------------
class MultiHeadSelfAttention(nn.Module):
    def __init__(self, dim, heads):
        super().__init__()
        self.heads = heads
        self.scale = (dim // heads) ** -0.5

        self.qkv = nn.Linear(dim, dim * 3)
        self.out_proj = nn.Linear(dim, dim)

    def forward(self, x):
        B, T, C = x.size()
        qkv = self.qkv(x)  # (B, T, 3C)
        q, k, v = qkv.chunk(3, dim=-1)

        q = q.view(B, T, self.heads, C // self.heads).transpose(1, 2)
        k = k.view(B, T, self.heads, C // self.heads).transpose(1, 2)
        v = v.view(B, T, self.heads, C // self.heads).transpose(1, 2)

        attn_scores = (q @ k.transpose(-2, -1)) * self.scale
        attn_weights = F.softmax(attn_scores, dim=-1)
        out = attn_weights @ v
        out = out.transpose(1, 2).contiguous().view(B, T, C)

        return self.out_proj(out)

# -----------------------------
# 4. Transformer Block
# -----------------------------
class TransformerBlock(nn.Module):
    def __init__(self, dim, heads, ff_dim, dropout=0.1):
        super().__init__()
        self.ln1 = nn.LayerNorm(dim)
        self.attn = MultiHeadSelfAttention(dim, heads)
        self.ln2 = nn.LayerNorm(dim)
        self.ff = nn.Sequential(
            nn.Linear(dim, ff_dim),
            nn.ReLU(),
            nn.Linear(ff_dim, dim),
        )
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        x = x + self.dropout(self.attn(self.ln1(x)))
        x = x + self.dropout(self.ff(self.ln2(x)))
        return x

# -----------------------------
# 5. Transformer Language Model
# -----------------------------
class MiniLM(nn.Module):
    def __init__(self, vocab_size, dim=512, depth=6, heads=8, ff_dim=2048, max_len=512):
        super().__init__()
        self.token_emb = nn.Embedding(vocab_size, dim)
        self.pos_enc = PositionalEncoding(dim, max_len=max_len)
        self.blocks = nn.Sequential(*[
            TransformerBlock(dim, heads, ff_dim) for _ in range(depth)
        ])
        self.ln_f = nn.LayerNorm(dim)
        self.head = nn.Linear(dim, vocab_size)

    def forward(self, idx):
        x = self.token_emb(idx)
        x = self.pos_enc(x)
        x = self.blocks(x)
        x = self.ln_f(x)
        logits = self.head(x)
        return logits

# -----------------------------
# 6. Inference / Sample Usage
# -----------------------------
vocab_size = 50257
model = MiniLM(vocab_size).to(device)

sample_input = torch.randint(0, vocab_size, (1, 32)).to(device)
logits = model(sample_input)
print("Output shape:", logits.shape)  # (1, 32, vocab_size)
