<a href="https://colab.research.google.com/github/CuriousCaliBoi/AgniKai/blob/main/Transformer_Circuits.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# === Mini Transformer with Attribution Graphs (Phase 1) ===
# Phase 1: Minimal Transformer Block (decoder-only) in PyTorch

import torch
import torch.nn as nn
import torch.nn.functional as F

class MiniTransformerBlock(nn.Module):
    def __init__(self, d_model, n_heads):
        super().__init__()
        self.ln1 = nn.LayerNorm(d_model)
        self.attn = nn.MultiheadAttention(d_model, n_heads, batch_first=True)
        self.ln2 = nn.LayerNorm(d_model)
        self.ff = nn.Sequential(
            nn.Linear(d_model, 4 * d_model),
            nn.GELU(),
            nn.Linear(4 * d_model, d_model)
        )

    def forward(self, x):
        # Self-attention block
        attn_input = self.ln1(x)
        attn_output, attn_weights = self.attn(attn_input, attn_input, attn_input)
        x = x + attn_output

        # Feedforward block
        ff_input = self.ln2(x)
        x = x + self.ff(ff_input)

        return x, attn_weights


class MiniTransformer(nn.Module):
    def __init__(self, vocab_size, d_model=128, n_heads=4, n_layers=2, max_len=64):
        super().__init__()
        self.token_embed = nn.Embedding(vocab_size, d_model)
        self.pos_embed = nn.Parameter(torch.randn(1, max_len, d_model))
        self.blocks = nn.ModuleList([
            MiniTransformerBlock(d_model, n_heads) for _ in range(n_layers)
        ])
        self.ln_f = nn.LayerNorm(d_model)
        self.output_proj = nn.Linear(d_model, vocab_size)

    def forward(self, idx):
        B, T = idx.shape
        x = self.token_embed(idx) + self.pos_embed[:, :T, :]
        all_attn = []

        for block in self.blocks:
            x, attn = block(x)
            all_attn.append(attn)

        x = self.ln_f(x)
        logits = self.output_proj(x)
        return logits, all_attn


# === Example Usage ===
vocab_size = 100
model = MiniTransformer(vocab_size)

# Dummy batch of token IDs
x = torch.randint(0, vocab_size, (2, 10))
logits, attn_weights = model(x)
print("Logits shape:", logits.shape)
print("Attention weights shape:", [a.shape for a in attn_weights])