# üß† TRM LoRA Efficiency Study

**One-Click Runnable** notebook for LoRA experiments on TinyRecursiveModels.

## ‚úÖ Alignment with Trelis Implementation
This notebook is aligned with `TinyRecursiveModels-Trelis` (trm.yaml + cfg_pretrain_lora.yaml):

| Parameter | This Notebook | Trelis | Status |
|-----------|---------------|--------|--------|
| `L_cycles` | 4 | 4 | ‚úÖ Aligned |
| `puzzle_emb_len` | 1 | 1 | ‚úÖ Aligned |
| `lr_warmup_steps` | 2000 | 2000 | ‚úÖ Aligned |
| `global_batch_size` | 768 | 768 | ‚úÖ Aligned |
| `H_cycles` | 3 | 3 | ‚úÖ Aligned |
| `L_layers` | 2 | 2 | ‚úÖ Aligned |
| `hidden_size` | 512 | 512 | ‚úÖ Aligned |
| `num_heads` | 8 | 8 | ‚úÖ Aligned |
| `expansion` | 4 | 4 | ‚úÖ Aligned |
| `halt_max_steps` | 16 | 16 | ‚úÖ Aligned |
| `halt_exploration_prob` | 0.1 | 0.1 | ‚úÖ Aligned |
| `lr` | 1e-4 | 1e-4 | ‚úÖ Aligned |
| `weight_decay` | 0.1 | 0.1 | ‚úÖ Aligned |
| `ema_rate` | 0.999 | 0.999 | ‚úÖ Aligned |

**Differences (intentional):**
- Single GPU (Trelis uses multi-GPU)
- AdamW optimizer (Trelis uses Muon)

## LoRA Experiments
| Config | LoRA Rank | Alpha | Train Base |
|--------|-----------|-------|------------|
| baseline | 0 | - | True |
| lora-r1 | 1 | 16 | False |
| lora-r4 | 4 | 16 | False |
| lora-r16 | 16 | 32 | False |
| lora-r4-full | 4 | 16 | True |

In [None]:
# ============================================================================
# Cell 1: Configuration - EDIT THIS
# ============================================================================
import os

# API Keys
WANDB_API_KEY = '' # Or set directly
HF_TOKEN = ""

# Experiment Selection
SELECTED_CONFIG = 'lora-r1'  # Options: 'baseline', 'lora-r1', 'lora-r4', 'lora-r16', 'lora-r4-full'

# Training Settings (Aligned with Trelis cfg_pretrain_lora.yaml)
TRAIN_EPOCHS = 10000
BATCH_SIZE = 1024  # Aligned: Trelis cfg_pretrain_lora.yaml
EVAL_INTERVAL = 1000

# Dataset Settings
TRAIN_SUBSAMPLE = 1000
NUM_AUGMENT = 1000
FORCE_REBUILD = False

print(f'Selected: {SELECTED_CONFIG}, Epochs: {TRAIN_EPOCHS}, BatchSize: {BATCH_SIZE} (Aligned with Trelis)')

Selected: lora-r1, Epochs: 10000, BatchSize: 1024 (Aligned with Trelis)


In [14]:
# ============================================================================
# Cell 2: Install Dependencies
# ============================================================================
!pip install -q torch einops tqdm numpy pydantic wandb coolname datasets
print('Dependencies installed')

Dependencies installed


In [15]:
# ============================================================================
# Cell 3: Imports & GPU Setup
# ============================================================================
from typing import Optional, Any, Sequence, List, Tuple, Dict
from dataclasses import dataclass
import os, math, json, shutil, copy, time
import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, IterableDataset
import numpy as np
from tqdm import tqdm
import wandb
import coolname
import pydantic
from pydantic import BaseModel
import einops
from torch.nn.functional import scaled_dot_product_attention

IGNORE_LABEL_ID = -100

if torch.cuda.is_available():
    gpu = torch.cuda.get_device_name(0)
    print(f'GPU: {gpu}')
    if 'H100' in gpu or 'A100' in gpu:
        torch.backends.cuda.matmul.allow_tf32 = True
        torch.backends.cudnn.allow_tf32 = True
        print('TF32 enabled')
    torch.cuda.set_device(0)

if WANDB_API_KEY:
    wandb.login(key=WANDB_API_KEY)
    print('Wandb logged in')



GPU: NVIDIA A100-SXM4-40GB
TF32 enabled
Wandb logged in


In [16]:
# ============================================================================
# Cell 4: Common Utilities
# ============================================================================
def trunc_normal_init_(tensor, std=1.0, lower=-2.0, upper=2.0):
    with torch.no_grad():
        if std == 0:
            tensor.zero_()
        else:
            sqrt2 = math.sqrt(2)
            a, b = math.erf(lower/sqrt2), math.erf(upper/sqrt2)
            z = (b-a)/2
            c = (2*math.pi)**-0.5
            pdf_u, pdf_l = c*math.exp(-0.5*lower**2), c*math.exp(-0.5*upper**2)
            comp_std = std/math.sqrt(1-(upper*pdf_u-lower*pdf_l)/z-((pdf_u-pdf_l)/z)**2)
            tensor.uniform_(a, b).erfinv_().mul_(sqrt2*comp_std).clip_(lower*comp_std, upper*comp_std)
    return tensor

CosSin = Tuple[torch.Tensor, torch.Tensor]
def _find_multiple(a, b): return (-(a//-b))*b
def rotate_half(x): return torch.cat((-x[..., x.shape[-1]//2:], x[..., :x.shape[-1]//2]), dim=-1)
def apply_rotary_pos_emb(q, k, cos, sin):
    orig = q.dtype; q, k = q.to(cos.dtype), k.to(cos.dtype)
    return ((q*cos.unsqueeze(-2))+(rotate_half(q)*sin.unsqueeze(-2))).to(orig), ((k*cos.unsqueeze(-2))+(rotate_half(k)*sin.unsqueeze(-2))).to(orig)
def rms_norm(x, eps):
    dt = x.dtype; x = x.float(); return (x*torch.rsqrt(x.square().mean(-1,keepdim=True)+eps)).to(dt)
print('Utilities loaded')

Utilities loaded


In [17]:
# ============================================================================
# Cell 5: Layers with LoRA Support
# ============================================================================
class CastedLinear(nn.Module):
    def __init__(self, in_f, out_f, bias):
        super().__init__()
        self.weight = nn.Parameter(trunc_normal_init_(torch.empty((out_f, in_f)), std=1.0/(in_f**0.5)))
        self.bias = nn.Parameter(torch.zeros(out_f)) if bias else None
        self._lora_rank = 0; self._lora_alpha = 1.0; self._lora_scaling = 1.0
        self._lora_dropout = None; self._lora_A = None; self._lora_B = None
    def forward(self, x):
        out = F.linear(x, self.weight.to(x.dtype), self.bias.to(x.dtype) if self.bias is not None else None)
        if self._lora_rank > 0 and self._lora_A is not None:
            lx = self._lora_dropout(x) if self._lora_dropout else x
            out = out + F.linear(F.linear(lx.to(self._lora_A.dtype), self._lora_A), self._lora_B).to(out.dtype) * self._lora_scaling
        return out
    def enable_lora(self, rank, alpha=None, dropout=0.0, train_base=False, train_bias=False):
        if rank <= 0: return
        if self._lora_rank > 0: raise RuntimeError('LoRA already enabled')
        self._lora_rank = rank
        self._lora_alpha = float(alpha if alpha else rank)
        self._lora_scaling = self._lora_alpha / rank
        self._lora_dropout = nn.Dropout(dropout) if dropout > 0 else nn.Identity()
        self._lora_A = nn.Parameter(torch.zeros(rank, self.weight.shape[1]))
        self._lora_B = nn.Parameter(torch.zeros(self.weight.shape[0], rank))
        nn.init.kaiming_uniform_(self._lora_A, a=math.sqrt(5)); nn.init.zeros_(self._lora_B)
        if not train_base: self.weight.requires_grad = False
        if self.bias is not None and not train_bias: self.bias.requires_grad = False

class CastedEmbedding(nn.Module):
    def __init__(self, n, d, std, dtype):
        super().__init__()
        self.cast_to = dtype
        self.embedding_weight = nn.Parameter(trunc_normal_init_(torch.empty(n, d), std=std))
    def forward(self, x): return F.embedding(x, self.embedding_weight.to(self.cast_to))

class RotaryEmbedding(nn.Module):
    def __init__(self, dim, max_pos, base, device=None):
        super().__init__()
        inv_freq = 1.0/(base**(torch.arange(0, dim, 2, dtype=torch.float32, device=device)/dim))
        freqs = torch.outer(torch.arange(max_pos, dtype=torch.float32, device=device), inv_freq)
        emb = torch.cat((freqs, freqs), dim=-1)
        self.cos_cached = nn.Buffer(emb.cos(), persistent=False)
        self.sin_cached = nn.Buffer(emb.sin(), persistent=False)
    def forward(self): return self.cos_cached, self.sin_cached

class Attention(nn.Module):
    def __init__(self, hidden, head_dim, heads, kv_heads, causal=False):
        super().__init__()
        self.hidden, self.head_dim, self.heads, self.kv_heads, self.causal = hidden, head_dim, heads, kv_heads, causal
        self.out_size = head_dim * heads
        self.qkv_proj = CastedLinear(hidden, (heads + 2*kv_heads)*head_dim, False)
        self.o_proj = CastedLinear(self.out_size, hidden, False)
    def forward(self, cos_sin, x):
        B, L, _ = x.shape
        qkv = self.qkv_proj(x).view(B, L, self.heads + 2*self.kv_heads, self.head_dim)
        q, k, v = qkv[:,:,:self.heads], qkv[:,:,self.heads:self.heads+self.kv_heads], qkv[:,:,self.heads+self.kv_heads:]
        if cos_sin: q, k = apply_rotary_pos_emb(q, k, *cos_sin)
        q, k, v = (einops.rearrange(t, 'B S H D -> B H S D') for t in (q, k, v))
        out = scaled_dot_product_attention(q, k, v, is_causal=self.causal)
        return self.o_proj(einops.rearrange(out, 'B H S D -> B S H D').reshape(B, L, self.out_size))

class SwiGLU(nn.Module):
    def __init__(self, hidden, expansion):
        super().__init__()
        inter = _find_multiple(round(expansion*hidden*2/3), 256)
        self.gate_up = CastedLinear(hidden, inter*2, False)
        self.down = CastedLinear(inter, hidden, False)
    def forward(self, x):
        g, u = self.gate_up(x).chunk(2, dim=-1)
        return self.down(F.silu(g) * u)

def enable_lora_for_model(model, rank, alpha=None, dropout=0.0, train_base=False, train_bias=False):
    if rank <= 0: return 0
    cnt = 0
    for m in model.modules():
        if isinstance(m, CastedLinear):
            m.enable_lora(rank, alpha, dropout, train_base, train_bias); cnt += 1
    return cnt

def count_parameters(model):
    total = trainable = lora = 0
    for n, p in model.named_parameters():
        total += p.numel()
        if p.requires_grad: trainable += p.numel()
        if '_lora_' in n: lora += p.numel()
    return {'total': total, 'trainable': trainable, 'lora': lora, 'ratio': trainable/total if total else 0}

print('Layers loaded')

Layers loaded


In [18]:
# ============================================================================
# Cell 6: TRM Model (Aligned with TinyRecursiveModels-Trelis)
# ============================================================================
@dataclass
class TRMInnerCarry:
    z_H: torch.Tensor
    z_L: torch.Tensor

@dataclass
class TRMCarry:
    inner: TRMInnerCarry
    steps: torch.Tensor
    halted: torch.Tensor
    data: Dict[str, torch.Tensor]

class TRMConfig(BaseModel):
    """Config aligned with Trelis TinyRecursiveModels trm.yaml"""
    batch_size: int; seq_len: int; vocab_size: int; num_puzzle_identifiers: int
    H_cycles: int; L_cycles: int; H_layers: int; L_layers: int
    hidden_size: int; expansion: float; num_heads: int; pos_encodings: str
    rms_norm_eps: float = 1e-5; rope_theta: float = 10000.0
    halt_max_steps: int; halt_exploration_prob: float
    halt_max_steps_eval: Optional[int] = None
    forward_dtype: str = 'bfloat16'
    puzzle_emb_ndim: int = 0; puzzle_emb_len: int = 1  # Aligned: Trelis uses 1
    mlp_t: bool = False; no_ACT_continue: bool = True
    puzzle_emb_dropout: float = 0.0; grid_token_dropout: float = 0.0
    lora_rank: int = 0; lora_alpha: Optional[float] = 1.0  # Optional to handle baseline (None)
    lora_dropout: float = 0.0; lora_train_base: bool = False; lora_train_bias: bool = False

class TRMBlock(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.cfg = cfg
        if cfg.mlp_t:
            self.mlp_t = SwiGLU(cfg.seq_len + cfg.puzzle_emb_len, cfg.expansion)
        else:
            self.attn = Attention(cfg.hidden_size, cfg.hidden_size//cfg.num_heads, cfg.num_heads, cfg.num_heads, False)
        self.mlp = SwiGLU(cfg.hidden_size, cfg.expansion)
        self.eps = cfg.rms_norm_eps
    def forward(self, cos_sin, x):
        if self.cfg.mlp_t:
            x = x.transpose(1,2); x = rms_norm(x + self.mlp_t(x), self.eps); x = x.transpose(1,2)
        else:
            x = rms_norm(x + self.attn(cos_sin, x), self.eps)
        return rms_norm(x + self.mlp(x), self.eps)

class TRMReasoning(nn.Module):
    def __init__(self, layers): super().__init__(); self.layers = nn.ModuleList(layers)
    def forward(self, x, inj, **kw):
        x = x + inj
        for l in self.layers: x = l(hidden_states=x, **kw) if hasattr(l, 'hidden_states') else l(kw.get('cos_sin'), x)
        return x

class TRMInner(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.cfg = cfg
        self.dtype = getattr(torch, cfg.forward_dtype)
        self.puzzle_emb_len = cfg.puzzle_emb_len
        
        # Embedding scale (aligned with original)
        self.embed_scale = math.sqrt(cfg.hidden_size)
        embed_init_std = 1.0 / self.embed_scale
        
        self.embedding = CastedEmbedding(cfg.vocab_size, cfg.hidden_size, embed_init_std, self.dtype)
        
        if cfg.pos_encodings == 'rope':
            self.rotary = RotaryEmbedding(cfg.hidden_size//cfg.num_heads, cfg.seq_len + self.puzzle_emb_len, cfg.rope_theta)
        else:
            self.pos_emb = CastedEmbedding(cfg.seq_len + self.puzzle_emb_len, cfg.hidden_size, embed_init_std, self.dtype)
        
        self.L_level = TRMReasoning([TRMBlock(cfg) for _ in range(cfg.L_layers)])
        self.lm_head = CastedLinear(cfg.hidden_size, cfg.vocab_size, False)
        self.q_head = CastedLinear(cfg.hidden_size, 2, True)  # bias=True (aligned)
        
        # Initial states as Buffer (aligned with original)
        self.H_init = nn.Buffer(trunc_normal_init_(torch.empty(cfg.hidden_size, dtype=self.dtype), std=1), persistent=True)
        self.L_init = nn.Buffer(trunc_normal_init_(torch.empty(cfg.hidden_size, dtype=self.dtype), std=1), persistent=True)
        
        # Q head special init (aligned with original)
        with torch.no_grad():
            self.q_head.weight.zero_()
            self.q_head.bias.fill_(-5)
    
    def _input_embeddings(self, inputs):
        """Token embedding with scale."""
        emb = self.embedding(inputs.to(torch.int32))
        # Position embeddings (learned)
        if self.cfg.pos_encodings == 'learned':
            emb = 0.707106781 * (emb + self.pos_emb.embedding_weight.to(self.dtype))
        return self.embed_scale * emb
    
    def empty_carry(self, batch_size):
        return TRMInnerCarry(
            z_H=torch.empty(batch_size, self.cfg.seq_len + self.puzzle_emb_len, self.cfg.hidden_size, dtype=self.dtype),
            z_L=torch.empty(batch_size, self.cfg.seq_len + self.puzzle_emb_len, self.cfg.hidden_size, dtype=self.dtype),
        )
    
    def reset_carry(self, reset_flag, carry):
        return TRMInnerCarry(
            z_H=torch.where(reset_flag.view(-1, 1, 1), self.H_init, carry.z_H),
            z_L=torch.where(reset_flag.view(-1, 1, 1), self.L_init, carry.z_L),
        )
    
    def forward(self, carry, batch):
        cs = self.rotary() if hasattr(self, 'rotary') else None
        
        # Input encoding (no puzzle embedding for simplicity in Sudoku)
        # Pad zeros for puzzle_emb_len positions at the beginning
        input_emb = self._input_embeddings(batch['inputs'])
        B = input_emb.shape[0]
        pad = torch.zeros(B, self.puzzle_emb_len, self.cfg.hidden_size, dtype=self.dtype, device=input_emb.device)
        input_embeddings = torch.cat([pad, input_emb], dim=1)
        
        z_H, z_L = carry.z_H, carry.z_L
        
        # Forward iterations (aligned with original: H_cycles-1 no_grad, 1 with grad)
        with torch.no_grad():
            for _ in range(self.cfg.H_cycles - 1):
                for _ in range(self.cfg.L_cycles):
                    z_L = self.L_level(z_L, z_H + input_embeddings, cos_sin=cs)
                z_H = self.L_level(z_H, z_L, cos_sin=cs)
        
        # Last H cycle with grad
        for _ in range(self.cfg.L_cycles):
            z_L = self.L_level(z_L, z_H + input_embeddings, cos_sin=cs)
        z_H = self.L_level(z_H, z_L, cos_sin=cs)
        
        # Output from z_H (aligned with original)
        new_carry = TRMInnerCarry(z_H=z_H.detach(), z_L=z_L.detach())
        logits = self.lm_head(z_H)[:, self.puzzle_emb_len:]  # Remove puzzle_emb positions
        q_logits = self.q_head(z_H[:, 0]).to(torch.float32)  # Use first position (aligned)
        
        return new_carry, {'logits': logits, 'q_halt_logits': q_logits[:, 0], 'q_continue_logits': q_logits[:, 1]}

class TinyRecursiveReasoningModel_ACTV1(nn.Module):
    def __init__(self, cfg_dict):
        super().__init__()
        self.config = TRMConfig(**cfg_dict)
        self.inner = TRMInner(self.config)
    
    def initial_carry(self, batch):
        B = batch['inputs'].shape[0]
        return TRMCarry(
            inner=self.inner.empty_carry(B),
            steps=torch.zeros(B, dtype=torch.int32, device=batch['inputs'].device),
            halted=torch.ones(B, dtype=torch.bool, device=batch['inputs'].device),  # Default halted (aligned)
            data={k: torch.empty_like(v) for k, v in batch.items()}
        )
    
    def forward(self, carry, batch, **kw):
        # Reset carry for halted sequences (aligned with original)
        new_inner = self.inner.reset_carry(carry.halted, carry.inner)
        new_steps = torch.where(carry.halted, torch.zeros_like(carry.steps), carry.steps)
        new_data = {k: torch.where(carry.halted.view((-1,) + (1,)*(v.ndim-1)), batch[k], v) for k, v in carry.data.items()}
        
        # Forward inner model
        new_inner, out = self.inner(new_inner, new_data)
        
        with torch.no_grad():
            new_steps = new_steps + 1
            halt_limit = self.config.halt_max_steps if self.training or self.config.halt_max_steps_eval is None else self.config.halt_max_steps_eval
            is_last_step = new_steps >= halt_limit
            halted = is_last_step
            
            if self.training and self.config.halt_max_steps > 1:
                if self.config.no_ACT_continue:
                    halted = halted | (out['q_halt_logits'] > 0)
                else:
                    halted = halted | (out['q_halt_logits'] > out['q_continue_logits'])
                
                # Exploration (aligned with original)
                min_halt_steps = (torch.rand_like(out['q_halt_logits']) < self.config.halt_exploration_prob) * torch.randint_like(new_steps, low=2, high=self.config.halt_max_steps + 1)
                halted = halted & (new_steps >= min_halt_steps)
        
        return TRMCarry(new_inner, new_steps, halted, new_data), out

print('TRM Model loaded (aligned with Trelis)')

TRM Model loaded (aligned with Trelis)


In [19]:
# ============================================================================
# Cell 7: Loss Head (Aligned with Trelis)
# ============================================================================
def s(x, eps=1e-30): return torch.where(x<0, 1/(1-x+eps), x+1)
def log_stablemax(x, dim=-1): sx = s(x); return torch.log(sx/sx.sum(dim=dim, keepdim=True))
def stablemax_ce(logits, labels, ignore_index=-100, valid_mask=None):
    lp = log_stablemax(logits.double(), -1)
    if valid_mask is None: valid_mask = labels != ignore_index
    tl = torch.where(valid_mask, labels, 0)
    plp = torch.gather(lp, index=tl.long().unsqueeze(-1), dim=-1).squeeze(-1)
    return -torch.where(valid_mask, plp, 0.0)

class ACTLossHead(nn.Module):
    def __init__(self, model, loss_type): super().__init__(); self.model = model; self.loss_fn = stablemax_ce
    def initial_carry(self, *a, **k): return self.model.initial_carry(*a, **k)
    def forward(self, return_keys, **kw):
        carry, out = self.model(**kw)
        labels = carry.data['labels']  # Updated field name matches TRMCarry
        
        with torch.no_grad():
            out['preds'] = out['logits'].argmax(-1)
            mask = labels != IGNORE_LABEL_ID
            cnt = mask.sum(-1); div = cnt.clamp_min(1).unsqueeze(-1)
            correct = mask & (out['preds'] == labels)
            seq_correct = correct.sum(-1) == cnt
            
            # Metrics only for halted sequences (aligned with original)
            valid = carry.halted & (cnt > 0)
            metrics = {
                'count': valid.sum(),
                'accuracy': torch.where(valid, (correct.float()/div).sum(-1), 0.0).sum(),
                'exact_accuracy': (valid & seq_correct).sum(),
                'q_halt_accuracy': (valid & ((out['q_halt_logits'] >= 0) == seq_correct)).sum(),  # Added (aligned)
                'steps': torch.where(valid, carry.steps, 0).sum()
            }
        
        # Losses
        lm_loss = (self.loss_fn(out['logits'], labels, valid_mask=mask)/div).sum()
        q_halt_loss = F.binary_cross_entropy_with_logits(out['q_halt_logits'], seq_correct.float(), reduction='sum')
        
        # Q continue loss (only if target exists, aligned with original)
        q_continue_loss = 0
        if 'target_q_continue' in out:
            q_continue_loss = F.binary_cross_entropy_with_logits(out['q_continue_logits'], out['target_q_continue'], reduction='sum')
            metrics['q_continue_loss'] = q_continue_loss.detach()
        
        metrics.update({'lm_loss': lm_loss.detach(), 'q_halt_loss': q_halt_loss.detach()})
        
        total_loss = lm_loss + 0.5 * (q_halt_loss + q_continue_loss)
        return carry, total_loss, metrics, {k: out[k].detach() for k in return_keys if k in out}, carry.halted.all()

print('Loss Head loaded (aligned with Trelis)')

Loss Head loaded (aligned with Trelis)


In [20]:
# ============================================================================
# Cell 8: Dataset Classes
# ============================================================================
class PuzzleDatasetMetadata(pydantic.BaseModel):
    pad_id: int; ignore_label_id: Optional[int]; blank_identifier_id: int
    vocab_size: int; seq_len: int; num_puzzle_identifiers: int
    total_groups: int; mean_puzzle_examples: float; total_puzzles: int; sets: List[str]

class PuzzleDatasetConfig(pydantic.BaseModel):
    seed: int; dataset_paths: List[str]; global_batch_size: int
    test_set_mode: bool; epochs_per_iter: int; rank: int = 0; num_replicas: int = 1

class PuzzleDataset(IterableDataset):
    """Aligned with Trelis puzzle_dataset.py"""
    def __init__(self, cfg, split='train'):
        super().__init__()
        self.cfg, self.split = cfg, split
        with open(os.path.join(cfg.dataset_paths[0], split, 'dataset.json')) as f:
            self.metadata = PuzzleDatasetMetadata(**json.load(f))
        self.local_bs = cfg.global_batch_size // cfg.num_replicas
        self._data = None; self._iters = 0
    def _load(self):
        if self._data: return
        self._data = {}
        for s in self.metadata.sets:
            p = self.cfg.dataset_paths[0]
            self._data[s] = {k: np.load(os.path.join(p, self.split, f'{s}__{k}.npy'), mmap_mode='r' if k in ['inputs','labels'] else None)
                            for k in ['inputs','labels','puzzle_identifiers','puzzle_indices','group_indices']}
            # Aligned: compute puzzle_group_ids for task_identifiers
            gi = self._data[s]['group_indices']
            puzzle_group_ids = np.empty(int(gi[-1]), dtype=np.int32)
            for gid in range(gi.size - 1):
                puzzle_group_ids[int(gi[gid]):int(gi[gid+1])] = gid
            self._data[s]['puzzle_group_ids'] = puzzle_group_ids
    def _collate(self, b):
        """Aligned: include task_identifiers with pad=-1"""
        b = {k: v.astype(np.int32) for k,v in b.items()}
        if self.metadata.ignore_label_id: b['labels'][b['labels']==self.metadata.ignore_label_id] = IGNORE_LABEL_ID
        if b['puzzle_identifiers'].size < self.local_bs:
            pad = self.local_bs - b['puzzle_identifiers'].size
            pv = {'inputs': self.metadata.pad_id, 'labels': IGNORE_LABEL_ID, 'puzzle_identifiers': self.metadata.blank_identifier_id, 'task_identifiers': -1}
            b = {k: np.pad(v, ((0,pad),)+((0,0),)*(v.ndim-1), constant_values=pv[k]) for k,v in b.items()}
        return {k: torch.from_numpy(v) for k,v in b.items()}
    def __iter__(self):
        self._load()
        if self.cfg.test_set_mode: yield from self._test()
        else: yield from self._train()
    def _test(self):
        for sn, d in self._data.items():
            for i in range(0, len(d['inputs']), self.cfg.global_batch_size):
                j = min(len(d['inputs']), i+self.local_bs)
                pid = d['puzzle_identifiers'][i:j]
                pid = pid if pid.ndim==1 else pid[:,0]
                # Aligned: include task_identifiers
                yield sn, self._collate({'inputs': d['inputs'][i:j], 'labels': d['labels'][i:j],
                      'puzzle_identifiers': pid, 'task_identifiers': d['puzzle_group_ids'][pid]}), j-i
    def _train(self):
        for sn, d in self._data.items():
            self._iters += 1
            rng = np.random.default_rng(self.cfg.seed + self._iters)
            gi, pi = d['group_indices'], d['puzzle_indices']
            for _ in range(self.cfg.epochs_per_iter):
                order = rng.permutation(gi.size-1); idx = 0
                while idx < len(order):
                    bi, bp, bg, sz = [], [], [], 0  # Aligned: bg for group ids
                    while idx < len(order) and sz < self.cfg.global_batch_size:
                        gid = order[idx]; pid = rng.integers(gi[gid], gi[gid+1]); idx += 1
                        ps, pe = pi[pid], pi[pid+1]; psz = int(pe-ps)
                        add = min(psz, self.cfg.global_batch_size-sz)
                        bi.append(ps + rng.choice(psz, add, replace=False))
                        bp.append(np.full(add, pid, np.int32))
                        bg.append(np.full(add, gid, np.int32))  # Aligned: task_identifiers
                        sz += add
                    if bi:
                        ii = np.concatenate(bi)
                        # Aligned: include task_identifiers
                        yield sn, self._collate({'inputs': d['inputs'][ii], 'labels': d['labels'][ii], 
                              'puzzle_identifiers': np.concatenate(bp), 'task_identifiers': np.concatenate(bg)}), len(ii)

print('Dataset loaded')

Dataset loaded


In [21]:
# ============================================================================
# Cell 9: Training Framework (Single GPU + AdamW, Aligned with Trelis)
# ============================================================================
class LossConfig(pydantic.BaseModel):
    model_config = pydantic.ConfigDict(extra='allow')
    name: str

class ArchConfig(pydantic.BaseModel):
    model_config = pydantic.ConfigDict(extra='allow')
    name: str; loss: LossConfig

class PretrainConfig(pydantic.BaseModel):
    """Aligned with Trelis PretrainConfig (pretrain.py)"""
    arch: ArchConfig; data_paths: List[str]; data_paths_test: List[str] = []
    global_batch_size: int; epochs: int; lr: float; lr_min_ratio: float; lr_warmup_steps: int
    weight_decay: float; beta1: float; beta2: float
    puzzle_emb_ndim: int = 0; puzzle_emb_lr: float = 0.0; puzzle_emb_weight_decay: float = 0.0
    freeze_weights: bool = False; checkpoint_path: Optional[str] = None
    checkpoint_every_eval: bool = False; load_checkpoint: Optional[str] = None
    project_name: Optional[str] = None; run_name: Optional[str] = None
    eval_interval: Optional[int] = None; min_eval_interval: int = 0  # Aligned: when to start eval
    ema: bool = False; ema_rate: float = 0.999; seed: int = 0

@dataclass
class TrainState:
    model: nn.Module
    optimizers: List[Any]
    optimizer_lrs: List[float]
    carry: Any  # Carry persists across batches (aligned with original)
    step: int
    total_steps: int

class EMAHelper:
    def __init__(self, mu=0.999): self.mu = mu; self.shadow = {}
    def register(self, m):
        for n, p in m.named_parameters():
            if p.requires_grad: self.shadow[n] = p.data.clone()
    def update(self, m):
        for n, p in m.named_parameters():
            if p.requires_grad and n in self.shadow: self.shadow[n] = self.mu*self.shadow[n] + (1-self.mu)*p.data
    def ema_copy(self, m):
        mc = copy.deepcopy(m)
        for n, p in mc.named_parameters():
            if n in self.shadow: p.data.copy_(self.shadow[n])
        return mc

def cosine_lr(step, base_lr, warmup, total, min_ratio=0.0):
    if step < warmup: return base_lr * step / max(1, warmup)
    prog = (step - warmup) / max(1, total - warmup)
    return base_lr * (min_ratio + (1-min_ratio) * 0.5 * (1 + math.cos(math.pi * prog)))

def create_dataloader(cfg, split, **kw):
    ds = PuzzleDataset(PuzzleDatasetConfig(seed=cfg.seed, dataset_paths=cfg.data_paths_test if split=='test' and cfg.data_paths_test else cfg.data_paths, **kw), split)
    return DataLoader(ds, batch_size=None, num_workers=1, prefetch_factor=8, pin_memory=True, persistent_workers=True), ds.metadata

def create_model(cfg, meta):
    model_cfg = {**cfg.arch.__pydantic_extra__, 'batch_size': cfg.global_batch_size, 'vocab_size': meta.vocab_size,
                'seq_len': meta.seq_len, 'num_puzzle_identifiers': meta.num_puzzle_identifiers, 'causal': False}
    lora_rank = cfg.arch.__pydantic_extra__.get('lora_rank', 0)
    lora_alpha = cfg.arch.__pydantic_extra__.get('lora_alpha', None)
    lora_train_base = cfg.arch.__pydantic_extra__.get('lora_train_base', False)
    with torch.device('cuda'):
        model = TinyRecursiveReasoningModel_ACTV1(model_cfg)
        model = ACTLossHead(model, cfg.arch.loss.__pydantic_extra__.get('loss_type', 'stablemax_cross_entropy'))
        if lora_rank > 0:
            cnt = enable_lora_for_model(model, lora_rank, lora_alpha, 0.0, lora_train_base, False)
            print(f'LoRA enabled: {cnt} layers, rank={lora_rank}, alpha={lora_alpha}, train_base={lora_train_base}')
        stats = count_parameters(model)
        print(f'Params: {stats["total"]:,} total, {stats["trainable"]:,} trainable ({stats["ratio"]:.2%}), {stats["lora"]:,} LoRA')
        if 'DISABLE_COMPILE' not in os.environ: model = torch.compile(model)
    # AdamW optimizer (only trainable params)
    opt = torch.optim.AdamW([p for p in model.parameters() if p.requires_grad], lr=cfg.lr, weight_decay=cfg.weight_decay, betas=(cfg.beta1, cfg.beta2))
    return model, [opt], [cfg.lr]

def init_train_state(cfg, meta):
    total = int(cfg.epochs * meta.total_groups * meta.mean_puzzle_examples / cfg.global_batch_size)
    model, opts, lrs = create_model(cfg, meta)
    return TrainState(model, opts, lrs, None, 0, total)

def train_batch(cfg, ts, batch, gbs):
    """
    Train one batch - aligned with original Trelis implementation.
    
    Key: Only ONE forward pass per batch. ACT loop happens across batches via carry.
    """
    ts.step += 1
    if ts.step > ts.total_steps: return None
    batch = {k: v.cuda() for k,v in batch.items()}
    
    # Init carry if None (first batch)
    if ts.carry is None:
        with torch.device('cuda'):
            ts.carry = ts.model.initial_carry(batch)
    
    # Single forward pass (carry handles ACT state across batches)
    ts.carry, loss, metrics, _, _ = ts.model(carry=ts.carry, batch=batch, return_keys=[])
    
    # Backward
    (loss / gbs).backward()
    
    # Update LR and step optimizers
    lr_now = None
    for opt, base_lr in zip(ts.optimizers, ts.optimizer_lrs):
        lr_now = cosine_lr(ts.step, base_lr, cfg.lr_warmup_steps, ts.total_steps, cfg.lr_min_ratio)
        for pg in opt.param_groups: pg['lr'] = lr_now
        opt.step(); opt.zero_grad()
    
    if metrics:
        keys = sorted(metrics.keys())
        vals = torch.stack([metrics[k] for k in keys]).cpu().numpy()
        rm = {k: vals[i] for i,k in enumerate(keys)}
        cnt = max(rm.get('count', 1), 1)
        return {f'train/{k}': v/(gbs if k.endswith('loss') else cnt) for k,v in rm.items()} | {'train/lr': lr_now}
    return None

def evaluate(cfg, ts, loader, meta, max_batches=100):
    """Evaluate with optional batch limit for faster iteration."""
    metrics = None
    with torch.inference_mode():
        for i, (sn, batch, gbs) in enumerate(loader):
            if max_batches and i >= max_batches: break  # Limit for faster eval
            batch = {k: v.cuda() for k,v in batch.items()}
            with torch.device('cuda'): carry = ts.model.initial_carry(batch)
            # Full ACT loop for evaluation (aligned with Trelis)
            while True:
                carry, _, m, _, done = ts.model(carry=carry, batch=batch, return_keys=[])
                if done: break
            if metrics is None: metrics = {k: 0.0 for k in m}
            for k in m: metrics[k] += m[k].item() if torch.is_tensor(m[k]) else m[k]
    if metrics:
        cnt = max(metrics.get('count', 1), 1)
        return {f'eval/{k}': v/cnt for k,v in metrics.items() if k != 'count'}
    return {}

print('Training framework loaded (aligned with Trelis)')

Training framework loaded (aligned with Trelis)


In [22]:
# ============================================================================
# Cell 10: LoRA Configurations
# ============================================================================
LORA_CONFIGS = {
    'baseline': {'lora_rank': 0, 'lora_alpha': None, 'train_base': True, 'desc': 'Full training'},
    'lora-r1': {'lora_rank': 1, 'lora_alpha': 16.0, 'train_base': False, 'desc': 'LoRA r=1'},
    'lora-r4': {'lora_rank': 4, 'lora_alpha': 16.0, 'train_base': False, 'desc': 'LoRA r=4'},
    'lora-r16': {'lora_rank': 16, 'lora_alpha': 32.0, 'train_base': False, 'desc': 'LoRA r=16'},
    'lora-r4-full': {'lora_rank': 4, 'lora_alpha': 16.0, 'train_base': True, 'desc': 'LoRA r=4 + base'},
}
print('LoRA configs:', list(LORA_CONFIGS.keys()))

LoRA configs: ['baseline', 'lora-r1', 'lora-r4', 'lora-r16', 'lora-r4-full']


In [23]:
# ============================================================================
# Cell 11: Build Sudoku Dataset
# ============================================================================
from datasets import load_dataset
from huggingface_hub import hf_hub_download, login
import warnings

# Set Hugging Face Token for authentication
os.environ["HF_TOKEN"] = HF_TOKEN

# Login to Hugging Face Hub
try:
    login(token=HF_TOKEN, add_to_git_credential=False)
    print("‚úÖ Successfully authenticated with Hugging Face Hub")
except Exception as e:
    print(f"‚ö†Ô∏è Warning: Could not login to Hugging Face Hub: {e}")
    print("   Continuing with token in environment variable...")

DATASET_DIR = './data/sudoku_lora'

def build_sudoku_dataset():
    if os.path.exists(os.path.join(DATASET_DIR, 'train', 'dataset.json')) and not FORCE_REBUILD:
        print('Dataset exists, skipping build')
        return
    print('Building Sudoku dataset...')
    # Explicitly pass token to avoid Colab secrets timeout
    ds = load_dataset('sapientinc/sudoku-extreme', token=HF_TOKEN)
    
    # Debug: print available columns
    print(f'Available columns: {ds["train"].column_names}')
    
    # For sapientinc/sudoku-extreme: columns are ['source', 'question', 'answer', 'rating']
    # 'question' = puzzle (81 chars with '.' for empty), 'answer' = solution (81 digits)
    sample = ds['train'][0]
    
    # Try common column name patterns
    puzzle_key = next((k for k in sample.keys() if k in ['question', 'puzzle', 'quiz', 'input']), None)
    solution_key = next((k for k in sample.keys() if k in ['answer', 'solution', 'output', 'target']), None)
    
    if puzzle_key is None or solution_key is None:
        print(f'ERROR: Could not find puzzle/solution keys. Available: {list(sample.keys())}')
        print(f'Sample data: {sample}')
        raise ValueError('Cannot determine puzzle/solution columns')
    
    print(f'Using: puzzle_key={puzzle_key}, solution_key={solution_key}')
    
    def shuffle_sudoku(inp, out):
        perm = np.random.permutation(9) + 1
        mapping = np.zeros(11, dtype=inp.dtype); mapping[1:10] = perm
        return mapping[inp], mapping[out]
    
    def convert(split):
        data = ds[split]
        if split == 'train': data = data.select(range(min(TRAIN_SUBSAMPLE, len(data))))
        # Use detected keys, handle '.' as 0 for empty cells
        inputs = [np.array([0 if c == '.' else int(c) for c in str(r[puzzle_key])], dtype=np.int8) for r in data]
        labels = [np.array([int(c) for c in str(r[solution_key])], dtype=np.int8) for r in data]
        aug = 0 if split == 'test' else NUM_AUGMENT
        res = {k: [] for k in ['inputs', 'labels', 'puzzle_identifiers', 'puzzle_indices', 'group_indices']}
        res['puzzle_indices'].append(0); res['group_indices'].append(0)
        pid = eid = 0
        for inp, lab in zip(inputs, labels):
            for ai in range(1 + aug):
                i, l = (inp, lab) if ai == 0 else shuffle_sudoku(inp, lab)
                res['inputs'].append(i); res['labels'].append(l); eid += 1; pid += 1
                res['puzzle_indices'].append(eid); res['puzzle_identifiers'].append(0)
            res['group_indices'].append(pid)
        def to_np(s): arr = np.stack(s); return (arr + 1).astype(np.int32)
        res = {'inputs': to_np(res['inputs']), 'labels': to_np(res['labels']),
              'group_indices': np.array(res['group_indices'], np.int32),
              'puzzle_indices': np.array(res['puzzle_indices'], np.int32),
              'puzzle_identifiers': np.array(res['puzzle_identifiers'], np.int32)}
        meta = PuzzleDatasetMetadata(seq_len=81, vocab_size=11, pad_id=0, ignore_label_id=0, blank_identifier_id=0,
                                    num_puzzle_identifiers=1, total_groups=len(res['group_indices'])-1,
                                    mean_puzzle_examples=1, total_puzzles=len(res['group_indices'])-1, sets=['all'])
        path = os.path.join(DATASET_DIR, split)
        os.makedirs(path, exist_ok=True)
        with open(os.path.join(path, 'dataset.json'), 'w') as f: json.dump(meta.model_dump(), f)
        for k, v in res.items(): np.save(os.path.join(path, f'all__{k}.npy'), v)
        print(f'{split}: {res["inputs"].shape[0]} examples')
    convert('train'); convert('test')
    print('Dataset built!')

build_sudoku_dataset()

Note: Environment variable`HF_TOKEN` is set and is the current active token independently from the token you've just configured.


‚úÖ Successfully authenticated with Hugging Face Hub
Building Sudoku dataset...
Available columns: ['source', 'question', 'answer', 'rating']
Using: puzzle_key=question, solution_key=answer
train: 1001000 examples
test: 422786 examples
Dataset built!


In [None]:
# ============================================================================
# Cell 12: Launch Training
# ============================================================================
import time

def launch_experiment(config_name):
    # Clear GPU memory before starting experiment
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
        torch.cuda.reset_peak_memory_stats()
        print(f'GPU memory cleared. Free: {torch.cuda.get_device_properties(0).total_memory - torch.cuda.memory_allocated(0):.2f} GB')
    
    lora_cfg = LORA_CONFIGS[config_name]
    print('='*70)
    print(f'Experiment: {config_name} - {lora_cfg["desc"]}')
    print('='*70)
    
    # Track efficiency metrics for report
    exp_start_time = time.time()
    
    # ======== ALIGNED WITH TRELIS (trm.yaml + cfg_pretrain_lora.yaml) ========
    cfg_dict = {
        'arch': {'name': 'TRM', 'loss': {'name': 'ACTLossHead', 'loss_type': 'stablemax_cross_entropy'},
                'halt_exploration_prob': 0.1, 'halt_max_steps': 16, 'H_cycles': 3, 'L_cycles': 4,  # Aligned: Trelis uses 4
                'H_layers': 0, 'L_layers': 2, 'hidden_size': 512, 'num_heads': 8, 'expansion': 4,
                'puzzle_emb_ndim': 0, 'pos_encodings': 'rope', 'forward_dtype': 'bfloat16',
                'mlp_t': False, 'puzzle_emb_len': 1, 'no_ACT_continue': True,  # Aligned: Trelis uses 1
                'lora_rank': lora_cfg['lora_rank'], 'lora_alpha': lora_cfg['lora_alpha'], 'lora_train_base': lora_cfg['train_base']},
        'data_paths': [DATASET_DIR], 'data_paths_test': [],
        'global_batch_size': BATCH_SIZE, 'epochs': TRAIN_EPOCHS, 'eval_interval': EVAL_INTERVAL,
        'lr': 1e-4, 'lr_min_ratio': 1.0, 'lr_warmup_steps': 1000,  # Aligned: Trelis cfg_pretrain_lora.yaml
        'beta1': 0.9, 'beta2': 0.95, 'weight_decay': 0.1, 'ema': True, 'ema_rate': 0.999, 'seed': 0,
        'min_eval_interval': 0,  # Aligned: Trelis - when to start eval
        'project_name': 'TRM-A100-LoRA', 'run_name': f'sudoku-{config_name}',
        'checkpoint_path': f'./checkpoints/{config_name}', 'checkpoint_every_eval': True,
    }
    cfg = PretrainConfig(**cfg_dict)
    torch.manual_seed(cfg.seed)
    
    train_per_iter = cfg.eval_interval or cfg.epochs
    total_iters = cfg.epochs // train_per_iter
    train_loader, train_meta = create_dataloader(cfg, 'train', test_set_mode=False, epochs_per_iter=train_per_iter, global_batch_size=cfg.global_batch_size)
    try: eval_loader, _ = create_dataloader(cfg, 'test', test_set_mode=True, epochs_per_iter=1, global_batch_size=cfg.global_batch_size)
    except: eval_loader = None
    
    ts = init_train_state(cfg, train_meta)
    pbar = tqdm(total=ts.total_steps)
    
    # Compute efficiency metrics for report
    param_stats = count_parameters(ts.model)
    memory_allocated = torch.cuda.max_memory_allocated() / 1e9  # GB
    
    wandb.init(project=cfg.project_name, name=cfg.run_name, config=cfg_dict)
    wandb.log({
        'params/total': param_stats['total'],
        'params/trainable': param_stats['trainable'],
        'params/lora': param_stats['lora'],
        'params/trainable_ratio': param_stats['ratio'],
        'efficiency/memory_gb': memory_allocated,
    }, step=0)
    
    print(f"üìä Efficiency Metrics:")
    print(f"   Total Params: {param_stats['total']:,}")
    print(f"   Trainable: {param_stats['trainable']:,} ({param_stats['ratio']:.2%})")
    print(f"   LoRA Params: {param_stats['lora']:,}")
    print(f"   Memory: {memory_allocated:.2f} GB")
    
    ema = EMAHelper(cfg.ema_rate) if cfg.ema else None
    if ema: ema.register(ts.model)
    
    for it in range(total_iters):
        print(f'Epoch {it * train_per_iter}')
        ts.model.train()
        for sn, batch, gbs in train_loader:
            m = train_batch(cfg, ts, batch, gbs)
            if m: wandb.log(m, step=ts.step); pbar.update(ts.step - pbar.n)
            if ema: ema.update(ts.model)
        
        # Aligned with Trelis: check min_eval_interval before evaluating
        if it >= cfg.min_eval_interval and eval_loader is not None:
            print('Evaluating...')
            # Aligned with Trelis: deep copy train_state for EMA evaluation
            if ema:
                ts_eval = copy.deepcopy(ts)
                ts_eval.model = ema.ema_copy(ts.model)
            else:
                ts_eval = ts
            ts_eval.model.eval()
            em = evaluate(cfg, ts_eval, eval_loader, train_meta)
            if em: wandb.log(em, step=ts.step); print(em)
            if ema: del ts_eval  # Aligned: clean up
    
    pbar.close()
    
    # Final efficiency summary for report
    total_time = time.time() - exp_start_time
    peak_memory = torch.cuda.max_memory_allocated() / 1e9
    time_per_step = total_time / max(ts.step, 1)
    
    summary = {
        'config': config_name,
        'total_params': param_stats['total'],
        'trainable_params': param_stats['trainable'],
        'lora_params': param_stats['lora'],
        'trainable_ratio': param_stats['ratio'],
        'peak_memory_gb': peak_memory,
        'total_time_s': total_time,
        'time_per_step_ms': time_per_step * 1000,
        'total_steps': ts.step,
    }
    
    wandb.log({
        'summary/peak_memory_gb': peak_memory,
        'summary/total_time_s': total_time,
        'summary/time_per_step_ms': time_per_step * 1000,
    }, step=ts.step)
    wandb.finish()
    
    print(f'\n{"="*70}')
    print(f'üìà EXPERIMENT SUMMARY: {config_name}')
    print(f'{"="*70}')
    print(f'   Trainable Params: {param_stats["trainable"]:,} / {param_stats["total"]:,} ({param_stats["ratio"]:.2%})')
    print(f'   Peak Memory: {peak_memory:.2f} GB')
    print(f'   Total Time: {total_time:.1f}s ({total_time/60:.1f} min)')
    print(f'   Time/Step: {time_per_step*1000:.1f} ms')
    print(f'   Total Steps: {ts.step}')
    print(f'{"="*70}\n')
    
    return ts, summary

# Run the selected experiment
print(f'Running: {SELECTED_CONFIG}')
train_state, experiment_summary = launch_experiment(SELECTED_CONFIG)

In [None]:
# ============================================================================
# Cell 13: Run All Experiments & Generate Comparison Table
# ============================================================================
import pandas as pd

# Collect all experiment summaries
all_summaries = []
for name in ['baseline', 'lora-r1', 'lora-r4', 'lora-r16', 'lora-r4-full']:
    _, summary = launch_experiment(name)
    all_summaries.append(summary)
    torch.cuda.empty_cache()  # Clear memory between experiments

# Create comparison DataFrame for report
df = pd.DataFrame(all_summaries)
df['param_reduction'] = df['total_params'] / df['trainable_params']
df = df.round({'trainable_ratio': 4, 'peak_memory_gb': 2, 'time_per_step_ms': 1, 'param_reduction': 1})

print("\n" + "="*80)
print("üìä LORA EFFICIENCY COMPARISON TABLE (For Report)")
print("="*80)
print(df[['config', 'trainable_params', 'trainable_ratio', 'param_reduction', 'peak_memory_gb', 'time_per_step_ms']].to_string(index=False))
print("="*80)

# Save to CSV for report
df.to_csv('lora_experiment_results.csv', index=False)
print("‚úÖ Results saved to lora_experiment_results.csv")

GPU memory cleared. Free: 42255340544.00 GB
Experiment: baseline - Full training
Params: 6,828,034 total, 6,828,034 trainable (100.00%), 0 LoRA




0,1
efficiency/memory_gb,‚ñÅ
params/lora,‚ñÅ
params/total,‚ñÅ
params/trainable,‚ñÅ
params/trainable_ratio,‚ñÅ
train/accuracy,‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñà‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñà‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ
train/count,‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñà‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñà‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ
train/exact_accuracy,‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ
train/lm_loss,‚ñà‚ñÉ‚ñÉ‚ñÅ‚ñÅ‚ñÇ‚ñÇ‚ñÇ‚ñÇ‚ñÇ‚ñÇ‚ñÇ‚ñÇ‚ñÇ‚ñÇ‚ñÇ‚ñà‚ñÉ‚ñÉ‚ñÅ‚ñÅ‚ñÇ‚ñÅ‚ñÇ‚ñÇ‚ñÇ‚ñÇ‚ñÇ‚ñÇ‚ñÇ‚ñÇ‚ñÅ‚ñà‚ñÉ‚ñÉ‚ñÅ‚ñÅ
train/lr,‚ñÅ‚ñÅ‚ñÅ‚ñÇ‚ñÇ‚ñÇ‚ñÇ‚ñÇ‚ñÉ‚ñÉ‚ñÉ‚ñÉ‚ñÉ‚ñÑ‚ñÑ‚ñÑ‚ñÑ‚ñÑ‚ñÖ‚ñÖ‚ñÖ‚ñÖ‚ñÖ‚ñÖ‚ñÜ‚ñÜ‚ñÜ‚ñÜ‚ñÜ‚ñá‚ñá‚ñá‚ñá‚ñá‚ñà‚ñà‚ñà

0,1
efficiency/memory_gb,0.02743
params/lora,18445
params/total,6846479
params/trainable,24077
params/trainable_ratio,0.00352
train/accuracy,0
train/count,0
train/exact_accuracy,0
train/lm_loss,2.53825
train/lr,0.0


üìä Efficiency Metrics:
   Total Params: 6,828,034
   Trainable: 6,828,034 (100.00%)
   LoRA Params: 0
   Memory: 0.25 GB
Epoch 0




Evaluating...
{'eval/accuracy': 0.4120477709174156, 'eval/exact_accuracy': 0.0, 'eval/q_halt_accuracy': 1.0, 'eval/steps': 16.0, 'eval/lm_loss': 1.6305569865019307, 'eval/q_halt_loss': 0.0016139806667342781}
Epoch 1000




Evaluating...


Evaluating...
{'eval/accuracy': 0.6021992009878159, 'eval/exact_accuracy': 9.765625e-06, 'eval/q_halt_accuracy': 0.999990234375, 'eval/steps': 16.0, 'eval/lm_loss': 0.8943529775610007, 'eval/q_halt_loss': 0.00011378311552107334}
Epoch 3000




Evaluating...
{'eval/accuracy': 0.6380148100852966, 'eval/exact_accuracy': 0.01474609375, 'eval/q_halt_accuracy': 0.98525390625, 'eval/steps': 16.0, 'eval/lm_loss': 0.8032578144405588, 'eval/q_halt_loss': 0.06528268065303564}
Epoch 4000




Evaluating...
{'eval/accuracy': 0.6436900877952576, 'eval/exact_accuracy': 0.0227734375, 'eval/q_halt_accuracy': 0.992333984375, 'eval/steps': 16.0, 'eval/lm_loss': 0.7942803491411385, 'eval/q_halt_loss': 0.015955868144519627}
Epoch 5000




Evaluating...
{'eval/accuracy': 0.6355900460481644, 'eval/exact_accuracy': 0.029365234375, 'eval/q_halt_accuracy': 0.99392578125, 'eval/steps': 16.0, 'eval/lm_loss': 0.8906812705269657, 'eval/q_halt_loss': 0.02274650321342051}
Epoch 6000




Evaluating...
{'eval/accuracy': 0.6267912179231644, 'eval/exact_accuracy': 0.04251953125, 'eval/q_halt_accuracy': 0.94431640625, 'eval/steps': 16.0, 'eval/lm_loss': 1.0942984497024826, 'eval/q_halt_loss': 0.14879593133926391}
Epoch 7000




Evaluating...
{'eval/accuracy': 0.6251004391908646, 'eval/exact_accuracy': 0.046142578125, 'eval/q_halt_accuracy': 0.99541015625, 'eval/steps': 16.0, 'eval/lm_loss': 1.0419413143575298, 'eval/q_halt_loss': 0.02488753373734653}
Epoch 8000




Evaluating...
{'eval/accuracy': 0.6287447053194046, 'eval/exact_accuracy': 0.0571484375, 'eval/q_halt_accuracy': 0.99806640625, 'eval/steps': 16.0, 'eval/lm_loss': 0.9874203723998375, 'eval/q_halt_loss': 0.012353496546857058}
Epoch 9000




Evaluating...


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 9765/9765 [1:10:38<00:00,  2.30it/s]

{'eval/accuracy': 0.6319086450338364, 'eval/exact_accuracy': 0.07302734375, 'eval/q_halt_accuracy': 0.998203125, 'eval/steps': 16.0, 'eval/lm_loss': 1.0002173564546688, 'eval/q_halt_loss': 0.011818661601282657}





0,1
efficiency/memory_gb,‚ñÅ
eval/accuracy,‚ñÅ‚ñÑ‚ñá‚ñà‚ñà‚ñà‚ñá‚ñá‚ñà‚ñà
eval/exact_accuracy,‚ñÅ‚ñÅ‚ñÅ‚ñÇ‚ñÉ‚ñÑ‚ñÖ‚ñÖ‚ñÜ‚ñà
eval/lm_loss,‚ñà‚ñÑ‚ñÇ‚ñÅ‚ñÅ‚ñÇ‚ñÑ‚ñÉ‚ñÉ‚ñÉ
eval/q_halt_accuracy,‚ñà‚ñà‚ñà‚ñÜ‚ñá‚ñá‚ñÅ‚ñá‚ñà‚ñà
eval/q_halt_loss,‚ñÅ‚ñÅ‚ñÅ‚ñÑ‚ñÇ‚ñÇ‚ñà‚ñÇ‚ñÇ‚ñÇ
eval/steps,‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ
params/lora,‚ñÅ
params/total,‚ñÅ
params/trainable,‚ñÅ

0,1
efficiency/memory_gb,0.24753
eval/accuracy,0.63191
eval/exact_accuracy,0.07303
eval/lm_loss,1.00022
eval/q_halt_accuracy,0.9982
eval/q_halt_loss,0.01182
eval/steps,16
params/lora,0
params/total,6828034
params/trainable,6828034



üìà EXPERIMENT SUMMARY: baseline
   Trainable Params: 6,828,034 / 6,828,034 (100.00%)
   Peak Memory: 17.41 GB
   Total Time: 4238.2s (70.6 min)
   Time/Step: 423.8 ms
   Total Steps: 10000

GPU memory cleared. Free: 41997966848.00 GB
Experiment: lora-r1 - LoRA r=1
LoRA enabled: 10 layers, rank=1, alpha=16.0, train_base=False
Params: 6,846,479 total, 24,077 trainable (0.35%), 18,445 LoRA




üìä Efficiency Metrics:
   Total Params: 6,846,479
   Trainable: 24,077 (0.35%)
   LoRA Params: 18,445
   Memory: 0.50 GB
Epoch 0




Evaluating...


In [None]:
# ============================================================================
# Cell 14: Generate Report Figures
# ============================================================================
import matplotlib.pyplot as plt

# Load results if not in memory
try:
    df
except NameError:
    df = pd.read_csv('lora_experiment_results.csv')

fig, axes = plt.subplots(1, 3, figsize=(15, 4))

# Plot 1: Trainable Parameters
ax1 = axes[0]
colors = ['#FF6B6B' if 'baseline' in c else '#4ECDC4' for c in df['config']]
ax1.bar(df['config'], df['trainable_params'], color=colors)
ax1.set_ylabel('Trainable Parameters')
ax1.set_title('Parameter Efficiency')
ax1.tick_params(axis='x', rotation=45)
for i, v in enumerate(df['trainable_params']):
    ax1.text(i, v + 50000, f'{v/1e6:.2f}M', ha='center', fontsize=9)

# Plot 2: Memory Usage
ax2 = axes[1]
ax2.bar(df['config'], df['peak_memory_gb'], color=colors)
ax2.set_ylabel('Peak Memory (GB)')
ax2.set_title('Memory Efficiency')
ax2.tick_params(axis='x', rotation=45)

# Plot 3: Time per Step
ax3 = axes[2]
ax3.bar(df['config'], df['time_per_step_ms'], color=colors)
ax3.set_ylabel('Time per Step (ms)')
ax3.set_title('Training Speed')
ax3.tick_params(axis='x', rotation=45)

plt.tight_layout()
plt.savefig('lora_efficiency_comparison.png', dpi=150, bbox_inches='tight')
plt.show()
print("‚úÖ Figure saved to lora_efficiency_comparison.png")