In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import math
from dataclasses import dataclass
from torch.amp import autocast, GradScaler 
import torch.utils.checkpoint

# ==========================================
# 1. CONFIGURATION (Optimized for 5GB VRAM)
# ==========================================
@dataclass
class GemmaZeroConfig:
    vocab_size: int = 32000      # Reduced from 256k to save VRAM
    hidden_size: int = 768       # Decent size for reasoning
    intermediate_size: int = 2048 # GeGLU expansion
    num_hidden_layers: int = 16  # Deep enough for complex logic
    num_attention_heads: int = 8
    num_key_value_heads: int = 4 # GQA (Grouped Query Attention)
    head_dim: int = 96
    max_position_embeddings: int = 1024 # Context window
    rms_norm_eps: float = 1e-6
    rope_theta: float = 10000.0
    attn_logit_softcapping: float = 50.0 # Theoretical Hack: Stability
    final_logit_softcapping: float = 30.0

# ==========================================
# 2. LOW-LEVEL MODULES
# ==========================================
class GemmaRMSNorm(nn.Module):
    def __init__(self, dim: int, eps: float = 1e-6):
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(dim))

    def forward(self, x):
        x_float = x.float()
        variance = x_float.pow(2).mean(-1, keepdim=True)
        x_float = x_float * torch.rsqrt(variance + self.eps)
        return (x_float * self.weight.float()).type_as(x) + 1.0

class GemmaRotaryEmbedding(nn.Module):
    def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
        super().__init__()
        # Calculate locally first to avoid attribute collision
        inv_freq_tensor = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim))
        self.register_buffer("inv_freq", inv_freq_tensor, persistent=False)

    def forward(self, x, seq_len=None):
        t = torch.arange(seq_len, device=x.device, dtype=self.inv_freq.dtype)
        freqs = torch.einsum("i,j->ij", t, self.inv_freq)
        emb = torch.cat((freqs, freqs), dim=-1)
        return emb.cos(), emb.sin()

def rotate_half(x):
    x1 = x[..., : x.shape[-1] // 2]
    x2 = x[..., x.shape[-1] // 2 :]
    return torch.cat((-x2, x1), dim=-1)

def apply_rotary_pos_emb(q, k, cos, sin):
    q_embed = (q * cos) + (rotate_half(q) * sin)
    k_embed = (k * cos) + (rotate_half(k) * sin)
    return q_embed, k_embed

# ==========================================
# 3. ATTENTION & MLP (The Brain)
# ==========================================
class GemmaAttention(nn.Module):
    def __init__(self, config: GemmaZeroConfig):
        super().__init__()
        self.config = config
        self.num_heads = config.num_attention_heads
        self.head_dim = config.head_dim
        self.num_key_value_heads = config.num_key_value_heads
        self.num_key_value_groups = self.num_heads // self.num_key_value_heads
        
        self.q_proj = nn.Linear(config.hidden_size, self.num_heads * self.head_dim, bias=False)
        self.k_proj = nn.Linear(config.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
        self.v_proj = nn.Linear(config.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
        self.o_proj = nn.Linear(self.num_heads * self.head_dim, config.hidden_size, bias=False)
        
        self.rotary_emb = GemmaRotaryEmbedding(self.head_dim, config.max_position_embeddings, config.rope_theta)

    def forward(self, hidden_states, attention_mask=None):
        bsz, q_len, _ = hidden_states.size()
        q = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
        k = self.k_proj(hidden_states).view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
        v = self.v_proj(hidden_states).view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)

        cos, sin = self.rotary_emb(v, seq_len=q_len)
        q, k = apply_rotary_pos_emb(q, k, cos, sin)

        k = k.repeat_interleave(self.num_key_value_groups, dim=1)
        v = v.repeat_interleave(self.num_key_value_groups, dim=1)

        attn_weights = torch.matmul(q, k.transpose(2, 3)) / math.sqrt(self.head_dim)
        
        # Soft-capping Hack
        if self.config.attn_logit_softcapping is not None:
            attn_weights = torch.tanh(attn_weights / self.config.attn_logit_softcapping) * self.config.attn_logit_softcapping
        
        if attention_mask is not None:
            attn_weights = attn_weights + attention_mask

        attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(q.dtype)
        attn_output = torch.matmul(attn_weights, v)
        attn_output = attn_output.transpose(1, 2).contiguous().view(bsz, q_len, -1)
        return self.o_proj(attn_output)

class GemmaMLP(nn.Module):
    def __init__(self, config: GemmaZeroConfig):
        super().__init__()
        self.gate_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False)
        self.up_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False)
        self.down_proj = nn.Linear(config.intermediate_size, config.hidden_size, bias=False)

    def forward(self, x):
        return self.down_proj(F.gelu(self.gate_proj(x)) * self.up_proj(x))

# ==========================================
# 4. MAIN MODEL ARCHITECTURE
# ==========================================
class GemmaBlock(nn.Module):
    def __init__(self, config: GemmaZeroConfig):
        super().__init__()
        self.input_layernorm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        self.self_attn = GemmaAttention(config)
        self.post_attention_layernorm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        self.mlp = GemmaMLP(config)

    def forward(self, hidden_states, attention_mask=None):
        residual = hidden_states
        hidden_states = self.input_layernorm(hidden_states)
        hidden_states = self.self_attn(hidden_states, attention_mask=attention_mask)
        hidden_states = residual + hidden_states

        residual = hidden_states
        hidden_states = self.post_attention_layernorm(hidden_states)
        hidden_states = self.mlp(hidden_states)
        hidden_states = residual + hidden_states
        return hidden_states

class GemmaZeroModel(nn.Module):
    def __init__(self, config: GemmaZeroConfig):
        super().__init__()
        self.config = config
        self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
        self.layers = nn.ModuleList([GemmaBlock(config) for _ in range(config.num_hidden_layers)])
        self.norm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        self.embed_scale = math.sqrt(config.hidden_size)
        self.gradient_checkpointing = False

    def gradient_checkpointing_enable(self):
        self.gradient_checkpointing = True

    def forward(self, input_ids, attention_mask=None):
        x = self.embed_tokens(input_ids) * self.embed_scale
        
        for layer in self.layers:
            if self.gradient_checkpointing and self.training:
                x = torch.utils.checkpoint.checkpoint(layer, x, attention_mask, use_reentrant=False)
            else:
                x = layer(x, attention_mask)
            
        x = self.norm(x)
        
        # Tied Weights + Soft-capping on output
        logits = torch.matmul(x, self.embed_tokens.weight.t())
        if self.config.final_logit_softcapping is not None:
             logits = torch.tanh(logits / self.config.final_logit_softcapping) * self.config.final_logit_softcapping
             
        return logits

# ==========================================
# 5. TRAINING LOOP (Memory Optimized)
# ==========================================
def train():
    device = "cuda" if torch.cuda.is_available() else "cpu"
    print(f"Running on: {device}")
    
    config = GemmaZeroConfig()
    model = GemmaZeroModel(config).to(device)
    
    # Enable memory saving
    model.gradient_checkpointing_enable()

    print(f"Model Created. Params: {sum(p.numel() for p in model.parameters()) / 1e6:.2f}M")
    
    optimizer = optim.AdamW(model.parameters(), lr=3e-4, weight_decay=0.01)
    scaler = GradScaler('cuda') # For mixed precision

    # Settings for 5GB VRAM
    batch_size = 1
    grad_accum_steps = 4
    seq_len = 1024
    
    # Dummy Dataset
    inputs = torch.randint(0, config.vocab_size, (batch_size, seq_len)).to(device)
    labels = inputs.clone()

    model.train()
    print("Starting Training Loop...")
    
    optimizer.zero_grad()
    
    for step in range(50):
        
        # Mixed Precision Context
        with autocast('cuda', dtype=torch.float16):
            logits = model(inputs)
            
            # Shift for Causal LM loss
            shift_logits = logits[..., :-1, :].contiguous()
            shift_labels = labels[..., 1:].contiguous()
            
            loss = F.cross_entropy(shift_logits.view(-1, config.vocab_size), shift_labels.view(-1))
            loss = loss / grad_accum_steps # Normalize

        # Backward
        scaler.scale(loss).backward()
        
        if (step + 1) % grad_accum_steps == 0:
            scaler.step(optimizer)
            scaler.update()
            optimizer.zero_grad()
            
            mem = torch.cuda.memory_allocated() / 1e9
            print(f"Step {step+1} | Loss: {loss.item() * grad_accum_steps:.4f} | VRAM Used: {mem:.2f} GB")

if __name__ == "__main__":
    train()

Running on: cuda
Model Created. Params: 128.41M
Starting Training Loop...
Step 4 | Loss: 38.4384 | VRAM Used: 1.65 GB
Step 8 | Loss: 33.9174 | VRAM Used: 1.65 GB
Step 12 | Loss: 29.9557 | VRAM Used: 1.65 GB
Step 16 | Loss: 26.4443 | VRAM Used: 1.65 GB
Step 20 | Loss: 23.4977 | VRAM Used: 1.65 GB
Step 24 | Loss: 21.1215 | VRAM Used: 1.65 GB
Step 28 | Loss: 19.1322 | VRAM Used: 1.65 GB
Step 32 | Loss: 17.3895 | VRAM Used: 1.65 GB
Step 36 | Loss: 15.8175 | VRAM Used: 1.65 GB
Step 40 | Loss: 14.4026 | VRAM Used: 1.65 GB
