# üß† TinyRecursiveModel: H100 Implementation

This notebook provides a **complete faithful reproduction** of the **TinyRecursiveModels** (TRM) architecture.

**Key Features:**
- ‚úÖ **Exact reproduction** of the original TinyRecursiveModels codebase
- ‚úÖ **H100 optimizations** (torch.compile, TF32, optimized attention)
- ‚úÖ **No distributed code** - single GPU implementation
- ‚úÖ **Ready to run** on H100 GPUs

**Based on:** [TinyRecursiveModels](https://github.com/AlexiaJM/TinyRecursiveModels)


# TOKEN

In [None]:
# Hugging Face Token
HF_TOKEN = ""
# Wandb Token
WANDB_API_KEY = ''


## Part 1: Install Dependencies

In [1]:
!uv pip install -q torch einops tqdm numpy pydantic wandb coolname ninja wheel triton nvidia-ml-py
print("‚úÖ Dependencies installed!")
print("üìù Note: Using PyTorch's built-in AdamW optimizer (no need for adam-atan2-pytorch)")
print("üìù Note: Installed nvidia-ml-py for GPU monitoring")

‚úÖ Dependencies installed!
üìù Note: Using PyTorch's built-in AdamW optimizer (no need for adam-atan2-pytorch)
üìù Note: Installed nvidia-ml-py for GPU monitoring


## Part 2: Imports

In [None]:
from typing import Optional, Any, Sequence, List
from dataclasses import dataclass
import os
import math
import yaml
import shutil
import copy
import importlib
import inspect
import time

import torch
from torch import nn
from torch.utils.data import DataLoader

from tqdm import tqdm
import wandb
import coolname
import pydantic
from pydantic import BaseModel

# Using PyTorch's built-in AdamW optimizer instead of AdamAtan2
# AdamW is more standard and widely used

# H100 optimizations
if torch.cuda.is_available():
    gpu_name = torch.cuda.get_device_name(0)
    print(f"GPU: {gpu_name}")
    if "H100" in gpu_name or "A100" in gpu_name:
        torch.backends.cuda.matmul.allow_tf32 = True
        torch.backends.cudnn.allow_tf32 = True
        print("üöÄ H100/A100 detected! TF32 enabled")
    torch.cuda.set_device(0)

# ============================================================================
# Wandb Configuration (Similar to TRM_Baseline.ipynb)
# ============================================================================
print("="*70)
print("üìä Configuring Weights & Biases")
print("="*70)



# Project configuration
WANDB_PROJECT = "TRM-A100-Sudoku"
wandb.login(key=WANDB_API_KEY)
WANDB_ENTITY = None  # Set to your W&B username/team if needed (e.g., "jarviszhang-new-york-university")

print(f"üìÅ W&B Project: {WANDB_PROJECT}")
if WANDB_ENTITY:
    print(f"üë§ W&B Entity: {WANDB_ENTITY}")
print("="*70)

GPU: NVIDIA A100-SXM4-40GB
üöÄ H100/A100 detected! TF32 enabled
üìä Configuring Weights & Biases


  self.setter(val)
[34m[1mwandb[0m: No netrc file found, creating one.
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
[34m[1mwandb[0m: Currently logged in as: [33mjarviszhang[0m ([33mjarviszhang-new-york-university[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


üìÅ W&B Project: TRM-A100-Sudoku


## Part 3: Common Utilities

In [3]:
import math

import torch
from torch import nn


def trunc_normal_init_(tensor: torch.Tensor, std: float = 1.0, lower: float = -2.0, upper: float = 2.0):
    # NOTE: PyTorch nn.init.trunc_normal_ is not mathematically correct, the std dev is not actually the std dev of initialized tensor
    # This function is a PyTorch version of jax truncated normal init (default init method in flax)
    # https://github.com/jax-ml/jax/blob/main/jax/_src/random.py#L807-L848
    # https://github.com/jax-ml/jax/blob/main/jax/_src/nn/initializers.py#L162-L199

    with torch.no_grad():
        if std == 0:
            tensor.zero_()
        else:
            sqrt2 = math.sqrt(2)
            a = math.erf(lower / sqrt2)
            b = math.erf(upper / sqrt2)
            z = (b - a) / 2

            c = (2 * math.pi) ** -0.5
            pdf_u = c * math.exp(-0.5 * lower ** 2)
            pdf_l = c * math.exp(-0.5 * upper ** 2)
            comp_std = std / math.sqrt(1 - (upper * pdf_u - lower * pdf_l) / z - ((pdf_u - pdf_l) / z) ** 2)

            tensor.uniform_(a, b)
            tensor.erfinv_()
            tensor.mul_(sqrt2 * comp_std)
            tensor.clip_(lower * comp_std, upper * comp_std)

    return tensor


## Part 4: Layers

In [4]:
from typing import Tuple
import einops
import torch
from torch import nn
import torch.nn.functional as F

#try:
#    from flash_attn_interface import flash_attn_func  # type: ignore[import]
#except ImportError:
#    # Fallback to FlashAttention 2
#    from flash_attn import flash_attn_func  # type: ignore[import]
from torch.nn.functional import scaled_dot_product_attention

# trunc_normal_init_ defined in Part 3


CosSin = Tuple[torch.Tensor, torch.Tensor]


def _find_multiple(a, b):
    return (-(a // -b)) * b


def rotate_half(x: torch.Tensor):
    """Rotates half the hidden dims of the input."""
    x1 = x[..., : x.shape[-1] // 2]
    x2 = x[..., x.shape[-1] // 2 :]
    return torch.cat((-x2, x1), dim=-1)


def apply_rotary_pos_emb(q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor):
    # q, k: [bs, seq_len, num_heads, head_dim]
    # cos, sin: [seq_len, head_dim]
    orig_dtype = q.dtype
    q = q.to(cos.dtype)
    k = k.to(cos.dtype)

    q_embed = (q * cos.unsqueeze(-2)) + (rotate_half(q) * sin.unsqueeze(-2))
    k_embed = (k * cos.unsqueeze(-2)) + (rotate_half(k) * sin.unsqueeze(-2))

    return q_embed.to(orig_dtype), k_embed.to(orig_dtype)


class CastedLinear(nn.Module):
    def __init__(self,
                 in_features: int,
                 out_features: int,
                 bias: bool):
        super().__init__()
        # Truncated LeCun normal init
        self.weight = nn.Parameter(
            trunc_normal_init_(torch.empty((out_features, in_features)), std=1.0 / (in_features ** 0.5))
        )
        self.bias = None
        if bias:
            # Zero init bias
            self.bias = nn.Parameter(torch.zeros((out_features, )))

    def forward(self, input: torch.Tensor) -> torch.Tensor:
        return F.linear(input, self.weight.to(input.dtype), bias=self.bias.to(input.dtype) if self.bias is not None else None)


class CastedEmbedding(nn.Module):
    def __init__(self,
                 num_embeddings: int,
                 embedding_dim: int,
                 init_std: float,
                 cast_to: torch.dtype):
        super().__init__()
        self.cast_to = cast_to

        # Truncated LeCun normal init
        self.embedding_weight = nn.Parameter(
            trunc_normal_init_(torch.empty((num_embeddings, embedding_dim)), std=init_std)
        )
        
    def forward(self, input: torch.Tensor) -> torch.Tensor:
        return F.embedding(input, self.embedding_weight.to(self.cast_to))


class RotaryEmbedding(nn.Module):
    def __init__(self, dim, max_position_embeddings, base, device=None):
        super().__init__()

        # RoPE
        inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim))
        t = torch.arange(max_position_embeddings, dtype=torch.float32, device=device)
        freqs = torch.outer(t, inv_freq)

        # Different from paper, but it uses a different permutation in order to obtain the same calculation
        emb = torch.cat((freqs, freqs), dim=-1)
        self.cos_cached = nn.Buffer(emb.cos(), persistent=False)
        self.sin_cached = nn.Buffer(emb.sin(), persistent=False)

    def forward(self):
        return self.cos_cached, self.sin_cached


class Attention(nn.Module):
    def __init__(self, hidden_size, head_dim, num_heads, num_key_value_heads, causal=False):
        super().__init__()

        self.hidden_size = hidden_size
        self.head_dim = head_dim
        self.output_size = head_dim * num_heads
        self.num_heads = num_heads
        self.num_key_value_heads = num_key_value_heads
        self.causal = causal

        self.qkv_proj = CastedLinear(self.hidden_size, (self.num_heads + 2 * self.num_key_value_heads) * self.head_dim, bias=False)
        self.o_proj = CastedLinear(self.output_size, self.hidden_size, bias=False)

    def forward(self, cos_sin: CosSin, hidden_states: torch.Tensor) -> torch.Tensor:
        batch_size, seq_len, _ = hidden_states.shape

        # hidden_states: [bs, seq_len, num_heads, head_dim]
        qkv = self.qkv_proj(hidden_states)

        # Split head
        qkv = qkv.view(batch_size, seq_len, self.num_heads + 2 * self.num_key_value_heads, self.head_dim)
        query = qkv[:, :, :self.num_heads]
        key = qkv[:, :, self.num_heads: self.num_heads + self.num_key_value_heads]
        value = qkv[:, :, self.num_heads + self.num_key_value_heads:]

        # RoPE
        if cos_sin is not None:
            cos, sin = cos_sin
            query, key = apply_rotary_pos_emb(query, key, cos, sin)

        # flash attn
        query, key, value = map(lambda t: einops.rearrange(t, 'B S H D -> B H S D'), (query, key, value)) # needed for scaled_dot_product_attention but not flash_attn_func
        attn_output = scaled_dot_product_attention(query=query, key=key, value=value, is_causal=self.causal)
        attn_output = einops.rearrange(attn_output, 'B H S D -> B S H D')
        attn_output = attn_output.view(batch_size, seq_len, self.output_size)  # type: ignore
        return self.o_proj(attn_output)

class LinearSwish(nn.Module):
    def __init__(self, hidden_size: int, reverse=False):
        super().__init__()

        self.linear = CastedLinear(hidden_size, hidden_size, bias=False)
        self.reverse = reverse

    def forward(self, x):
        if self.reverse:
            return F.silu(self.linear(x))
        else:
            return self.linear(F.silu(x))


class SwiGLU(nn.Module):
    def __init__(self, hidden_size: int, expansion: float):
        super().__init__()
        inter = _find_multiple(round(expansion * hidden_size * 2 / 3), 256)

        self.gate_up_proj = CastedLinear(hidden_size, inter * 2, bias=False)
        self.down_proj    = CastedLinear(inter, hidden_size, bias=False)

    def forward(self, x):
        gate, up = self.gate_up_proj(x).chunk(2, dim=-1)
        return self.down_proj(F.silu(gate) * up)

def rms_norm(hidden_states: torch.Tensor, variance_epsilon: float) -> torch.Tensor:
    input_dtype = hidden_states.dtype
    hidden_states = hidden_states.to(torch.float32)

    variance = hidden_states.square().mean(-1, keepdim=True)
    hidden_states = hidden_states * torch.rsqrt(variance + variance_epsilon)
    return hidden_states.to(input_dtype)


## Part 5: Sparse Embedding (Non-distributed)

In [5]:
from typing import Union

import torch
from torch import nn
# import torch.distributed as dist  # Removed for single GPU
from torch.optim.optimizer import Optimizer, ParamsT

# trunc_normal_init_ defined in Part 3


class CastedSparseEmbedding(nn.Module):
    def __init__(self, num_embeddings: int, embedding_dim: int, batch_size: int, init_std: float, cast_to: torch.dtype):
        super().__init__()
        self.cast_to = cast_to

        # Real Weights
        # Truncated LeCun normal init
        self.weights = nn.Buffer(
            trunc_normal_init_(torch.empty((num_embeddings, embedding_dim)), std=init_std), persistent=True
        )

        # Local weights and IDs
        # Local embeddings, with gradient, not persistent
        self.local_weights = nn.Buffer(torch.zeros(batch_size, embedding_dim, requires_grad=True), persistent=False)
        # Local embedding IDs, not persistent
        self.local_ids = nn.Buffer(torch.zeros(batch_size, dtype=torch.int32), persistent=False)

    def forward(self, inputs: torch.Tensor) -> torch.Tensor:
        if not self.training:
            # Test mode, no gradient
            return self.weights[inputs].to(self.cast_to)
            
        # Training mode, fill puzzle embedding from weights
        with torch.no_grad():
            self.local_weights.copy_(self.weights[inputs])
            self.local_ids.copy_(inputs)

        return self.local_weights.to(self.cast_to)


class CastedSparseEmbeddingSignSGD(Optimizer):
    def __init__(
        self,
        params: ParamsT,

        lr: Union[float, torch.Tensor] = 1e-3,
        weight_decay: float = 1e-2,
    ):
        if not 0.0 <= lr:
            raise ValueError(f"Invalid learning rate: {lr}")
        if not 0.0 <= weight_decay:
            raise ValueError(f"Invalid weight_decay value: {weight_decay}")

        defaults = dict(
            lr=lr,
            weight_decay=weight_decay,
            world_size=1  # Single GPU, no distributed training
        )
        super().__init__(params, defaults)

    @torch.no_grad
    def step(self, closure=None):  # type: ignore
        for group in self.param_groups:
            # Find the sparse embedding weights
            local_weights_grad = None
            local_ids = None
            weights = None
            
            assert len(group["params"]) == 3
            for p in group["params"]:
                if p.requires_grad:
                    local_weights_grad = p.grad
                elif p.ndim == 1:
                    local_ids = p
                elif p.ndim == 2:
                    weights = p
                else:
                    assert False
                
            assert local_ids is not None
            assert weights is not None
        
            # Apply SignSGD
            # Adam ‚âà SignSGD if gradient is very sparse
            if local_weights_grad is not None:
                _sparse_emb_signsgd_dist(
                    local_weights_grad,
                    local_ids,
                    weights,
                    
                    lr=group["lr"],
                    weight_decay=group["weight_decay"],
                    world_size=group["world_size"]
                )


def _sparse_emb_signsgd_dist(
    local_weights_grad: torch.Tensor,
    local_ids: torch.Tensor,
    weights: torch.Tensor,
    
    lr: float,
    weight_decay: float,
    world_size: int = 1,  # Single GPU, not used but kept for compatibility
    ) -> None:
    N, D = local_weights_grad.shape
    
    # All-gather
    all_weights_grad = local_weights_grad
    all_ids = local_ids

    if False:  # Single GPU
        all_weights_grad = torch.empty((world_size * N, D), dtype=local_weights_grad.dtype, device=local_weights_grad.device)
        all_ids = torch.empty(world_size * N,               dtype=local_ids.dtype,          device=local_ids.device)
    
        # dist.all_gather_into_tensor(all_weights_grad, local_weights_grad)  # Single GPU
        dist.all_gather_into_tensor(all_ids,          local_ids)

    # Unique
    grad_ids, inv = all_ids.unique(return_inverse=True)

    grad = torch.zeros((grad_ids.shape[0], D), dtype=all_weights_grad.dtype, device=all_weights_grad.device)
    grad.scatter_add_(0, inv.unsqueeze(-1).expand(-1, D), all_weights_grad)

    # SignSGD with decoupled weight decay
    p = weights[grad_ids]

    p.mul_(1.0 - lr * weight_decay).add_(torch.sign(grad), alpha=-lr)

    # Write updated slices back
    weights[grad_ids] = p


# Distributed version (for compatibility with original code, but works for single GPU)
class CastedSparseEmbeddingSignSGD_Distributed(Optimizer):
    """
    Distributed version of CastedSparseEmbeddingSignSGD.
    For single GPU, this is equivalent to CastedSparseEmbeddingSignSGD.
    """
    def __init__(
        self,
        params: ParamsT,
        world_size: int,
        lr: Union[float, torch.Tensor] = 1e-3,
        weight_decay: float = 1e-2,
    ):
        if not 0.0 <= lr:
            raise ValueError(f"Invalid learning rate: {lr}")
        if not 0.0 <= weight_decay:
            raise ValueError(f"Invalid weight_decay value: {weight_decay}")

        defaults = dict(
            lr=lr,
            weight_decay=weight_decay,
            world_size=world_size
        )
        super().__init__(params, defaults)

    @torch.no_grad
    def step(self, closure=None):  # type: ignore
        for group in self.param_groups:
            # Find the sparse embedding weights
            local_weights_grad = None
            local_ids = None
            weights = None
            
            assert len(group["params"]) == 3
            for p in group["params"]:
                if p.requires_grad:
                    local_weights_grad = p.grad
                elif p.ndim == 1:
                    local_ids = p
                elif p.ndim == 2:
                    weights = p
                else:
                    assert False
                
            assert local_ids is not None
            assert weights is not None
        
            # Apply SignSGD
            if local_weights_grad is not None:
                _sparse_emb_signsgd_dist(
                    local_weights_grad,
                    local_ids,
                    weights,
                    lr=group["lr"],
                    weight_decay=group["weight_decay"],
                    world_size=group["world_size"]
                )


## Part 6: Losses

In [6]:
from typing import Any, Tuple, Dict, Sequence, Optional

import torch
import torch.nn.functional as F
from torch import nn
import math

IGNORE_LABEL_ID = -100


def s(x, epsilon=1e-30):
    return torch.where(
        x<0,
        1/(1-x+ epsilon),
        x + 1
    )


def log_stablemax(x, dim=-1):
    s_x = s(x)
    return torch.log(s_x/torch.sum(s_x, dim=dim, keepdim=True))


def stablemax_cross_entropy(logits, labels, ignore_index: int = -100, valid_mask=None):
    logprobs = log_stablemax(logits.to(torch.float64), dim=-1)

    if valid_mask is None:
        valid_mask = (labels != ignore_index)
    transformed_labels = torch.where(valid_mask, labels, 0)
    prediction_logprobs = torch.gather(logprobs, index=transformed_labels.to(torch.long).unsqueeze(-1), dim=-1).squeeze(-1)

    return -torch.where(valid_mask, prediction_logprobs, 0)


def softmax_cross_entropy(logits, labels, ignore_index: int = -100):
    # Cast logits to f32
    # Flatten logits
    return F.cross_entropy(logits.to(torch.float32).view(-1, logits.shape[-1]), labels.to(torch.long).view(-1), ignore_index=ignore_index, reduction="none").view(labels.shape)


class ACTLossHead(nn.Module):
    def __init__(self, model: nn.Module, loss_type: str):
        super().__init__()
        self.model = model
        self.loss_fn = globals()[loss_type]
        
    def initial_carry(self, *args, **kwargs):
        return self.model.initial_carry(*args, **kwargs)  # type: ignore

    def forward(
        self,
        return_keys: Sequence[str],
        # Model args
        **model_kwargs,
    ) -> Tuple[Any, torch.Tensor, Dict[str, torch.Tensor], Optional[Dict[str, torch.Tensor]], torch.Tensor]:
        # Model logits
        # B x SeqLen x D
        new_carry, outputs = self.model(**model_kwargs)
        labels = new_carry.current_data["labels"]

        with torch.no_grad():
            # Preds
            outputs["preds"] = torch.argmax(outputs["logits"], dim=-1)

            # Correctness
            mask = (labels != IGNORE_LABEL_ID)
            loss_counts = mask.sum(-1)
            loss_divisor = loss_counts.clamp_min(1).unsqueeze(-1)  # Avoid NaNs in division

            is_correct = mask & (torch.argmax(outputs["logits"], dim=-1) == labels)
            seq_is_correct = is_correct.sum(-1) == loss_counts
            
            # Metrics (halted)
            valid_metrics = new_carry.halted & (loss_counts > 0)
            metrics = {
                "count": valid_metrics.sum(),
                
                "accuracy":       torch.where(valid_metrics, (is_correct.to(torch.float32) / loss_divisor).sum(-1), 0).sum(),
                "exact_accuracy": (valid_metrics & seq_is_correct).sum(),

                "q_halt_accuracy": (valid_metrics & ((outputs["q_halt_logits"] >= 0) == seq_is_correct)).sum(),
                "steps":          torch.where(valid_metrics, new_carry.steps, 0).sum(),
            }

        # Losses

        lm_loss = (self.loss_fn(outputs["logits"], labels, ignore_index=IGNORE_LABEL_ID, valid_mask=mask) / loss_divisor).sum()
        q_halt_loss = F.binary_cross_entropy_with_logits(outputs["q_halt_logits"], seq_is_correct.to(outputs["q_halt_logits"].dtype), reduction="sum")
        metrics.update({
            "lm_loss": lm_loss.detach(),
            "q_halt_loss": q_halt_loss.detach(),
        })
        # Q continue (bootstrapping target loss); Alexia: This fits Q-learning, but seems totally unecessary
        q_continue_loss = 0
        if "target_q_continue" in outputs:
            q_continue_loss = F.binary_cross_entropy_with_logits(outputs["q_continue_logits"], outputs["target_q_continue"], reduction="sum")

            metrics["q_continue_loss"] = q_continue_loss.detach()
        # Filter outputs for return
        detached_outputs = {k: outputs[k].detach() for k in return_keys if k in outputs}

        return new_carry, lm_loss + 0.5 * (q_halt_loss + q_continue_loss), metrics, detached_outputs, new_carry.halted.all()



## Part 7: TRM Model

In [7]:
from typing import Tuple, List, Dict, Optional
from dataclasses import dataclass
import math
import torch
import copy
import torch.nn.functional as F
from torch import nn
from pydantic import BaseModel
import random
# trunc_normal_init_ defined in Part 3
# All layer classes defined in Part 4
# CastedSparseEmbedding defined in Part 5

IGNORE_LABEL_ID = -100

@dataclass
class TinyRecursiveReasoningModel_ACTV1InnerCarry:
    z_H: torch.Tensor
    z_L: torch.Tensor


@dataclass
class TinyRecursiveReasoningModel_ACTV1Carry:
    inner_carry: TinyRecursiveReasoningModel_ACTV1InnerCarry
    
    steps: torch.Tensor
    halted: torch.Tensor
    
    current_data: Dict[str, torch.Tensor]


class TinyRecursiveReasoningModel_ACTV1Config(BaseModel):
    batch_size: int
    seq_len: int
    puzzle_emb_ndim: int = 0
    num_puzzle_identifiers: int
    vocab_size: int

    H_cycles: int
    L_cycles: int

    H_layers: int # ignored
    L_layers: int

    # Transformer config
    hidden_size: int
    expansion: float
    num_heads: int
    pos_encodings: str

    rms_norm_eps: float = 1e-5
    rope_theta: float = 10000.0
    
    # Halting Q-learning config
    halt_max_steps: int
    halt_exploration_prob: float

    forward_dtype: str = "bfloat16"

    # Alexia: added
    mlp_t: bool = False # use mlp on L instead of transformer
    puzzle_emb_len: int = 16 # if non-zero, its specified to this value
    no_ACT_continue: bool =  True # No continue ACT loss, only use the sigmoid of the halt which makes much more sense

class TinyRecursiveReasoningModel_ACTV1Block(nn.Module):
    def __init__(self, config: TinyRecursiveReasoningModel_ACTV1Config) -> None:
        super().__init__()

        self.config = config
        if self.config.mlp_t:
            self.puzzle_emb_len = -(self.config.puzzle_emb_ndim // -self.config.hidden_size) if self.config.puzzle_emb_len == 0 else self.config.puzzle_emb_len
            self.mlp_t = SwiGLU(
                hidden_size=self.config.seq_len + self.puzzle_emb_len, # L
                expansion=config.expansion,
            )
        else:
            self.self_attn = Attention(
                hidden_size=config.hidden_size,
                head_dim=config.hidden_size // config.num_heads,
                num_heads=config.num_heads,
                num_key_value_heads=config.num_heads,
                causal=False
            )
        self.mlp = SwiGLU(
            hidden_size=config.hidden_size,
            expansion=config.expansion,
        )
        self.norm_eps = config.rms_norm_eps

    def forward(self, cos_sin: CosSin, hidden_states: torch.Tensor) -> torch.Tensor:
        # B, L, D = hidden_states.shape
        # Post Norm
        if self.config.mlp_t:
            hidden_states = hidden_states.transpose(1,2)
            out = self.mlp_t(hidden_states)
            hidden_states = rms_norm(hidden_states + out, variance_epsilon=self.norm_eps)
            hidden_states = hidden_states.transpose(1,2)
        else:
            # Self Attention
            hidden_states = rms_norm(hidden_states + self.self_attn(cos_sin=cos_sin, hidden_states=hidden_states), variance_epsilon=self.norm_eps)
        # Fully Connected
        out = self.mlp(hidden_states)
        hidden_states = rms_norm(hidden_states + out, variance_epsilon=self.norm_eps)
        return hidden_states

class TinyRecursiveReasoningModel_ACTV1ReasoningModule(nn.Module):
    def __init__(self, layers: List[TinyRecursiveReasoningModel_ACTV1Block]):
        super().__init__()
        self.layers = torch.nn.ModuleList(layers)

    def forward(self, hidden_states: torch.Tensor, input_injection: torch.Tensor, **kwargs) -> torch.Tensor:
        hidden_states = hidden_states + input_injection
        for layer in self.layers:
            hidden_states = layer(hidden_states=hidden_states, **kwargs)
        return hidden_states


class TinyRecursiveReasoningModel_ACTV1_Inner(nn.Module):
    def __init__(self, config: TinyRecursiveReasoningModel_ACTV1Config) -> None:
        super().__init__()
        self.config = config
        self.forward_dtype = getattr(torch, self.config.forward_dtype)

        # I/O

        self.embed_scale = math.sqrt(self.config.hidden_size)
        embed_init_std = 1.0 / self.embed_scale

        self.embed_tokens = CastedEmbedding(self.config.vocab_size, self.config.hidden_size, init_std=embed_init_std, cast_to=self.forward_dtype)
        self.lm_head      = CastedLinear(self.config.hidden_size, self.config.vocab_size, bias=False)
        self.q_head       = CastedLinear(self.config.hidden_size, 2, bias=True)

        self.puzzle_emb_len = -(self.config.puzzle_emb_ndim // -self.config.hidden_size)  if self.config.puzzle_emb_len == 0 else self.config.puzzle_emb_len  # ceil div
        if self.config.puzzle_emb_ndim > 0:
            # Zero init puzzle embeddings
            self.puzzle_emb = CastedSparseEmbedding(self.config.num_puzzle_identifiers, self.config.puzzle_emb_ndim,
                                                    batch_size=self.config.batch_size, init_std=0, cast_to=self.forward_dtype)

        # LM Blocks
        if self.config.pos_encodings == "rope":
            self.rotary_emb = RotaryEmbedding(dim=self.config.hidden_size // self.config.num_heads,
                                              max_position_embeddings=self.config.seq_len + self.puzzle_emb_len,
                                              base=self.config.rope_theta)
        elif self.config.pos_encodings == "learned":
            self.embed_pos = CastedEmbedding(self.config.seq_len + self.puzzle_emb_len, self.config.hidden_size, init_std=embed_init_std, cast_to=self.forward_dtype)
        else:
            pass

        # Reasoning Layers
        self.L_level = TinyRecursiveReasoningModel_ACTV1ReasoningModule(layers=[TinyRecursiveReasoningModel_ACTV1Block(self.config) for _i in range(self.config.L_layers)])

        # Initial states
        self.H_init = nn.Buffer(trunc_normal_init_(torch.empty(self.config.hidden_size, dtype=self.forward_dtype), std=1), persistent=True)
        self.L_init = nn.Buffer(trunc_normal_init_(torch.empty(self.config.hidden_size, dtype=self.forward_dtype), std=1), persistent=True)

        # Q head special init
        # Init Q to (almost) zero for faster learning during bootstrapping
        with torch.no_grad():
            self.q_head.weight.zero_()
            self.q_head.bias.fill_(-5)  # type: ignore

    def _input_embeddings(self, input: torch.Tensor, puzzle_identifiers: torch.Tensor):
        # Token embedding
        embedding = self.embed_tokens(input.to(torch.int32))

        # Puzzle embeddings
        if self.config.puzzle_emb_ndim > 0:
            puzzle_embedding = self.puzzle_emb(puzzle_identifiers)
            
            pad_count = self.puzzle_emb_len * self.config.hidden_size - puzzle_embedding.shape[-1]
            if pad_count > 0:
                puzzle_embedding = F.pad(puzzle_embedding, (0, pad_count))

            embedding = torch.cat((puzzle_embedding.view(-1, self.puzzle_emb_len, self.config.hidden_size), embedding), dim=-2)

        # Position embeddings
        if self.config.pos_encodings == "learned":
            # scale by 1/sqrt(2) to maintain forward variance
            embedding = 0.707106781 * (embedding + self.embed_pos.embedding_weight.to(self.forward_dtype))

        # Scale
        return self.embed_scale * embedding

    def empty_carry(self, batch_size: int):
        return TinyRecursiveReasoningModel_ACTV1InnerCarry(
            z_H=torch.empty(batch_size, self.config.seq_len + self.puzzle_emb_len, self.config.hidden_size, dtype=self.forward_dtype),
            z_L=torch.empty(batch_size, self.config.seq_len + self.puzzle_emb_len, self.config.hidden_size, dtype=self.forward_dtype),
        )
        
    def reset_carry(self, reset_flag: torch.Tensor, carry: TinyRecursiveReasoningModel_ACTV1InnerCarry):
        return TinyRecursiveReasoningModel_ACTV1InnerCarry(
            z_H=torch.where(reset_flag.view(-1, 1, 1), self.H_init, carry.z_H),
            z_L=torch.where(reset_flag.view(-1, 1, 1), self.L_init, carry.z_L),
        )

    def forward(self, carry: TinyRecursiveReasoningModel_ACTV1InnerCarry, batch: Dict[str, torch.Tensor]) -> Tuple[TinyRecursiveReasoningModel_ACTV1InnerCarry, torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
        seq_info = dict(
            cos_sin=self.rotary_emb() if hasattr(self, "rotary_emb") else None,
        )

        # Input encoding
        input_embeddings = self._input_embeddings(batch["inputs"], batch["puzzle_identifiers"])

        # Forward iterations
        it = 0
        z_H, z_L = carry.z_H, carry.z_L
        # H_cycles-1 without grad
        with torch.no_grad():
            for _H_step in range(self.config.H_cycles-1):
                for _L_step in range(self.config.L_cycles):
                    z_L = self.L_level(z_L, z_H + input_embeddings, **seq_info)
                z_H = self.L_level(z_H, z_L, **seq_info)
        # 1 with grad
        for _L_step in range(self.config.L_cycles):
            z_L = self.L_level(z_L, z_H + input_embeddings, **seq_info)
        z_H = self.L_level(z_H, z_L, **seq_info)

        # LM Outputs
        new_carry = TinyRecursiveReasoningModel_ACTV1InnerCarry(z_H=z_H.detach(), z_L=z_L.detach())  # New carry no grad
        output = self.lm_head(z_H)[:, self.puzzle_emb_len:]
        q_logits = self.q_head(z_H[:, 0]).to(torch.float32) # Q-head; uses the first puzzle_emb position
        return new_carry, output, (q_logits[..., 0], q_logits[..., 1])


class TinyRecursiveReasoningModel_ACTV1(nn.Module):
    """ACT wrapper."""

    def __init__(self, config_dict: dict):
        super().__init__()
        self.config = TinyRecursiveReasoningModel_ACTV1Config(**config_dict)
        self.inner = TinyRecursiveReasoningModel_ACTV1_Inner(self.config)

    @property
    def puzzle_emb(self):
        return self.inner.puzzle_emb

    def initial_carry(self, batch: Dict[str, torch.Tensor]):
        batch_size = batch["inputs"].shape[0]

        return TinyRecursiveReasoningModel_ACTV1Carry(
            inner_carry=self.inner.empty_carry(batch_size),  # Empty is expected, it will be reseted in first pass as all sequences are halted.
            
            steps=torch.zeros((batch_size, ), dtype=torch.int32),
            halted=torch.ones((batch_size, ), dtype=torch.bool),  # Default to halted
            
            current_data={k: torch.empty_like(v) for k, v in batch.items()}
        )
        
    def forward(self, carry: TinyRecursiveReasoningModel_ACTV1Carry, batch: Dict[str, torch.Tensor]) -> Tuple[TinyRecursiveReasoningModel_ACTV1Carry, Dict[str, torch.Tensor]]:

        # Update data, carry (removing halted sequences)
        new_inner_carry = self.inner.reset_carry(carry.halted, carry.inner_carry)
        
        new_steps = torch.where(carry.halted, 0, carry.steps)

        new_current_data = {k: torch.where(carry.halted.view((-1, ) + (1, ) * (batch[k].ndim - 1)), batch[k], v) for k, v in carry.current_data.items()}

        # Forward inner model
        new_inner_carry, logits, (q_halt_logits, q_continue_logits) = self.inner(new_inner_carry, new_current_data)

        outputs = {
            "logits": logits,
            "q_halt_logits": q_halt_logits,
            "q_continue_logits": q_continue_logits
        }

        with torch.no_grad():
            # Step
            new_steps = new_steps + 1
            is_last_step = new_steps >= self.config.halt_max_steps
            
            halted = is_last_step

            # if training, and ACT is enabled
            if self.training and (self.config.halt_max_steps > 1):

                # Halt signal
                # NOTE: During evaluation, always use max steps, this is to guarantee the same halting steps inside a batch for batching purposes
                
                if self.config.no_ACT_continue:
                    halted = halted | (q_halt_logits > 0)
                else:
                    halted = halted | (q_halt_logits > q_continue_logits)

                # Exploration
                min_halt_steps = (torch.rand_like(q_halt_logits) < self.config.halt_exploration_prob) * torch.randint_like(new_steps, low=2, high=self.config.halt_max_steps + 1)
                halted = halted & (new_steps >= min_halt_steps)

                if not self.config.no_ACT_continue:
                    # Compute target Q
                    # NOTE: No replay buffer and target networks for computing target Q-value.
                    # As batch_size is large, there're many parallel envs.
                    # Similar concept as PQN https://arxiv.org/abs/2407.04811
                    _, _, (next_q_halt_logits, next_q_continue_logits), _, _ = self.inner(new_inner_carry, new_current_data)
                    outputs["target_q_continue"] = torch.sigmoid(torch.where(is_last_step, next_q_halt_logits, torch.maximum(next_q_halt_logits, next_q_continue_logits)))

        return TinyRecursiveReasoningModel_ACTV1Carry(new_inner_carry, new_steps, halted, new_current_data), outputs


## Part 8: Dataset Common

In [8]:
from typing import List, Optional

import pydantic
import numpy as np


# Global list mapping each dihedral transform id to its inverse.
# Index corresponds to the original tid, and the value is its inverse.
DIHEDRAL_INVERSE = [0, 3, 2, 1, 4, 5, 6, 7]


class PuzzleDatasetMetadata(pydantic.BaseModel):
    pad_id: int
    ignore_label_id: Optional[int]
    blank_identifier_id: int
    vocab_size: int
    seq_len: int
    num_puzzle_identifiers: int
    total_groups: int
    mean_puzzle_examples: float
    total_puzzles: int
    sets: List[str]


def dihedral_transform(arr: np.ndarray, tid: int) -> np.ndarray:
    """8 dihedral symmetries by rotate, flip and mirror"""
    
    if tid == 0:
        return arr  # identity
    elif tid == 1:
        return np.rot90(arr, k=1)
    elif tid == 2:
        return np.rot90(arr, k=2)
    elif tid == 3:
        return np.rot90(arr, k=3)
    elif tid == 4:
        return np.fliplr(arr)       # horizontal flip
    elif tid == 5:
        return np.flipud(arr)       # vertical flip
    elif tid == 6:
        return arr.T                # transpose (reflection along main diagonal)
    elif tid == 7:
        return np.fliplr(np.rot90(arr, k=1))  # anti-diagonal reflection
    else:
        return arr
    
    
def inverse_dihedral_transform(arr: np.ndarray, tid: int) -> np.ndarray:
    return dihedral_transform(arr, DIHEDRAL_INVERSE[tid])


## Part 9: Puzzle Dataset (Non-distributed)

In [9]:
import os
import json
from typing import Tuple, List, Dict, Optional
import numpy as np
import pydantic

import torch
from torch.utils.data import IterableDataset, get_worker_info

# IGNORE_LABEL_ID defined in Part 6
# PuzzleDatasetMetadata defined in Part 8

from pydantic import BaseModel

def _sample_batch(rng: np.random.Generator, group_order: np.ndarray, puzzle_indices: np.ndarray, group_indices: np.ndarray, start_index: int, global_batch_size: int):
    # Pack examples into a full batch
    batch = []
    batch_puzzle_indices = []
    current_size = 0

    while (start_index < group_order.size) and (current_size < global_batch_size):
        # Pick a group and a puzzle from that group
        group_id = group_order[start_index]
        puzzle_id = rng.integers(group_indices[group_id], group_indices[group_id + 1])
        start_index += 1

        # Get range of the puzzle
        puzzle_start = puzzle_indices[puzzle_id]
        puzzle_size = int(puzzle_indices[puzzle_id + 1] - puzzle_start)

        append_size = min(puzzle_size, global_batch_size - current_size)

        # Put into batch
        batch_puzzle_indices.append(np.full(append_size, puzzle_id, dtype=np.int32))
        batch.append(puzzle_start + np.random.choice(puzzle_size, append_size, replace=False))

        current_size += append_size

    return start_index, np.concatenate(batch), np.concatenate(batch_puzzle_indices)


class PuzzleDatasetConfig(pydantic.BaseModel):
    seed: int
    dataset_paths: List[str]
    global_batch_size: int
    test_set_mode: bool
    epochs_per_iter: int  # Batch X epochs in an iteration to reduce overhead.
    rank: int = 0  # Single GPU
    num_replicas: int = 1  # Single GPU
class PuzzleDataset(IterableDataset):
    def __init__(self, config: PuzzleDatasetConfig, split: str = "train"):
        super().__init__()
        self.config = config
        self.split = split

        # Merge multiple metadata
        prev_seq_len = None
        prev_vocab_size = None
        prev_pad_id = None
        prev_ignore_label_id = None
        prev_blank_identifier_id = None
        prev_sets = None
        prev_num_identifiers = None
        mean_puzzle_examples = 0
        total_puzzles = 0
        total_groups = 0
        num_identifiers = 0
        for dataset_path in config.dataset_paths:
            current_metadata = self._load_metadata(dataset_path)
            if prev_seq_len is None:
                prev_seq_len = current_metadata.seq_len
                prev_vocab_size = current_metadata.vocab_size
                prev_pad_id = current_metadata.pad_id
                prev_ignore_label_id = current_metadata.ignore_label_id
                prev_blank_identifier_id = current_metadata.blank_identifier_id
                prev_sets = current_metadata.sets
                prev_num_identifiers = current_metadata.num_puzzle_identifiers
            else:
                assert prev_seq_len == current_metadata.seq_len
                assert prev_vocab_size == current_metadata.vocab_size
                assert prev_pad_id == current_metadata.pad_id
                assert prev_ignore_label_id == current_metadata.ignore_label_id
                assert prev_blank_identifier_id == current_metadata.blank_identifier_id
                assert prev_sets == current_metadata.sets
                assert prev_num_identifiers == current_metadata.num_puzzle_identifiers
            mean_puzzle_examples += current_metadata.mean_puzzle_examples*current_metadata.total_puzzles
            total_puzzles += current_metadata.total_puzzles
            total_groups += current_metadata.total_groups
            num_identifiers += current_metadata.num_puzzle_identifiers
        mean_puzzle_examples = mean_puzzle_examples / total_puzzles

        self.metadata = PuzzleDatasetMetadata(
            seq_len=prev_seq_len,
            vocab_size=prev_vocab_size,
            pad_id=prev_pad_id,
            ignore_label_id=prev_ignore_label_id,
            blank_identifier_id=prev_blank_identifier_id,
            num_puzzle_identifiers=num_identifiers,
            total_groups=total_groups,
            mean_puzzle_examples=mean_puzzle_examples,
            total_puzzles=total_puzzles,
            sets=prev_sets
        )

        # Checks
        assert self.config.global_batch_size % self.config.num_replicas == 0, f"Global batch size {self.config.global_batch_size} must be multiples of nodes {self.config.num_replicas}."
        self.local_batch_size = self.config.global_batch_size // self.config.num_replicas

        # State
        self._data = None
        self._iters = 0

    def _load_metadata(self, dataset_path) -> PuzzleDatasetMetadata:
        with open(os.path.join(dataset_path, self.split, "dataset.json"), "r") as f:
            return PuzzleDatasetMetadata(**json.load(f))

    def _lazy_load_dataset(self):
        if self._data is not None:
            return

        field_mmap_modes = {
            "inputs": "r",
            "labels": "r",

            # Keep indices in memory
            "puzzle_identifiers": None,
            "puzzle_indices": None,
            "group_indices": None
        }

        # Load data
        self._data = {}
        for set_name in self.metadata.sets: # Load subset
            for i, dataset_path in enumerate(self.config.dataset_paths):
                if i > 0:
                    set_name_ = set_name + str(i)
                else:
                    set_name_ = set_name
                self._data[set_name_] = {
                    field_name: np.load(os.path.join(dataset_path, self.split, f"{set_name}__{field_name}.npy"), mmap_mode=mmap_mode)
                    for field_name, mmap_mode in field_mmap_modes.items()
                }


    def _collate_batch(self, batch):
        # Convert dtype
        batch = {k: v.astype(np.int32) for k, v in batch.items()}

        # Convert ignore label IDs
        if self.metadata.ignore_label_id is not None:
            batch["labels"][batch["labels"] == self.metadata.ignore_label_id] = IGNORE_LABEL_ID

        # Pad
        if batch["puzzle_identifiers"].size < self.local_batch_size:
            pad_size = self.local_batch_size - batch["puzzle_identifiers"].size
            pad_values = {
                "inputs": self.metadata.pad_id,
                "labels": IGNORE_LABEL_ID,
                "puzzle_identifiers": self.metadata.blank_identifier_id
            }
            batch = {k: np.pad(v, ((0, pad_size), ) + ((0, 0), ) * (v.ndim - 1), constant_values=pad_values[k]) for k, v in batch.items()}

        # To tensor
        return {k: torch.from_numpy(v) for k, v in batch.items()}
    
    def _iter_test(self):
        for set_i, (set_name, dataset) in enumerate(self._data.items()):  # type: ignore
            total_examples = len(dataset["inputs"])

            # Load examples one by one
            start_index = 0
            while start_index < total_examples:
                # Compute indices
                end_index = min(total_examples, start_index + self.config.global_batch_size)
                
                local_start = start_index + 0  # Single GPU
                local_end   = min(start_index + self.local_batch_size, end_index)  # Single GPU
                
                # Get batch of examples, and also puzzle IDs
                puzzle_indices = []
                puzzle_index = np.searchsorted(dataset["puzzle_indices"], local_start, side="right") - 1
                for i in range(local_start, local_end):
                    while puzzle_index + 1 < len(dataset["puzzle_indices"]) and i >= dataset["puzzle_indices"][puzzle_index + 1]:
                        puzzle_index += 1

                    puzzle_indices.append(puzzle_index)
                
                batch = self._collate_batch({
                    "inputs": dataset["inputs"][local_start: local_end],
                    "labels": dataset["labels"][local_start: local_end],
                    "puzzle_identifiers": dataset["puzzle_identifiers"][puzzle_indices]
                })

                yield set_name, batch, end_index - start_index
                
                # Advance to next batch
                start_index += self.config.global_batch_size

    def _iter_train(self):
        for set_name, dataset in self._data.items():  # type: ignore
            # Increase epoch count
            self._iters += 1

            # Randomly shuffle groups
            rng = np.random.Generator(np.random.Philox(seed=self.config.seed + self._iters))

            group_order = np.concatenate([rng.permutation(dataset["group_indices"].size - 1) for _i in range(self.config.epochs_per_iter)])
            start_index = 0
            
            while start_index < group_order.size:
                start_index, batch_indices, batch_puzzle_indices = _sample_batch(
                    rng,
                    group_order=group_order,
                    puzzle_indices=dataset["puzzle_indices"],
                    group_indices=dataset["group_indices"],
                    start_index=start_index,
                    global_batch_size=self.config.global_batch_size,
                )

                # Select current rank and collate
                global_effective_batch_size = batch_puzzle_indices.size  # Global effective batch size, excluding pads

                # Drop last batch
                if global_effective_batch_size < self.config.global_batch_size:
                    break

                batch_indices        = batch_indices       [0: self.local_batch_size]  # Single GPU
                batch_puzzle_indices = batch_puzzle_indices[0: self.local_batch_size]  # Single GPU
                batch = self._collate_batch({
                    "inputs": dataset["inputs"][batch_indices],
                    "labels": dataset["labels"][batch_indices],
                    "puzzle_identifiers": dataset["puzzle_identifiers"][batch_puzzle_indices]
                })

                yield set_name, batch, global_effective_batch_size
                
    def __iter__(self):
        worker_info = get_worker_info()
        assert worker_info is None or worker_info.num_workers == 1, "Multithreaded data loading is not currently supported."
        
        self._lazy_load_dataset()
        
        # Iterate using specified mode
        if self.config.test_set_mode:
            yield from self._iter_test()
        else:
            yield from self._iter_train()



## Part 10: EMA Helper

In [10]:
import copy
import torch.nn as nn

class EMAHelper(object):
    def __init__(self, mu=0.999):
        self.mu = mu
        self.shadow = {}

    def register(self, module):
        if isinstance(module, nn.DataParallel):
            module = module.module
        for name, param in module.named_parameters():
            if param.requires_grad:
                self.shadow[name] = param.data.clone()

    def update(self, module):
        if isinstance(module, nn.DataParallel):
            module = module.module
        for name, param in module.named_parameters():
            if param.requires_grad:
                self.shadow[name].data = (1. - self.mu) * param.data + self.mu * self.shadow[name].data

    def ema(self, module):
        if isinstance(module, nn.DataParallel):
            module = module.module
        for name, param in module.named_parameters():
            if param.requires_grad:
                param.data.copy_(self.shadow[name].data)

    def ema_copy(self, module):
        module_copy = copy.deepcopy(module)
        self.ema(module_copy)
        return module_copy

    def state_dict(self):
        return self.shadow

    def load_state_dict(self, state_dict):
        self.shadow = state_dict



## Part 11: Utils Functions

In [11]:
import importlib
import inspect


# Class registry for notebook environment
_MODEL_CLASS_REGISTRY = {}

def register_model_class(identifier: str, cls):
    """Register a model class for notebook environment."""
    _MODEL_CLASS_REGISTRY[identifier] = cls

def load_model_class(identifier: str, prefix: str = "models."):
    """Load model class from identifier. Works in notebook environment."""
    # Check registry first
    if identifier in _MODEL_CLASS_REGISTRY:
        return _MODEL_CLASS_REGISTRY[identifier]
    
    module_path, class_name = identifier.split('@')
    
    # Map common identifiers to class names
    class_name_map = {
        'TinyRecursiveReasoningModel_ACTV1': 'TinyRecursiveReasoningModel_ACTV1',
        'ACTLossHead': 'ACTLossHead',
    }
    
    # Get from global namespace (notebook environment)
    import sys
    frame = sys._getframe(1)
    globals_dict = frame.f_globals
    
    # Try to find class in globals
    if class_name in globals_dict:
        cls = globals_dict[class_name]
        if isinstance(cls, type):
            return cls
    
    # Fallback: try importing (for evaluators)
    try:
        if prefix.startswith('evaluators.'):
            # For evaluators, try to import
            module = importlib.import_module(prefix + module_path)
            return getattr(module, class_name)
    except ImportError:
        pass
    
    raise ValueError(f'Class {class_name} not found. Make sure all cells are executed in order.')

def get_model_source_path(identifier: str, prefix: str = "models."):
    """Get source path. In notebook, return None."""
    # In notebook environment, we don't have source files
    return None


## Part 12: Training Framework (Non-distributed, H100 Optimized)

In [12]:
# ============================================================================
# Part 12: Training Framework (Non-distributed, H100 Optimized)
# ============================================================================

# Configuration classes
class LossConfig(pydantic.BaseModel):
    model_config = pydantic.ConfigDict(extra='allow')
    name: str

class ArchConfig(pydantic.BaseModel):
    model_config = pydantic.ConfigDict(extra='allow')
    name: str
    loss: LossConfig

class EvaluatorConfig(pydantic.BaseModel):
    model_config = pydantic.ConfigDict(extra="allow")
    name: str

class PretrainConfig(pydantic.BaseModel):
    # Config
    arch: ArchConfig
    # Data
    data_paths: List[str]
    data_paths_test: List[str] = []
    # Evaluators
    evaluators: List[EvaluatorConfig] = []

    # Hyperparams
    global_batch_size: int
    epochs: int

    lr: float
    lr_min_ratio: float
    lr_warmup_steps: int

    weight_decay: float
    beta1: float
    beta2: float

    # Puzzle embedding
    puzzle_emb_lr: float
    puzzle_emb_weight_decay: float

    # Names
    project_name: Optional[str] = None
    run_name: Optional[str] = None
    load_checkpoint: Optional[str] = None
    checkpoint_path: Optional[str] = None

    # Extras
    seed: int = 0
    checkpoint_every_eval: bool = False
    eval_interval: Optional[int] = None
    min_eval_interval: Optional[int] = 0
    eval_save_outputs: List[str] = []
    max_eval_batches: Optional[int] = None

    ema: bool = False
    ema_rate: float = 0.999
    freeze_weights: bool = False

@dataclass
class TrainState:
    model: nn.Module
    optimizers: Sequence[torch.optim.Optimizer]
    optimizer_lrs: Sequence[float]
    carry: Any

    step: int
    total_steps: int

def create_dataloader(config: PretrainConfig, split: str, rank: int = 0, world_size: int = 1, **kwargs):
    dataset = PuzzleDataset(PuzzleDatasetConfig(
        seed=config.seed,
        dataset_paths=config.data_paths_test if len(config.data_paths_test)>0 and split=="test" else config.data_paths,
        rank=rank,
        num_replicas=world_size,
        **kwargs
    ), split=split)
    dataloader = DataLoader(
        dataset,
        batch_size=None,
        num_workers=1,
        prefetch_factor=8,
        pin_memory=True,
        persistent_workers=True
    )
    return dataloader, dataset.metadata

def create_model(config: PretrainConfig, train_metadata: PuzzleDatasetMetadata, rank: int = 0, world_size: int = 1):
    model_cfg = dict(
        **config.arch.__pydantic_extra__,  # type: ignore
        batch_size=config.global_batch_size // world_size,
        vocab_size=train_metadata.vocab_size,
        seq_len=train_metadata.seq_len,
        num_puzzle_identifiers=train_metadata.num_puzzle_identifiers,
        causal=False
    )

    # Instantiate model with loss head
    model_cls = load_model_class(config.arch.name)
    loss_head_cls = load_model_class(config.arch.loss.name)

    with torch.device("cuda"):
        model: nn.Module = model_cls(model_cfg)
        print(model)
        model = loss_head_cls(model, **config.arch.loss.__pydantic_extra__)  # type: ignore
        if "DISABLE_COMPILE" not in os.environ:
            model = torch.compile(model)  # type: ignore

        # Load checkpoint
        if rank == 0:
            load_checkpoint(model, config)

    # Optimizers and lr (using AdamW instead of AdamAtan2)
    if config.arch.puzzle_emb_ndim == 0:
        optimizers = [
            torch.optim.AdamW(
                model.parameters(),
                lr=0,  # Needs to be set by scheduler
                weight_decay=config.weight_decay,
                betas=(config.beta1, config.beta2),
                eps=1e-8
            )
        ]
        optimizer_lrs = [config.lr]
    elif config.freeze_weights:
        # For frozen weights, we still need an optimizer for puzzle_emb
        # CastedSparseEmbeddingSignSGD_Distributed is defined in Part 5 (same cell)
        optimizers = [
            CastedSparseEmbeddingSignSGD_Distributed(
                model.model.puzzle_emb.buffers(),  # type: ignore
                lr=0,
                weight_decay=config.puzzle_emb_weight_decay,
                world_size=world_size
            )
        ]
        optimizer_lrs = [config.puzzle_emb_lr]
    else:
        # CastedSparseEmbeddingSignSGD_Distributed is defined in Part 5 (same cell)
        optimizers = [
            CastedSparseEmbeddingSignSGD_Distributed(
                model.model.puzzle_emb.buffers(),  # type: ignore
                lr=0,
                weight_decay=config.puzzle_emb_weight_decay,
                world_size=world_size
            ),
            torch.optim.AdamW(
                model.parameters(),
                lr=0,
                weight_decay=config.weight_decay,
                betas=(config.beta1, config.beta2),
                eps=1e-8
            )
        ]
        optimizer_lrs = [config.puzzle_emb_lr, config.lr]

    return model, optimizers, optimizer_lrs

def cosine_schedule_with_warmup_lr_lambda(
    current_step: int, *, base_lr: float, num_warmup_steps: int, num_training_steps: int, min_ratio: float = 0.0, num_cycles: float = 0.5
):
    if current_step < num_warmup_steps:
        return base_lr * float(current_step) / float(max(1, num_warmup_steps))

    progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps))
    return base_lr * (min_ratio + max(0.0, (1 - min_ratio) * 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress))))

def init_train_state(config: PretrainConfig, train_metadata: PuzzleDatasetMetadata, rank: int = 0, world_size: int = 1):
    # Estimated total training steps
    total_steps = int(config.epochs * train_metadata.total_groups * train_metadata.mean_puzzle_examples / config.global_batch_size)

    # Model
    model, optimizers, optimizer_lrs = create_model(config, train_metadata, rank=rank, world_size=world_size)

    return TrainState(
        step=0,
        total_steps=total_steps,
        model=model,
        optimizers=optimizers,
        optimizer_lrs=optimizer_lrs,
        carry=None
    )

def save_train_state(config: PretrainConfig, train_state: TrainState):
    if config.checkpoint_path is None:
        return

    os.makedirs(config.checkpoint_path, exist_ok=True)
    torch.save(train_state.model.state_dict(), os.path.join(config.checkpoint_path, f"step_{train_state.step}"))

def load_checkpoint(model: nn.Module, config: PretrainConfig):
    if config.load_checkpoint is not None:
        print(f"Loading checkpoint {config.load_checkpoint}")
        state_dict = torch.load(config.load_checkpoint, map_location="cuda")
        
        # Resize and reset puzzle emb if needed
        puzzle_emb_name = "_orig_mod.model.inner.puzzle_emb.weights"
        expected_shape: torch.Size = model.model.puzzle_emb.weights.shape  # type: ignore
        if puzzle_emb_name in state_dict:
            puzzle_emb = state_dict[puzzle_emb_name]
            if puzzle_emb.shape != expected_shape:
                print(f"Resetting puzzle embedding as shape is different. Found {puzzle_emb.shape}, Expected {expected_shape}")
                state_dict[puzzle_emb_name] = (
                    torch.mean(puzzle_emb, dim=0, keepdim=True).expand(expected_shape).contiguous()
                )
        model.load_state_dict(state_dict, assign=True)

def compute_lr(base_lr: float, config: PretrainConfig, train_state: TrainState):
    return cosine_schedule_with_warmup_lr_lambda(
        current_step=train_state.step,
        base_lr=base_lr,
        num_warmup_steps=round(config.lr_warmup_steps),
        num_training_steps=train_state.total_steps,
        min_ratio=config.lr_min_ratio
    )

def compute_grad_norm(model: nn.Module) -> float:
    """
    Compute the total gradient norm across all parameters.
    Returns the L2 norm of all gradients.
    """
    total_norm = 0.0
    param_count = 0
    for param in model.parameters():
        if param.grad is not None:
            param_norm = param.grad.data.norm(2)
            total_norm += param_norm.item() ** 2
            param_count += 1
    total_norm = total_norm ** (1. / 2)
    return total_norm if param_count > 0 else 0.0

def train_batch(config: PretrainConfig, train_state: TrainState, batch: Any, global_batch_size: int, rank: int = 0, world_size: int = 1):
    train_state.step += 1
    if train_state.step > train_state.total_steps:
        return
    
    # To device
    batch = {k: v.cuda() for k, v in batch.items()}
    
    # Init carry if it is None
    if train_state.carry is None:
        with torch.device("cuda"):
            train_state.carry = train_state.model.initial_carry(batch)  # type: ignore
    
    # Forward
    train_state.carry, loss, metrics, _, _ = train_state.model(carry=train_state.carry, batch=batch, return_keys=[])
    
    # Check for NaN or Inf in loss
    loss_value = loss.item() if isinstance(loss, torch.Tensor) else loss
    if not (torch.isfinite(loss) if isinstance(loss, torch.Tensor) else (math.isfinite(loss_value) if isinstance(loss_value, float) else True)):
        print(f"‚ö†Ô∏è WARNING: Step {train_state.step} - Loss is NaN or Inf: {loss_value}")
        return None
    
    ((1 / global_batch_size) * loss).backward()
    
    # Allreduce (single GPU, skip)
    if False:  # Single GPU
        pass
    
    # Compute gradient norm for monitoring (no clipping, just recording)
    grad_norm = compute_grad_norm(train_state.model)
    
    # Apply optimizer
    lr_this_step = None    
    for optim, base_lr in zip(train_state.optimizers, train_state.optimizer_lrs):
        lr_this_step = compute_lr(base_lr, config, train_state)
        for param_group in optim.param_groups:
            param_group['lr'] = lr_this_step
            
        optim.step()
        optim.zero_grad()
    
    # Reduce metrics
    if len(metrics):
        assert not any(v.requires_grad for v in metrics.values())
        metric_keys = list(sorted(metrics.keys()))
        metric_values = torch.stack([metrics[k] for k in metric_keys])
        if False:  # Single GPU
            pass
        if True:  # Single GPU, always rank 0
            metric_values = metric_values.cpu().numpy()
            reduced_metrics = {k: metric_values[i] for i, k in enumerate(metric_keys)}
            
            # Postprocess
            count = max(reduced_metrics["count"], 1)
            reduced_metrics = {f"train/{k}": v / (global_batch_size if k.endswith("loss") else count) for k, v in reduced_metrics.items()}
            reduced_metrics["train/lr"] = lr_this_step
            
            # Add gradient norm monitoring
            if grad_norm is not None:
                reduced_metrics["train/grad_norm"] = float(grad_norm)
            
            # Note: GPU/system stats are automatically logged by wandb (_disable_stats=False)
            # No need to manually record them
            
            return reduced_metrics

def create_evaluators(config: PretrainConfig, eval_metadata: PuzzleDatasetMetadata) -> List[Any]:
    data_paths = config.data_paths_test if len(config.data_paths_test)>0 else config.data_paths
    evaluators = []
    for cfg in config.evaluators:
        for data_path in data_paths:
            cls = load_model_class(cfg.name, "evaluators.")(
                data_path=data_path, eval_metadata=eval_metadata, **cfg.__pydantic_extra__
            )  # type: ignore
            evaluators.append(cls)
    return evaluators

def evaluate(
    config: PretrainConfig,
    train_state: TrainState,
    eval_loader: torch.utils.data.DataLoader,
    eval_metadata: PuzzleDatasetMetadata,
    evaluators: List[Any],
    rank: int = 0,
    world_size: int = 1,
    cpu_group: Optional[Any] = None,
):
    reduced_metrics = None

    with torch.inference_mode():
        return_keys = set(config.eval_save_outputs)
        for evaluator in evaluators:
            evaluator.begin_eval()
            return_keys.update(evaluator.required_outputs)
        
        # Run evaluation
        set_ids = {k: idx for idx, k in enumerate(eval_metadata.sets)}
        save_preds = {}
        metric_keys = []
        metric_values = None
        carry = None
        processed_batches = 0
        
        for set_name, batch, global_batch_size in eval_loader:
            if config.max_eval_batches is not None and processed_batches >= config.max_eval_batches:
                break
            processed_batches += 1
            if rank == 0:
                print(f"Processing batch {processed_batches}: {set_name}")
            
            # To device
            batch = {k: v.cuda() for k, v in batch.items()}
            with torch.device("cuda"):
                carry = train_state.model.initial_carry(batch)  # type: ignore

            # Forward
            inference_steps = 0
            while True:
                carry, loss, metrics, preds, all_finish = train_state.model(
                    carry=carry, batch=batch, return_keys=return_keys
                )
                inference_steps += 1
                if all_finish:
                    break

            if rank == 0:
                print(f"  Completed inference in {inference_steps} steps")

            for collection in (batch, preds):
                for k, v in collection.items():
                    if k in config.eval_save_outputs:
                        save_preds.setdefault(k, [])
                        save_preds[k].append(v.cpu())

            for evaluator in evaluators:
                evaluator.update_batch(batch, preds)

            del carry, loss, preds, batch, all_finish

            # Aggregate metrics
            set_id = set_ids[set_name]
            if metric_values is None:
                metric_keys = list(sorted(metrics.keys()))
                metric_values = torch.zeros(
                    (len(set_ids), len(metrics.values())), dtype=torch.float32, device="cuda"
                )
            metric_values[set_id] += torch.stack([metrics[k] for k in metric_keys])
            del metrics

        # Concatenate save preds
        save_preds = {k: torch.cat(v, dim=0) for k, v in save_preds.items()}

        # Save preds
        if config.checkpoint_path is not None and len(save_preds):
            os.makedirs(os.path.dirname(config.checkpoint_path), exist_ok=True)
            torch.save(
                save_preds, os.path.join(config.checkpoint_path, f"step_{train_state.step}_all_preds.{rank}")
            )
        del save_preds

        # Reduce to rank 0
        if metric_values is not None:
            if False:  # Single GPU
                pass
            if True:  # Single GPU, always rank 0
                reduced_metrics = metric_values.cpu().numpy()
                reduced_metrics = {
                    set_name: {
                        metric_name: reduced_metrics[set_id, metric_id]
                        for metric_id, metric_name in enumerate(metric_keys)
                    }
                    for set_id, set_name in enumerate(set_ids)
                }
                # Postprocess
                for set_name, m in reduced_metrics.items():
                    count = m.pop("count")
                    reduced_metrics[set_name] = {k: v / count for k, v in m.items()}

        # Run evaluators
        if rank == 0:
            print(f"\nRunning {len(evaluators)} evaluator(s)...")
        for i, evaluator in enumerate(evaluators):
            if rank == 0:
                print(f"Running evaluator {i+1}/{len(evaluators)}: {evaluator.__class__.__name__}")
            evaluator_save_path = None
            if config.checkpoint_path is not None:
                evaluator_save_path = os.path.join(
                    config.checkpoint_path,
                    f"evaluator_{evaluator.__class__.__name__}_step_{train_state.step}",
                )
                os.makedirs(evaluator_save_path, exist_ok=True)
            metrics = evaluator.result(evaluator_save_path, rank=rank, world_size=world_size, group=cpu_group)
            if rank == 0 and metrics is not None:
                if reduced_metrics is None:
                    reduced_metrics = {}
                reduced_metrics.update(metrics)
                print(f"  Completed {evaluator.__class__.__name__}")
        if rank == 0:
            print("All evaluators completed!")

    return reduced_metrics

def launch(config_dict: dict):
    """
    Launch training with a configuration dictionary.
    Single GPU, non-distributed version.
    """
    RANK = 0
    WORLD_SIZE = 1
    CPU_PROCESS_GROUP = None

    # Load config
    config = PretrainConfig(**config_dict)
    
    # Naming
    if config.project_name is None:
        config.project_name = "TRM-A100-Sudoku"
    if config.run_name is None:
        config.run_name = f"{config.arch.name.split('@')[-1]} {coolname.generate_slug(2)}"
    if config.checkpoint_path is None:
        config.checkpoint_path = os.path.join("checkpoints", config.project_name, config.run_name)

    # Seed RNGs
    torch.random.manual_seed(config.seed + RANK)

    # Dataset
    train_epochs_per_iter = config.eval_interval if config.eval_interval is not None else config.epochs
    total_iters = config.epochs // train_epochs_per_iter
    assert config.epochs % train_epochs_per_iter == 0, "Eval interval must be a divisor of total epochs."

    train_loader, train_metadata = create_dataloader(
        config, "train", test_set_mode=False, epochs_per_iter=train_epochs_per_iter,
        global_batch_size=config.global_batch_size, rank=RANK, world_size=WORLD_SIZE
    )
    try:
        eval_loader, eval_metadata = create_dataloader(
            config, "test", test_set_mode=True, epochs_per_iter=1,
            global_batch_size=config.global_batch_size, rank=RANK, world_size=WORLD_SIZE
        )
    except:
        print("NO EVAL DATA FOUND")
        eval_loader = eval_metadata = None

    try:
        evaluators = create_evaluators(config, eval_metadata)
    except:
        print("No evaluator found")
        evaluators = []

    # Train state
    train_state = init_train_state(config, train_metadata, rank=RANK, world_size=WORLD_SIZE)

    # Progress bar and logger
    progress_bar = None
    ema_helper = None
    if RANK == 0:
        progress_bar = tqdm(total=train_state.total_steps)
        wandb.init(
            project=config.project_name,
            name=config.run_name,
            config=config.model_dump(),
            settings=wandb.Settings(_disable_stats=False)  # Enable automatic system stats
        )
        wandb.log({"num_params": sum(x.numel() for x in train_state.model.parameters())}, step=0)
    if config.ema:
        print('Setup EMA')
        # EMAHelper is defined in Part 10 (cell 10) of this notebook
        # No need to import from models.ema - it's already in the global namespace
        ema_helper = EMAHelper(mu=config.ema_rate)
        ema_helper.register(train_state.model)

    # Training Loop
    for _iter_id in range(total_iters):
        print(f"Epoch {_iter_id * train_epochs_per_iter}")

        # Train Iter
        if RANK == 0:
            print("TRAIN")
        train_state.model.train()
        for set_name, batch, global_batch_size in train_loader:
            metrics = train_batch(config, train_state, batch, global_batch_size, rank=RANK, world_size=WORLD_SIZE)
            if RANK == 0 and metrics is not None:
                wandb.log(metrics, step=train_state.step)
                progress_bar.update(train_state.step - progress_bar.n)  # type: ignore
            if config.ema:
                ema_helper.update(train_state.model)

        if _iter_id >= config.min_eval_interval:
            # Evaluation
            if RANK == 0:
                print("EVALUATE")
            if config.ema:
                print("SWITCH TO EMA")
                train_state_eval = copy.deepcopy(train_state)
                train_state_eval.model = ema_helper.ema_copy(train_state_eval.model)
            else:
                train_state_eval = train_state
            train_state_eval.model.eval()
            metrics = evaluate(
                config, train_state_eval, eval_loader, eval_metadata, evaluators,
                rank=RANK, world_size=WORLD_SIZE, cpu_group=CPU_PROCESS_GROUP
            )
            if RANK == 0 and metrics is not None:
                wandb.log(metrics, step=train_state.step)
            
            # Checkpointing
            if RANK == 0:
                print("SAVE CHECKPOINT")
            if RANK == 0 and (config.checkpoint_every_eval or (_iter_id == total_iters - 1)):
                save_train_state(config, train_state_eval)
            if config.ema:
                del train_state_eval

    wandb.finish()

## Part 13: Example Usage & Dataset Preparation

### üì¶ Dataset Preparation

Before training, you need to prepare your dataset. For **Sudoku-Extreme** dataset:

```bash
# Build Sudoku dataset (1000 examples, 1000 augmentations)
python TinyRecursiveModels/dataset/build_sudoku_dataset.py \
  --output-dir data/sudoku-extreme-1k-aug-1000 \
  --subsample-size 1000 \
  --num-aug 1000
```

This will create the dataset in `data/sudoku-extreme-1k-aug-1000/` directory.

**Note:** Make sure you have the `TinyRecursiveModels` folder in your workspace, or adjust the path accordingly.

### ‚öôÔ∏è Configuration Options

Below are three pre-configured setups ready to use:

In [None]:
# ============================================================================
# Example Configurations
# ============================================================================

# ----------------------------------------------------------------------------
# Configuration 1: ARC-AGI Dataset (Original)
# ----------------------------------------------------------------------------
arc_config = {
    'arch': {
        'name': 'recursive_reasoning.trm@TinyRecursiveReasoningModel_ACTV1',
        'loss': {
            'name': 'losses@ACTLossHead',
            'loss_type': 'stablemax_cross_entropy'
        },
        'halt_exploration_prob': 0.1,
        'halt_max_steps': 16,
        'H_cycles': 3,
        'L_cycles': 6,
        'H_layers': 0,
        'L_layers': 2,
        'hidden_size': 512,
        'num_heads': 8,
        'expansion': 4,
        'puzzle_emb_ndim': 512,
        'pos_encodings': 'rope',
        'forward_dtype': 'bfloat16',
        'mlp_t': False,
        'puzzle_emb_len': 16,
        'no_ACT_continue': True
    },
    'data_paths': ['data/arc1concept-aug-1000'],
    'data_paths_test': [],
    'evaluators': [{'name': 'arc@ARC'}],
    'global_batch_size': 768,
    'epochs': 100000,
    'eval_interval': 10000,
    'checkpoint_every_eval': True,
    'lr': 1e-4,
    'lr_min_ratio': 1.0,
    'lr_warmup_steps': 2000,
    'beta1': 0.9,
    'beta2': 0.95,
    'weight_decay': 0.1,
    'puzzle_emb_weight_decay': 0.1,
    'puzzle_emb_lr': 1e-2,
    'seed': 0,
    'min_eval_interval': 0,
	'max_eval_batches': 1000,  # Limit to 1000 batches for faster evaluation (~76,800 examples with batch_size=768)
    'ema': False,
    'ema_rate': 0.999,
    'freeze_weights': False
}

# ----------------------------------------------------------------------------
# Configuration 2: Sudoku-Extreme Dataset (Attention Version) ‚≠ê RECOMMENDED
# Based on: pretrain_att_sudoku from README
# ----------------------------------------------------------------------------
sudoku_att_config = {
    'arch': {
        'name': 'recursive_reasoning.trm@TinyRecursiveReasoningModel_ACTV1',
        'loss': {
            'name': 'losses@ACTLossHead',
            'loss_type': 'stablemax_cross_entropy'
        },
        'halt_exploration_prob': 0.1,
        'halt_max_steps': 16,
        'H_cycles': 3,
        'L_cycles': 6,
        'H_layers': 0,
        'L_layers': 2,
        'hidden_size': 512,
        'num_heads': 8,
        'expansion': 4,
        'puzzle_emb_ndim': 512,
        'pos_encodings': 'rope',  # Use RoPE for attention
        'forward_dtype': 'bfloat16',
        'mlp_t': False,  # Use attention (transformer)
        'puzzle_emb_len': 16,
        'no_ACT_continue': True
    },
    'data_paths': ['data/sudoku-extreme-1k-aug-1000'],
    'data_paths_test': [],
    'evaluators': [],  # No evaluator for Sudoku
    'global_batch_size': 768,
    'epochs': 50000,
    'eval_interval': 5000,
    'checkpoint_every_eval': True,
    'lr': 1e-4,
    'lr_min_ratio': 1.0,
    'lr_warmup_steps': 2000,
    'beta1': 0.9,
    'beta2': 0.95,
    'weight_decay': 1.0,  # Higher weight decay for Sudoku
    'puzzle_emb_weight_decay': 1.0,  # Higher weight decay for Sudoku
    'puzzle_emb_lr': 1e-4,  # Same as main lr for Sudoku
    'seed': 0,
    'min_eval_interval': 0,
    'max_eval_batches': 1000,  # Limit to 1000 batches for faster evaluation (~76,800 examples with batch_size=768)
    'ema': True,  # Use EMA for Sudoku
    'ema_rate': 0.999,
    'freeze_weights': False,
	'project_name': 'TRM-A100-Sudoku'  # Wandb project name
}

# ----------------------------------------------------------------------------
# Configuration 3: Sudoku-Extreme Dataset (MLP Version)
# Based on: pretrain_mlp_t_sudoku from README
# ----------------------------------------------------------------------------
sudoku_mlp_config = {
    'arch': {
        'name': 'recursive_reasoning.trm@TinyRecursiveReasoningModel_ACTV1',
        'loss': {
            'name': 'losses@ACTLossHead',
            'loss_type': 'stablemax_cross_entropy'
        },
        'halt_exploration_prob': 0.1,
        'halt_max_steps': 16,
        'H_cycles': 3,
        'L_cycles': 6,
        'H_layers': 0,
        'L_layers': 2,
        'hidden_size': 512,
        'num_heads': 8,
        'expansion': 4,
        'puzzle_emb_ndim': 512,
        'pos_encodings': 'none',  # No positional encoding for MLP
        'forward_dtype': 'bfloat16',
        'mlp_t': True,  # Use MLP instead of transformer
        'puzzle_emb_len': 16,
        'no_ACT_continue': True
    },
    'data_paths': ['data/sudoku-extreme-1k-aug-1000'],
    'data_paths_test': [],
    'evaluators': [],  # No evaluator for Sudoku
    'global_batch_size': 768,
    'epochs': 50000,
    'eval_interval': 5000,
    'checkpoint_every_eval': True,
    'lr': 1e-4,
    'lr_min_ratio': 1.0,
    'lr_warmup_steps': 2000,
    'beta1': 0.9,
    'beta2': 0.95,
    'weight_decay': 1.0,  # Higher weight decay for Sudoku
    'puzzle_emb_weight_decay': 1.0,  # Higher weight decay for Sudoku
    'puzzle_emb_lr': 1e-4,  # Same as main lr for Sudoku
    'seed': 0,
    'min_eval_interval': 0,
	'max_eval_batches': 1000,  # Limit to 100 batches for faster evaluation (~76,800 examples with batch_size=768)
    'ema': True,  # Use EMA for Sudoku
    'ema_rate': 0.999,
    'freeze_weights': False ,
	'project_name': 'TRM-A100-Sudoku'  # Wandb project name
}

# ----------------------------------------------------------------------------
# Select Configuration
# ----------------------------------------------------------------------------
# Choose one of the configurations above:
# - arc_config: For ARC-AGI dataset
# - sudoku_att_config: For Sudoku with Attention (Transformer) ‚≠ê RECOMMENDED
# - sudoku_mlp_config: For Sudoku with MLP

# Default to Sudoku Attention version
example_config = sudoku_att_config

print("="*70)
print("üìã Available Configurations:")
print("="*70)
print("1. arc_config - ARC-AGI dataset")
print("2. sudoku_att_config - Sudoku-Extreme with Attention (‚≠ê Recommended)")
print("3. sudoku_mlp_config - Sudoku-Extreme with MLP")
print("4. sudoku_test_config - Sudoku Minimal Test (üß™ Quick test: 100 epochs, batch 256)")
print("="*70)
config_name = 'Sudoku (Attention)' if example_config == sudoku_att_config else 'Sudoku (MLP)' if example_config == sudoku_mlp_config else 'ARC-AGI'
print(f"‚úÖ Current configuration: {config_name}")
print(f"üìÅ Data path: {example_config['data_paths']}")
print(f"üîÑ Epochs: {example_config['epochs']}, Eval interval: {example_config['eval_interval']}")
print(f"üìä Batch size: {example_config['global_batch_size']}")
print(f"üéØ Architecture: {'MLP' if example_config['arch']['mlp_t'] else 'Attention (Transformer)'}")
print(f"üíæ EMA: {'Enabled' if example_config['ema'] else 'Disabled'}")
print("üöÄ To start training, call: launch(example_config)")
print("  Or use: launch(sudoku_att_config), launch(sudoku_mlp_config), or launch(sudoku_test_config)")




üìã Available Configurations:
1. arc_config - ARC-AGI dataset
2. sudoku_att_config - Sudoku-Extreme with Attention (‚≠ê Recommended)
3. sudoku_mlp_config - Sudoku-Extreme with MLP
4. sudoku_test_config - Sudoku Minimal Test (üß™ Quick test: 100 epochs, batch 256)
‚úÖ Current configuration: Sudoku (Attention)
üìÅ Data path: ['data/sudoku-extreme-1k-aug-1000']
üîÑ Epochs: 50000, Eval interval: 5000
üìä Batch size: 768
üéØ Architecture: Attention (Transformer)
üíæ EMA: Enabled
üöÄ To start training, call: launch(example_config)
  Or use: launch(sudoku_att_config), launch(sudoku_mlp_config), or launch(sudoku_test_config)


## Configuration 4: Sudoku-Extreme Minimal Test Configuration 

In [14]:
# ----------------------------------------------------------------------------
# Configuration 4: Sudoku-Extreme Minimal Test Configuration üß™
# For quick testing - reduced epochs and batch size
# ----------------------------------------------------------------------------
sudoku_test_config_A100 = {
    'arch': {
        'name': 'recursive_reasoning.trm@TinyRecursiveReasoningModel_ACTV1',
        'loss': {
            'name': 'losses@ACTLossHead',
            'loss_type': 'stablemax_cross_entropy'
        },
        'halt_exploration_prob': 0.1,
        'halt_max_steps': 16,
        'H_cycles': 3,
        'L_cycles': 6,
        'H_layers': 0,
        'L_layers': 2,
        'hidden_size': 512,
        'num_heads': 8,
        'expansion': 4,
        'puzzle_emb_ndim': 512,
        'pos_encodings': 'rope',
        'forward_dtype': 'bfloat16',
        'mlp_t': False,
        'puzzle_emb_len': 16,
        'no_ACT_continue': True
    },
    'data_paths': ['data/sudoku-extreme-1k-aug-1000'],
    'data_paths_test': [],
    'evaluators': [],
    'global_batch_size': 256,  # Reduced for testing
    'epochs': 5000,  # Minimal epochs for quick test
    'eval_interval': 1000,  # Evaluate every 50 epochs
    'checkpoint_every_eval': True,
    'lr': 2e-4,
    'lr_min_ratio': 1.0,
    'lr_warmup_steps': 100,  # Reduced warmup
    'beta1': 0.9,
    'beta2': 0.95,
    'weight_decay': 1.0,
    'puzzle_emb_weight_decay': 1.0,
    'puzzle_emb_lr': 1e-4,
    'seed': 0,
    'min_eval_interval': 0,
    'max_eval_batches': 100,  # Limit to 50 batches for faster evaluation (~12,800 examples with batch_size=256)
    'ema': True,
    'ema_rate': 0.999,
    'freeze_weights': False,
    'project_name': 'TRM-A100-Sudoku'  # Wandb project name
}

print("" + "="*70)
print("üß™ Minimal Test Configuration Available:")
print("="*70)
print("sudoku_test_config - Quick test with reduced epochs (100) and batch size (256)")
print("To run test: launch(sudoku_test_config)")
print("="*70)


üß™ Minimal Test Configuration Available:
sudoku_test_config - Quick test with reduced epochs (100) and batch size (256)
To run test: launch(sudoku_test_config)


## build dataset

In [None]:
# Check and build dataset if needed
import os
import csv
import json
import numpy as np
from tqdm  import tqdm
from huggingface_hub import hf_hub_download, login
import warnings

# Set Hugging Face Token for authentication
os.environ["HF_TOKEN"] = HF_TOKEN

# Login to Hugging Face Hub
try:
    login(token=HF_TOKEN, add_to_git_credential=False)
    print("‚úÖ Successfully authenticated with Hugging Face Hub")
except Exception as e:
    print(f"‚ö†Ô∏è Warning: Could not login to Hugging Face Hub: {e}")
    print("   Continuing with token in environment variable...")

DATASET_DIR = "data/sudoku-extreme-1k-aug-1000"
TRAIN_SUBSAMPLE_SIZE = 1000
NUM_AUG = 1000  # Augmentation count
MIN_DIFFICULTY = None  # Optional: filter by minimum difficulty rating

def shuffle_sudoku(board: np.ndarray, solution: np.ndarray):
    """Apply equivalent transformations to Sudoku (preserves validity)"""
    # Digit mapping: random permutation of 1-9
    digit_map = np.pad(np.random.permutation(np.arange(1, 10)), (1, 0))
    
    # Random transpose
    transpose_flag = np.random.rand() < 0.5

    # Row permutation: shuffle 3 bands, then shuffle rows within each band
    bands = np.random.permutation(3)
    row_perm = np.concatenate([b * 3 + np.random.permutation(3) for b in bands])

    # Column permutation: same for columns
    stacks = np.random.permutation(3)
    col_perm = np.concatenate([s * 3 + np.random.permutation(3) for s in stacks])

    # Build 81->81 position mapping
    mapping = np.array([row_perm[i // 9] * 9 + col_perm[i % 9] for i in range(81)])

    def apply_transformation(x: np.ndarray) -> np.ndarray:
        if transpose_flag:
            x = x.T
        new_board = x.flatten()[mapping].reshape(9, 9).copy()
        return digit_map[new_board]

    return apply_transformation(board), apply_transformation(solution)

def convert_subset(set_name: str):
    """Process train or test set"""
    print(f"\nüì• Processing {set_name} set...")

    # Download CSV from HuggingFace
    # Use HF_TOKEN from environment variable for authentication
    csv_path = hf_hub_download("sapientinc/sudoku-extreme", f"{set_name}.csv", repo_type="dataset", token=HF_TOKEN)

    # Read CSV
    inputs, labels = [], []
    with open(csv_path, newline="") as f:
        reader = csv.reader(f)
        next(reader)  # Skip header
        for source, q, a, rating in tqdm(reader, desc="Reading CSV"):
            # Filter by difficulty if specified (matching original implementation)
            if (MIN_DIFFICULTY is None) or (int(rating) >= MIN_DIFFICULTY):
                assert len(q) == 81 and len(a) == 81
                inputs.append(np.frombuffer(q.replace('.', '0').encode(), dtype=np.uint8).reshape(9, 9) - ord('0'))
                labels.append(np.frombuffer(a.encode(), dtype=np.uint8).reshape(9, 9) - ord('0'))

    print(f"  Loaded {len(inputs)} puzzles")

    # Dataset subsampling (only for train set)
    if set_name == "train" and TRAIN_SUBSAMPLE_SIZE is not None and TRAIN_SUBSAMPLE_SIZE < len(inputs):
        indices = np.random.choice(len(inputs), size=TRAIN_SUBSAMPLE_SIZE, replace=False)
        inputs = [inputs[i] for i in indices]
        labels = [labels[i] for i in indices]
        print(f"  Subsampled to {len(inputs)} puzzles")

    # Data augmentation (only for train set)
    num_augments = NUM_AUG if set_name == "train" else 0

    # Build results
    results = {k: [] for k in ["inputs", "labels", "puzzle_identifiers", "puzzle_indices", "group_indices"]}
    puzzle_id = 0
    example_id = 0

    results["puzzle_indices"].append(0)
    results["group_indices"].append(0)

    for orig_inp, orig_out in tqdm(zip(inputs, labels), total=len(inputs), desc="Augmenting"):
        for aug_idx in range(1 + num_augments):
            if aug_idx == 0:
                inp, out = orig_inp, orig_out
            else:
                inp, out = shuffle_sudoku(orig_inp, orig_out)

            results["inputs"].append(inp)
            results["labels"].append(out)
            example_id += 1
            puzzle_id += 1

            results["puzzle_indices"].append(example_id)
            results["puzzle_identifiers"].append(0)

        results["group_indices"].append(puzzle_id)

    # Convert to NumPy arrays
    def seq_to_numpy(seq):
        arr = np.concatenate(seq).reshape(len(seq), -1)
        assert np.all((arr >= 0) & (arr <= 9))
        return arr + 1  # Offset +1, 0 reserved for PAD

    results = {
        "inputs": seq_to_numpy(results["inputs"]),
        "labels": seq_to_numpy(results["labels"]),
        "group_indices": np.array(results["group_indices"], dtype=np.int32),
        "puzzle_indices": np.array(results["puzzle_indices"], dtype=np.int32),
        "puzzle_identifiers": np.array(results["puzzle_identifiers"], dtype=np.int32),
    }

    # Metadata (matching original implementation exactly)
    metadata = PuzzleDatasetMetadata(
        seq_len=81,
        vocab_size=10 + 1,  # PAD + "0" ... "9" (matching original)
        pad_id=0,
        ignore_label_id=0,
        blank_identifier_id=0,
        num_puzzle_identifiers=1,
        total_groups=len(results["group_indices"]) - 1,
        mean_puzzle_examples=1,  # Fixed value as in original (even with augmentation)
        total_puzzles=len(results["group_indices"]) - 1,
        sets=["all"]
    )

    # Save
    save_dir = os.path.join(DATASET_DIR, set_name)
    os.makedirs(save_dir, exist_ok=True)

    with open(os.path.join(save_dir, "dataset.json"), "w") as f:
        json.dump(metadata.model_dump(), f)  # No indent to match original

    for k, v in results.items():
        np.save(os.path.join(save_dir, f"all__{k}.npy"), v)

    print(f"  ‚úÖ Saved to {save_dir}")
    print(f"  üìä Total examples: {results['inputs'].shape[0]}")
    
    return metadata

# Check if dataset exists
train_metadata_path = os.path.join(DATASET_DIR, "train", "dataset.json")
test_metadata_path = os.path.join(DATASET_DIR, "test", "dataset.json")

if os.path.exists(train_metadata_path) and os.path.exists(test_metadata_path):
    print("="*70)
    print("‚úÖ Dataset already exists!")
    print(f"üìÅ Path: {DATASET_DIR}")
    print("="*70)
else:
    print("="*70)
    print("üì¶ Building Sudoku Dataset")
    print("="*70)
    print(f"Source: sapientinc/sudoku-extreme")
    print(f"Output: {DATASET_DIR}")
    print(f"Train subsample: {TRAIN_SUBSAMPLE_SIZE}")
    print(f"Augmentation: {NUM_AUG}")
    print("="*70)
    
    # Build train set
    train_metadata = convert_subset("train")
    
    # Build test set
    test_metadata = convert_subset("test")
    
    # Save identifiers.json
    with open(os.path.join(DATASET_DIR, "identifiers.json"), "w") as f:
        json.dump(["<blank>"], f)
    
    print("\n" + "="*70)
    print("‚úÖ Dataset build complete!")
    print("="*70)

Note: Environment variable`HF_TOKEN` is set and is the current active token independently from the token you've just configured.


‚úÖ Successfully authenticated with Hugging Face Hub
üì¶ Building Sudoku Dataset
Source: sapientinc/sudoku-extreme
Output: data/sudoku-extreme-1k-aug-1000
Train subsample: 1000
Augmentation: 1000

üì• Processing train set...


train.csv:   0%|          | 0.00/719M [00:00<?, ?B/s]

Reading CSV: 3831994it [00:32, 118403.51it/s]


  Loaded 3831994 puzzles
  Subsampled to 1000 puzzles


Augmenting: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 1000/1000 [02:15<00:00,  7.36it/s]


  ‚úÖ Saved to data/sudoku-extreme-1k-aug-1000/train
  üìä Total examples: 1001000

üì• Processing test set...


test.csv:   0%|          | 0.00/79.4M [00:00<?, ?B/s]

Reading CSV: 422786it [00:03, 119913.88it/s]


  Loaded 422786 puzzles


Augmenting: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 422786/422786 [00:00<00:00, 1836918.35it/s]


  ‚úÖ Saved to data/sudoku-extreme-1k-aug-1000/test
  üìä Total examples: 422786

‚úÖ Dataset build complete!


In [16]:
# Run the minimal test configuration for quick testing
print("üöÄ Starting minimal test configuration...")
print("="*70)
print("üìä Test Configuration:")
print(f"  - Epochs: {sudoku_test_config_A100['epochs']} (quick test)")
print(f"  - Batch size: {sudoku_test_config_A100['global_batch_size']}")
print(f"  - Eval interval: {sudoku_test_config_A100['eval_interval']}")
print(f"  - Architecture: Attention (Transformer)")
print(f"  - EMA: {'Enabled' if sudoku_test_config_A100['ema'] else 'Disabled'}")
print("="*70)
print("üí° This is a quick test. For full training, use: launch(sudoku_att_config)")
print("="*70)
launch(sudoku_test_config_A100)

üöÄ Starting minimal test configuration...
üìä Test Configuration:
  - Epochs: 5000 (quick test)
  - Batch size: 256
  - Eval interval: 1000
  - Architecture: Attention (Transformer)
  - EMA: Enabled
üí° This is a quick test. For full training, use: launch(sudoku_att_config)
TinyRecursiveReasoningModel_ACTV1(
  (inner): TinyRecursiveReasoningModel_ACTV1_Inner(
    (embed_tokens): CastedEmbedding()
    (lm_head): CastedLinear()
    (q_head): CastedLinear()
    (puzzle_emb): CastedSparseEmbedding()
    (rotary_emb): RotaryEmbedding()
    (L_level): TinyRecursiveReasoningModel_ACTV1ReasoningModule(
      (layers): ModuleList(
        (0-1): 2 x TinyRecursiveReasoningModel_ACTV1Block(
          (self_attn): Attention(
            (qkv_proj): CastedLinear()
            (o_proj): CastedLinear()
          )
          (mlp): SwiGLU(
            (gate_up_proj): CastedLinear()
            (down_proj): CastedLinear()
          )
        )
      )
    )
  )
)


  0%|          | 0/19531 [00:00<?, ?it/s]

Setup EMA
Epoch 0
TRAIN


 20%|‚ñà‚ñâ        | 3905/19531 [07:10<24:54, 10.46it/s] 

EVALUATE
SWITCH TO EMA
Processing batch 1: all


 20%|‚ñà‚ñâ        | 3906/19531 [07:21<24:54, 10.46it/s]

  Completed inference in 16 steps
Processing batch 2: all
  Completed inference in 16 steps
Processing batch 3: all
  Completed inference in 16 steps
Processing batch 4: all
  Completed inference in 16 steps
Processing batch 5: all
  Completed inference in 16 steps
Processing batch 6: all
  Completed inference in 16 steps
Processing batch 7: all
  Completed inference in 16 steps
Processing batch 8: all
  Completed inference in 16 steps
Processing batch 9: all
  Completed inference in 16 steps
Processing batch 10: all
  Completed inference in 16 steps
Processing batch 11: all
  Completed inference in 16 steps
Processing batch 12: all
  Completed inference in 16 steps
Processing batch 13: all
  Completed inference in 16 steps
Processing batch 14: all
  Completed inference in 16 steps
Processing batch 15: all
  Completed inference in 16 steps
Processing batch 16: all
  Completed inference in 16 steps
Processing batch 17: all
  Completed inference in 16 steps
Processing batch 18: all
  Com

 40%|‚ñà‚ñà‚ñà‚ñâ      | 7811/19531 [15:27<18:37, 10.49it/s]   

EVALUATE
SWITCH TO EMA
Processing batch 1: all
  Completed inference in 16 steps
Processing batch 2: all
  Completed inference in 16 steps
Processing batch 3: all
  Completed inference in 16 steps
Processing batch 4: all
  Completed inference in 16 steps
Processing batch 5: all
  Completed inference in 16 steps
Processing batch 6: all
  Completed inference in 16 steps
Processing batch 7: all
  Completed inference in 16 steps
Processing batch 8: all
  Completed inference in 16 steps
Processing batch 9: all
  Completed inference in 16 steps
Processing batch 10: all
  Completed inference in 16 steps
Processing batch 11: all
  Completed inference in 16 steps
Processing batch 12: all
  Completed inference in 16 steps
Processing batch 13: all


 40%|‚ñà‚ñà‚ñà‚ñâ      | 7812/19531 [15:39<18:37, 10.49it/s]

  Completed inference in 16 steps
Processing batch 14: all
  Completed inference in 16 steps
Processing batch 15: all
  Completed inference in 16 steps
Processing batch 16: all
  Completed inference in 16 steps
Processing batch 17: all
  Completed inference in 16 steps
Processing batch 18: all
  Completed inference in 16 steps
Processing batch 19: all
  Completed inference in 16 steps
Processing batch 20: all
  Completed inference in 16 steps
Processing batch 21: all
  Completed inference in 16 steps
Processing batch 22: all
  Completed inference in 16 steps
Processing batch 23: all
  Completed inference in 16 steps
Processing batch 24: all
  Completed inference in 16 steps
Processing batch 25: all
  Completed inference in 16 steps
Processing batch 26: all
  Completed inference in 16 steps
Processing batch 27: all
  Completed inference in 16 steps
Processing batch 28: all
  Completed inference in 16 steps
Processing batch 29: all
  Completed inference in 16 steps
Processing batch 30: a

 60%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñâ    | 11717/19531 [23:10<12:25, 10.48it/s]  

EVALUATE
SWITCH TO EMA
Processing batch 1: all
  Completed inference in 16 steps
Processing batch 2: all
  Completed inference in 16 steps
Processing batch 3: all
  Completed inference in 16 steps
Processing batch 4: all
  Completed inference in 16 steps
Processing batch 5: all
  Completed inference in 16 steps
Processing batch 6: all
  Completed inference in 16 steps
Processing batch 7: all
  Completed inference in 16 steps
Processing batch 8: all
  Completed inference in 16 steps
Processing batch 9: all
  Completed inference in 16 steps
Processing batch 10: all
  Completed inference in 16 steps
Processing batch 11: all
  Completed inference in 16 steps
Processing batch 12: all
  Completed inference in 16 steps
Processing batch 13: all


 60%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñâ    | 11718/19531 [23:21<12:25, 10.48it/s]

  Completed inference in 16 steps
Processing batch 14: all
  Completed inference in 16 steps
Processing batch 15: all
  Completed inference in 16 steps
Processing batch 16: all
  Completed inference in 16 steps
Processing batch 17: all
  Completed inference in 16 steps
Processing batch 18: all
  Completed inference in 16 steps
Processing batch 19: all
  Completed inference in 16 steps
Processing batch 20: all
  Completed inference in 16 steps
Processing batch 21: all
  Completed inference in 16 steps
Processing batch 22: all
  Completed inference in 16 steps
Processing batch 23: all
  Completed inference in 16 steps
Processing batch 24: all
  Completed inference in 16 steps
Processing batch 25: all
  Completed inference in 16 steps
Processing batch 26: all
  Completed inference in 16 steps
Processing batch 27: all
  Completed inference in 16 steps
Processing batch 28: all
  Completed inference in 16 steps
Processing batch 29: all
  Completed inference in 16 steps
Processing batch 30: a

 60%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà    | 11719/19531 [24:39<29:03:38, 13.39s/it]

  Completed inference in 16 steps

Running 0 evaluator(s)...
All evaluators completed!
SAVE CHECKPOINT
Epoch 3000
TRAIN


 80%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñâ  | 15623/19531 [30:53<06:15, 10.41it/s]   

EVALUATE
SWITCH TO EMA
Processing batch 1: all
  Completed inference in 16 steps
Processing batch 2: all
  Completed inference in 16 steps
Processing batch 3: all
  Completed inference in 16 steps
Processing batch 4: all
  Completed inference in 16 steps
Processing batch 5: all
  Completed inference in 16 steps
Processing batch 6: all
  Completed inference in 16 steps
Processing batch 7: all
  Completed inference in 16 steps
Processing batch 8: all
  Completed inference in 16 steps
Processing batch 9: all
  Completed inference in 16 steps
Processing batch 10: all
  Completed inference in 16 steps
Processing batch 11: all
  Completed inference in 16 steps
Processing batch 12: all
  Completed inference in 16 steps
Processing batch 13: all
  Completed inference in 16 steps
Processing batch 14: all
  Completed inference in 16 steps
Processing batch 15: all
  Completed inference in 16 steps
Processing batch 16: all
  Completed inference in 16 steps
Processing batch 17: all
  Completed infer

 80%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñâ  | 15624/19531 [31:09<06:15, 10.41it/s]

  Completed inference in 16 steps
Processing batch 19: all
  Completed inference in 16 steps
Processing batch 20: all
  Completed inference in 16 steps
Processing batch 21: all
  Completed inference in 16 steps
Processing batch 22: all
  Completed inference in 16 steps
Processing batch 23: all
  Completed inference in 16 steps
Processing batch 24: all
  Completed inference in 16 steps
Processing batch 25: all
  Completed inference in 16 steps
Processing batch 26: all
  Completed inference in 16 steps
Processing batch 27: all
  Completed inference in 16 steps
Processing batch 28: all
  Completed inference in 16 steps
Processing batch 29: all
  Completed inference in 16 steps
Processing batch 30: all
  Completed inference in 16 steps
Processing batch 31: all
  Completed inference in 16 steps
Processing batch 32: all
  Completed inference in 16 steps
Processing batch 33: all
  Completed inference in 16 steps
Processing batch 34: all
  Completed inference in 16 steps
Processing batch 35: a

 80%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà  | 15625/19531 [32:22<14:32:02, 13.40s/it]

  Completed inference in 16 steps

Running 0 evaluator(s)...
All evaluators completed!
SAVE CHECKPOINT
Epoch 4000
TRAIN


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñâ| 19529/19531 [38:35<00:00, 10.40it/s]   

EVALUATE
SWITCH TO EMA
Processing batch 1: all
  Completed inference in 16 steps
Processing batch 2: all
  Completed inference in 16 steps
Processing batch 3: all
  Completed inference in 16 steps
Processing batch 4: all
  Completed inference in 16 steps
Processing batch 5: all
  Completed inference in 16 steps
Processing batch 6: all
  Completed inference in 16 steps
Processing batch 7: all
  Completed inference in 16 steps
Processing batch 8: all
  Completed inference in 16 steps
Processing batch 9: all
  Completed inference in 16 steps
Processing batch 10: all
  Completed inference in 16 steps
Processing batch 11: all
  Completed inference in 16 steps
Processing batch 12: all
  Completed inference in 16 steps
Processing batch 13: all
  Completed inference in 16 steps
Processing batch 14: all
  Completed inference in 16 steps
Processing batch 15: all
  Completed inference in 16 steps
Processing batch 16: all


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñâ| 19530/19531 [38:49<00:00, 10.40it/s]

  Completed inference in 16 steps
Processing batch 17: all
  Completed inference in 16 steps
Processing batch 18: all
  Completed inference in 16 steps
Processing batch 19: all
  Completed inference in 16 steps
Processing batch 20: all
  Completed inference in 16 steps
Processing batch 21: all
  Completed inference in 16 steps
Processing batch 22: all
  Completed inference in 16 steps
Processing batch 23: all
  Completed inference in 16 steps
Processing batch 24: all
  Completed inference in 16 steps
Processing batch 25: all
  Completed inference in 16 steps
Processing batch 26: all
  Completed inference in 16 steps
Processing batch 27: all
  Completed inference in 16 steps
Processing batch 28: all
  Completed inference in 16 steps
Processing batch 29: all
  Completed inference in 16 steps
Processing batch 30: all
  Completed inference in 16 steps
Processing batch 31: all
  Completed inference in 16 steps
Processing batch 32: all
  Completed inference in 16 steps
Processing batch 33: a

0,1
num_params,‚ñÅ
train/accuracy,‚ñÅ‚ñÅ‚ñÅ‚ñÜ‚ñà‚ñÜ‚ñÜ‚ñÜ‚ñá‚ñÜ‚ñÜ‚ñá‚ñÜ‚ñÜ‚ñÜ‚ñá‚ñÜ‚ñÜ‚ñÜ‚ñÜ‚ñÜ‚ñÜ‚ñÜ‚ñÜ‚ñÜ‚ñá‚ñÜ‚ñá‚ñá‚ñÜ‚ñá‚ñÜ‚ñÜ‚ñá‚ñÜ‚ñÜ‚ñá‚ñÜ‚ñÜ‚ñá
train/count,‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñà‚ñÅ‚ñÅ‚ñÅ‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà
train/exact_accuracy,‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÉ‚ñÑ‚ñÅ‚ñÇ‚ñÇ‚ñÅ‚ñÅ‚ñÇ‚ñÉ‚ñÑ‚ñÑ‚ñÑ‚ñá‚ñÇ‚ñÖ‚ñÑ‚ñÑ‚ñÉ‚ñÅ‚ñÉ‚ñÖ‚ñà‚ñÉ‚ñÖ‚ñÜ‚ñÑ‚ñÜ‚ñÜ‚ñá‚ñÜ‚ñÖ‚ñÜ‚ñà‚ñá
train/grad_norm,‚ñÖ‚ñÇ‚ñÅ‚ñÅ‚ñÖ‚ñÅ‚ñÇ‚ñÇ‚ñÉ‚ñÉ‚ñÇ‚ñÉ‚ñÇ‚ñÇ‚ñÑ‚ñÉ‚ñÑ‚ñÅ‚ñÉ‚ñÅ‚ñÇ‚ñÇ‚ñÉ‚ñÉ‚ñÅ‚ñÑ‚ñÇ‚ñÇ‚ñÉ‚ñÅ‚ñÇ‚ñÅ‚ñÅ‚ñà‚ñÜ‚ñÖ‚ñÅ‚ñÑ‚ñÑ‚ñÅ
train/lm_loss,‚ñà‚ñÑ‚ñÉ‚ñÇ‚ñÇ‚ñÅ‚ñÇ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ
train/lr,‚ñÅ‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà
train/q_halt_accuracy,‚ñà‚ñÅ‚ñÅ‚ñÅ‚ñà‚ñà‚ñà‚ñà‚ñá‚ñá‚ñá‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñá‚ñà‚ñà‚ñá‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñá‚ñà‚ñà‚ñà‚ñà
train/q_halt_loss,‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÑ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÑ‚ñÉ‚ñÇ‚ñÇ‚ñÇ‚ñÑ‚ñÅ‚ñÇ‚ñÉ‚ñÅ‚ñÖ‚ñÅ‚ñÑ‚ñÇ‚ñÑ‚ñÉ‚ñÅ‚ñá‚ñÅ‚ñÅ‚ñÅ‚ñá‚ñà‚ñÖ‚ñÅ‚ñÖ‚ñÖ‚ñÑ‚ñÇ‚ñÇ
train/steps,‚ñÅ‚ñÅ‚ñÅ‚ñà‚ñà‚ñà‚ñá‚ñà‚ñá‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñá‚ñà‚ñá‚ñá‚ñÜ‚ñá‚ñá‚ñá‚ñá‚ñá‚ñà‚ñá‚ñà‚ñá‚ñá‚ñà‚ñá‚ñá‚ñà‚ñá‚ñÜ‚ñá‚ñá‚ñÜ‚ñá

0,1
num_params,6828034.0
train/accuracy,0.81191
train/count,1.0
train/exact_accuracy,0.35294
train/grad_norm,0.21474
train/lm_loss,0.74323
train/lr,0.0002
train/q_halt_accuracy,0.94118
train/q_halt_loss,0.0184
train/steps,12.70588


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñâ| 19530/19531 [40:06<00:00,  8.12it/s]
