<a href="https://colab.research.google.com/github/SAIROHITHARETI/TinyLLM/blob/main/LLAMA_LLM.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [12]:
import torch
import torch.nn as nn
import torch.optim as optim
import math

# =====================================================================
# 1. RMSNorm
# =====================================================================
class RMSNorm(nn.Module):
    def __init__(self, dim, eps=1e-6):
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(dim))

    def forward(self, x):
        rms = x.pow(2).mean(-1, keepdim=True).sqrt()
        return self.weight * (x / (rms + self.eps))


# =====================================================================
# 2. RoPE Helpers
# =====================================================================
def apply_rope(x, cos, sin):
    x1 = x[..., ::2]
    x2 = x[..., 1::2]

    cos = cos[..., : x1.size(-1)]
    sin = sin[..., : x1.size(-1)]

    out1 = x1 * cos - x2 * sin
    out2 = x1 * sin + x2 * cos

    out = torch.zeros_like(x)
    out[..., ::2] = out1
    out[..., 1::2] = out2
    return out


def build_rope(freqs, T, device):
    t = torch.arange(T, device=device)
    freqs = torch.outer(t, freqs)
    cos = freqs.cos()[None, :, None, :]
    sin = freqs.sin()[None, :, None, :]
    return cos, sin


# =====================================================================
# 3. GQA Attention with KV-Cache
# =====================================================================
class GQA(nn.Module):
    def __init__(self, dim, n_heads, n_kv_heads):
        super().__init__()
        assert dim % n_heads == 0
        self.dim = dim
        self.n_heads = n_heads
        self.n_kv_heads = n_kv_heads
        self.head_dim = dim // n_heads

        self.q_proj = nn.Linear(dim, dim)
        self.k_proj = nn.Linear(dim, self.head_dim * n_kv_heads)
        self.v_proj = nn.Linear(dim, self.head_dim * n_kv_heads)
        self.o_proj = nn.Linear(dim, dim)

    def forward(self, x, cos, sin, kv_cache=None):
        B, T, D = x.shape

        # Project q/k/v
        q = self.q_proj(x).view(B, T, self.n_heads, self.head_dim)
        k = self.k_proj(x).view(B, T, self.n_kv_heads, self.head_dim)
        v = self.v_proj(x).view(B, T, self.n_kv_heads, self.head_dim)

        # RoPE
        q = apply_rope(q, cos, sin)
        k = apply_rope(k, cos, sin)

        # Expand KV heads
        repeat = self.n_heads // self.n_kv_heads
        k = k.repeat_interleave(repeat, dim=2)
        v = v.repeat_interleave(repeat, dim=2)

        # ---------------------------------------------------------
        #   KV CACHE
        # ---------------------------------------------------------
        if kv_cache is not None:
            prev_k, prev_v = kv_cache
            k = torch.cat([prev_k, k], dim=1)
            v = torch.cat([prev_v, v], dim=1)

        new_cache = (k, v)  # save for next step

        # Attention scores
        att = torch.einsum("bthd,bThd->bhtT", q, k) / math.sqrt(self.head_dim)

        # Causal mask only needed when T > 1
        if kv_cache is None:
            causal = torch.tril(torch.ones(att.size(-1), att.size(-1), device=x.device))
            att = att.masked_fill(causal == 0, float('-inf'))

        att = torch.softmax(att, dim=-1)

        out = torch.einsum("bhtT,bThd->bthd", att, v)
        out = out.reshape(B, T, D)

        return self.o_proj(out), new_cache


# =====================================================================
# 4. Transformer Block
# =====================================================================
class TransformerBlock(nn.Module):
    def __init__(self, dim, n_heads, n_kv_heads, mlp_ratio=4):
        super().__init__()
        self.norm1 = RMSNorm(dim)
        self.attn = GQA(dim, n_heads, n_kv_heads)
        self.norm2 = RMSNorm(dim)

        hidden = dim * mlp_ratio
        self.mlp = nn.Sequential(
            nn.Linear(dim, hidden),
            nn.GELU(),
            nn.Linear(hidden, dim),
        )

    def forward(self, x, cos, sin, kv_cache=None):
        att_out, new_cache = self.attn(self.norm1(x), cos, sin, kv_cache)
        x = x + att_out
        x = x + self.mlp(self.norm2(x))
        return x, new_cache


# =====================================================================
# 5. MiniLLaMA Model with KV-cache Support
# =====================================================================
class MiniLLaMA(nn.Module):
    def __init__(self, vocab_size, dim=128, depth=3, n_heads=4, n_kv_heads=1):
        super().__init__()
        self.embed = nn.Embedding(vocab_size, dim)

        freqs = 1.0 / (10000 ** (torch.arange(0, dim, 2) / dim))
        self.register_buffer("freqs", freqs)

        self.blocks = nn.ModuleList([
            TransformerBlock(dim, n_heads, n_kv_heads) for _ in range(depth)
        ])

        self.norm = RMSNorm(dim)
        self.lm_head = nn.Linear(dim, vocab_size, bias=False)

    def forward(self, idx, kv_cache=None):
        B, T = idx.shape
        x = self.embed(idx)

        cos, sin = build_rope(self.freqs, T, x.device)

        new_cache = []

        for i, blk in enumerate(self.blocks):
            block_cache = kv_cache[i] if kv_cache is not None else None
            x, updated = blk(x, cos, sin, block_cache)
            new_cache.append(updated)

        x = self.norm(x)
        logits = self.lm_head(x)

        return logits, new_cache



In [13]:
# =====================================================================
# 6. Tiny Dataset ("abcabcabc")
# =====================================================================
text = "abcdefghijklmnopqrstuvwxyz" * 5
chars = sorted(list(set(text)))

stoi = {ch: i for i, ch in enumerate(chars)}
itos = {i: ch for ch, i in stoi.items()}

vocab_size = len(chars)
data = torch.tensor([stoi[c] for c in text], dtype=torch.long)


# =====================================================================
# 7. Mini Training Loop
# =====================================================================
model = MiniLLaMA(vocab_size)
optimizer = optim.Adam(model.parameters(), lr=1e-3)

def get_batch(seq_len=8):
    i = torch.randint(0, len(data) - seq_len - 1, (1,))
    x = data[i:i + seq_len]
    y = data[i + 1:i + seq_len + 1]
    return x.unsqueeze(0), y.unsqueeze(0)

print("Training...")
for step in range(1000):
    x, y = get_batch()

    logits, _ = model(x)
    loss = nn.functional.cross_entropy(logits.view(-1, vocab_size), y.view(-1))

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    if step % 100 == 0:
        print(f"step {step}: loss={loss.item():.4f}")

Training...
step 0: loss=3.3483
step 100: loss=0.0135
step 200: loss=0.0063
step 300: loss=0.0036
step 400: loss=0.0025
step 500: loss=0.0021
step 600: loss=0.0015
step 700: loss=0.0012
step 800: loss=0.0008
step 900: loss=0.0007


In [18]:
# =====================================================================
# 8. Autoregressive Generation with KV-cache
# =====================================================================
def generate(model, start, length=25):
    model.eval()
    idx = torch.tensor([[stoi[ch] for ch in start]], dtype=torch.long)
    kv_cache = None

    for step in range(length):
        logits, kv_cache = model(idx[:, -1:], kv_cache)

        # print(f"\n=== Generation Step {step+1} ===")
        # for i, (k, v) in enumerate(kv_cache):
        #     print(f"Block {i}: k = {k.shape}, v = {v.shape}")

        next_id = torch.argmax(logits[:, -1], dim=-1)
        idx = torch.cat([idx, next_id.unsqueeze(0)], dim=1)

    return "".join(itos[i.item()] for i in idx[0])

In [19]:
print("\nGenerated:")
print(generate(model, "b"))


Generated:
bcdefghijklmnopqrstuvwxyza
