# üß† Modelo TRM Adaptado para Generaci√≥n de Texto y Razonamiento (GSM8K)

**¬øQu√© es el TRM?** El *Tiny Recursive Model (TRM)* es una arquitectura experimental originalmente dise√±ada para resolver problemas cerrados (como Sudokus) reciclando sus propios pesos para "pensar" profundamente. Este *notebook* transforma esa arquitectura en un **modelo de lenguaje autorregresivo** (como GPT) capaz de leer y resolver matem√°ticas en lenguaje natural.

### üõ†Ô∏è Contenido de la versi√≥n TRM_Math para Google Colab:

* **Adaptaci√≥n de Arquitectura:** Integraci√≥n de Atenci√≥n Causal y memoria din√°mica para permitir la generaci√≥n de texto token a token (emulando a los modelos GPT est√°ndar).
* **Tokenizador y RoPE Din√°mico:** Uso del tokenizador de GPT-2 y parche autom√°tico de embeddings posicionales (RoPE) para soportar *prompts* de longitud variable.
* **Entrenamiento con GSM8K:** L√≥gica de entrenamiento optimizada con *Smart Masking* para ense√±arle al modelo cu√°ndo terminar sus respuestas matem√°ticas.
* **Optimizado para Google Colab:** Dise√±ado para exprimir al m√°ximo la capa gratuita (**GPU T4** üí™).
* **Checkpointing Inteligente:** Capacidad para guardar y reanudar estados de entrenamiento complejos (modelo, optimizador y configuraci√≥n) directamente en **Google Drive**. A prueba de desconexiones üòé
* Deja tus comentarios para mejorar esta versi√≥n üòâüëç
* **Pr√≥ximo paso:** Generalizaci√≥n del modelo (Ajuste de capacidad, recursividad y *Early Stopping*) üîç

### üöÄ C√≥mo usar este proyecto
* Abre el notebook TRM_Math_Reasoning.ipynb en Google Colab.
* Ejecuta la celda de configuraci√≥n de carpeta y montaje de Google Drive.
* Ejecuta las dependencias y la preparaci√≥n del TRM y tokenizador.
* Para entrenar: Ajusta los par√°metros en el formulario "Ejecutar entrenamiento con Seguridad" y dale a Play.
* Para inferencia: Ve a la secci√≥n "Consola de Pruebas", escribe tu problema en ingl√©s y ejecuta.
---
###### *Por sea caso... los emojies son m√≠os*
### üòë

##**Cargar Google Drive** üìÇ
* Google Drive

In [None]:
# @title üìÇ Montar Google Drive y Configurar Carpeta
from google.colab import drive
import os

# @markdown Ingrese el nombre de la carpeta donde se guardar√°n los checkpoints y el modelo.
# @markdown (Si la carpeta no existe, se crear√° autom√°ticamente).
Project_Folder_Name = "TRM_Math_Project" # @param {type:"string"}

# 1. Montar Drive (Evita pedir permisos si ya est√° montado)
mount_path = '/content/drive'
if not os.path.exists(mount_path):
    print("üîå Conectando con Google Drive...")
    drive.mount(mount_path)
else:
    print("‚úÖ Google Drive ya estaba conectado.")

# 2. Configurar y Verificar Ruta
# La ruta base en Colab siempre es /content/drive/My Drive/
CHECKPOINT_DIR = os.path.join(mount_path, "My Drive", Project_Folder_Name)

if os.path.exists(CHECKPOINT_DIR):
    print(f"üìÇ Carpeta detectada: {CHECKPOINT_DIR}")
    # Opcional: Listar contenido para que veas qu√© tienes
    num_files = len(os.listdir(CHECKPOINT_DIR))
    print(f"   Contiene {num_files} archivos.")
else:
    try:
        os.makedirs(CHECKPOINT_DIR, exist_ok=True)
        print(f"üÜï Carpeta creada exitosamente: {CHECKPOINT_DIR}")
    except OSError as e:
        print(f"‚ùå Error al crear la carpeta: {e}")

print("üöÄ Sistema de almacenamiento listo.")

## **Dependencias üìö**
_Ejecutar las dependencias en orden para correr el modelo correctamente:_
* Transformers Librerias
* common
* layers
* sparse_embedding


In [None]:
# @title transformers
!pip install transformers

In [None]:
# @title common
import math

import torch
from torch import nn


def trunc_normal_init_(tensor: torch.Tensor, std: float = 1.0, lower: float = -2.0, upper: float = 2.0):
    # NOTE: PyTorch nn.init.trunc_normal_ is not mathematically correct, the std dev is not actually the std dev of initialized tensor
    # This function is a PyTorch version of jax truncated normal init (default init method in flax)
    # https://github.com/jax-ml/jax/blob/main/jax/_src/random.py#L807-L848
    # https://github.com/jax-ml/jax/blob/main/jax/_src/nn/initializers.py#L162-L199

    with torch.no_grad():
        if std == 0:
            tensor.zero_()
        else:
            sqrt2 = math.sqrt(2)
            a = math.erf(lower / sqrt2)
            b = math.erf(upper / sqrt2)
            z = (b - a) / 2

            c = (2 * math.pi) ** -0.5
            pdf_u = c * math.exp(-0.5 * lower ** 2)
            pdf_l = c * math.exp(-0.5 * upper ** 2)
            comp_std = std / math.sqrt(1 - (upper * pdf_u - lower * pdf_l) / z - ((pdf_u - pdf_l) / z) ** 2)

            tensor.uniform_(a, b)
            tensor.erfinv_()
            tensor.mul_(sqrt2 * comp_std)
            tensor.clip_(lower * comp_std, upper * comp_std)

    return tensor

In [None]:
# @title layers_for_prompt
from typing import Tuple
import einops
import torch
from torch import nn
import torch.nn.functional as F

#try:
#    from flash_attn_interface import flash_attn_func  # type: ignore[import]
#except ImportError:
#    # Fallback to FlashAttention 2
#    from flash_attn import flash_attn_func  # type: ignore[import]
from torch.nn.functional import scaled_dot_product_attention

#from models.common import trunc_normal_init_


CosSin = Tuple[torch.Tensor, torch.Tensor]


def _find_multiple(a, b):
    return (-(a // -b)) * b


def rotate_half(x: torch.Tensor):
    """Rotates half the hidden dims of the input."""
    x1 = x[..., : x.shape[-1] // 2]
    x2 = x[..., x.shape[-1] // 2 :]
    return torch.cat((-x2, x1), dim=-1)


# C√≥digo para TRM: Puzzles y Sudokus
#def apply_rotary_pos_emb(q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor):
#    # q, k: [bs, seq_len, num_heads, head_dim]
#    # cos, sin: [seq_len, head_dim]
#    orig_dtype = q.dtype
#    q = q.to(cos.dtype)
#    k = k.to(cos.dtype)

#    q_embed = (q * cos.unsqueeze(-2)) + (rotate_half(q) * sin.unsqueeze(-2))
#    k_embed = (k * cos.unsqueeze(-2)) + (rotate_half(k) * sin.unsqueeze(-2))

#    return q_embed.to(orig_dtype), k_embed.to(orig_dtype)

# Funci√≥n para TRM_Math: Generaci√≥n de texto y razonamiento matem√°tico b√°sico
def apply_rotary_pos_emb(q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor):
    # q, k: [bs, seq_len, num_heads, head_dim]
    # cos, sin: [MAX_seq_len, head_dim] (Aqu√≠ ven√≠a el conflicto 512 vs 7)

    # 1. FIX: Detectar la longitud real de la secuencia de entrada
    # q.shape[1] es la longitud actual (ej: 7 tokens o 512 tokens)
    seq_len = q.shape[1]

    # 2. FIX: Recortar los embeddings pre-calculados si exceden la longitud actual
    # Esto permite que el modelo acepte prompts cortos ("2+2") o largos indistintamente
    if cos.shape[0] > seq_len:
        cos = cos[:seq_len]
        sin = sin[:seq_len]

    # 3. L√≥gica original matem√°tica (sin cambios)
    orig_dtype = q.dtype
    q = q.to(cos.dtype)
    k = k.to(cos.dtype)

    q_embed = (q * cos.unsqueeze(-2)) + (rotate_half(q) * sin.unsqueeze(-2))
    k_embed = (k * cos.unsqueeze(-2)) + (rotate_half(k) * sin.unsqueeze(-2))

    return q_embed.to(orig_dtype), k_embed.to(orig_dtype)


class CastedLinear(nn.Module):
    def __init__(self,
                 in_features: int,
                 out_features: int,
                 bias: bool):
        super().__init__()
        # Truncated LeCun normal init
        self.weight = nn.Parameter(
            trunc_normal_init_(torch.empty((out_features, in_features)), std=1.0 / (in_features ** 0.5))
        )
        self.bias = None
        if bias:
            # Zero init bias
            self.bias = nn.Parameter(torch.zeros((out_features, )))

    def forward(self, input: torch.Tensor) -> torch.Tensor:
        return F.linear(input, self.weight.to(input.dtype), bias=self.bias.to(input.dtype) if self.bias is not None else None)


class CastedEmbedding(nn.Module):
    def __init__(self,
                 num_embeddings: int,
                 embedding_dim: int,
                 init_std: float,
                 cast_to: torch.dtype):
        super().__init__()
        self.cast_to = cast_to

        # Truncated LeCun normal init
        self.embedding_weight = nn.Parameter(
            trunc_normal_init_(torch.empty((num_embeddings, embedding_dim)), std=init_std)
        )

    def forward(self, input: torch.Tensor) -> torch.Tensor:
        return F.embedding(input, self.embedding_weight.to(self.cast_to))


class RotaryEmbedding(nn.Module):
    def __init__(self, dim, max_position_embeddings, base, device=None):
        super().__init__()

        # RoPE
        inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim))
        t = torch.arange(max_position_embeddings, dtype=torch.float32, device=device)
        freqs = torch.outer(t, inv_freq)

        # Different from paper, but it uses a different permutation in order to obtain the same calculation
        emb = torch.cat((freqs, freqs), dim=-1)
        self.cos_cached = nn.Buffer(emb.cos(), persistent=False)
        self.sin_cached = nn.Buffer(emb.sin(), persistent=False)

    def forward(self):
        return self.cos_cached, self.sin_cached


class Attention(nn.Module):
    def __init__(self, hidden_size, head_dim, num_heads, num_key_value_heads, causal=False):
        super().__init__()

        self.hidden_size = hidden_size
        self.head_dim = head_dim
        self.output_size = head_dim * num_heads
        self.num_heads = num_heads
        self.num_key_value_heads = num_key_value_heads
        self.causal = causal

        self.qkv_proj = CastedLinear(self.hidden_size, (self.num_heads + 2 * self.num_key_value_heads) * self.head_dim, bias=False)
        self.o_proj = CastedLinear(self.output_size, self.hidden_size, bias=False)

    def forward(self, cos_sin: CosSin, hidden_states: torch.Tensor) -> torch.Tensor:
        batch_size, seq_len, _ = hidden_states.shape

        # hidden_states: [bs, seq_len, num_heads, head_dim]
        qkv = self.qkv_proj(hidden_states)

        # Split head
        qkv = qkv.view(batch_size, seq_len, self.num_heads + 2 * self.num_key_value_heads, self.head_dim)
        query = qkv[:, :, :self.num_heads]
        key = qkv[:, :, self.num_heads: self.num_heads + self.num_key_value_heads]
        value = qkv[:, :, self.num_heads + self.num_key_value_heads:]

        # RoPE
        if cos_sin is not None:
            cos, sin = cos_sin
            query, key = apply_rotary_pos_emb(query, key, cos, sin)

        # flash attn
        query, key, value = map(lambda t: einops.rearrange(t, 'B S H D -> B H S D'), (query, key, value)) # needed for scaled_dot_product_attention but not flash_attn_func
        attn_output = scaled_dot_product_attention(query=query, key=key, value=value, is_causal=self.causal)
        attn_output = einops.rearrange(attn_output, 'B H S D -> B S H D')
        attn_output = attn_output.reshape(batch_size, seq_len, self.output_size)  # type: ignore
        return self.o_proj(attn_output)

class LinearSwish(nn.Module):
    def __init__(self, hidden_size: int, reverse=False):
        super().__init__()

        self.linear = CastedLinear(hidden_size, hidden_size, bias=False)
        self.reverse = reverse

    def forward(self, x):
        if self.reverse:
            return F.silu(self.linear(x))
        else:
            return self.linear(F.silu(x))


class SwiGLU(nn.Module):
    def __init__(self, hidden_size: int, expansion: float):
        super().__init__()
        inter = _find_multiple(round(expansion * hidden_size * 2 / 3), 256)

        self.gate_up_proj = CastedLinear(hidden_size, inter * 2, bias=False)
        self.down_proj    = CastedLinear(inter, hidden_size, bias=False)

    def forward(self, x):
        gate, up = self.gate_up_proj(x).chunk(2, dim=-1)
        return self.down_proj(F.silu(gate) * up)

def rms_norm(hidden_states: torch.Tensor, variance_epsilon: float) -> torch.Tensor:
    input_dtype = hidden_states.dtype
    hidden_states = hidden_states.to(torch.float32)

    variance = hidden_states.square().mean(-1, keepdim=True)
    hidden_states = hidden_states * torch.rsqrt(variance + variance_epsilon)
    return hidden_states.to(input_dtype)

In [None]:
# @title sparse_embedding
from typing import Union

import torch
from torch import nn
import torch.distributed as dist
from torch.optim.optimizer import Optimizer, ParamsT

#from models.common import trunc_normal_init_


class CastedSparseEmbedding(nn.Module):
    def __init__(self, num_embeddings: int, embedding_dim: int, batch_size: int, init_std: float, cast_to: torch.dtype):
        super().__init__()
        self.cast_to = cast_to

        # Real Weights
        # Truncated LeCun normal init
        self.weights = nn.Buffer(
            trunc_normal_init_(torch.empty((num_embeddings, embedding_dim)), std=init_std), persistent=True
        )

        # Local weights and IDs
        # Local embeddings, with gradient, not persistent
        self.local_weights = nn.Buffer(torch.zeros(batch_size, embedding_dim, requires_grad=True), persistent=False)
        # Local embedding IDs, not persistent
        self.local_ids = nn.Buffer(torch.zeros(batch_size, dtype=torch.int32), persistent=False)

    def forward(self, inputs: torch.Tensor) -> torch.Tensor:
        if not self.training:
            # Test mode, no gradient
            return self.weights[inputs].to(self.cast_to)

        # Training mode, fill puzzle embedding from weights
        with torch.no_grad():
            self.local_weights.copy_(self.weights[inputs])
            self.local_ids.copy_(inputs)

        return self.local_weights.to(self.cast_to)


class CastedSparseEmbeddingSignSGD_Distributed(Optimizer):
    def __init__(
        self,
        params: ParamsT,

        world_size: int,
        lr: Union[float, torch.Tensor] = 1e-3,
        weight_decay: float = 1e-2,
    ):
        if not 0.0 <= lr:
            raise ValueError(f"Invalid learning rate: {lr}")
        if not 0.0 <= weight_decay:
            raise ValueError(f"Invalid weight_decay value: {weight_decay}")

        defaults = dict(
            lr=lr,
            weight_decay=weight_decay,
            world_size=world_size
        )
        super().__init__(params, defaults)

    @torch.no_grad
    def step(self, closure=None):  # type: ignore
        for group in self.param_groups:
            # Find the sparse embedding weights
            local_weights_grad = None
            local_ids = None
            weights = None

            assert len(group["params"]) == 3
            for p in group["params"]:
                if p.requires_grad:
                    local_weights_grad = p.grad
                elif p.ndim == 1:
                    local_ids = p
                elif p.ndim == 2:
                    weights = p
                else:
                    assert False

            assert local_ids is not None
            assert weights is not None

            # Apply SignSGD
            # Adam ‚âà SignSGD if gradient is very sparse
            if local_weights_grad is not None:
                _sparse_emb_signsgd_dist(
                    local_weights_grad,
                    local_ids,
                    weights,

                    lr=group["lr"],
                    weight_decay=group["weight_decay"],
                    world_size=group["world_size"]
                )


def _sparse_emb_signsgd_dist(
    local_weights_grad: torch.Tensor,
    local_ids: torch.Tensor,
    weights: torch.Tensor,

    lr: float,
    weight_decay: float,
    world_size: int
) -> None:
    N, D = local_weights_grad.shape

    # All-gather
    all_weights_grad = local_weights_grad
    all_ids = local_ids

    if world_size > 1:
        all_weights_grad = torch.empty((world_size * N, D), dtype=local_weights_grad.dtype, device=local_weights_grad.device)
        all_ids = torch.empty(world_size * N,               dtype=local_ids.dtype,          device=local_ids.device)

        dist.all_gather_into_tensor(all_weights_grad, local_weights_grad)
        dist.all_gather_into_tensor(all_ids,          local_ids)

    # Unique
    grad_ids, inv = all_ids.unique(return_inverse=True)

    grad = torch.zeros((grad_ids.shape[0], D), dtype=all_weights_grad.dtype, device=all_weights_grad.device)
    grad.scatter_add_(0, inv.unsqueeze(-1).expand(-1, D), all_weights_grad)

    # SignSGD with decoupled weight decay
    p = weights[grad_ids]

    p.mul_(1.0 - lr * weight_decay).add_(torch.sign(grad), alpha=-lr)

    # Write updated slices back
    weights[grad_ids] = p

## **Preparaci√≥n de TRM y Tokenizador üß†**
* TRM Modelo Base
* GPT-2 Tokenizer
* Math_TRM
* Loop de inferencia
* Cargar tokenizador ligero
* Prueba de verificaci√≥n

In [None]:
# @title TRM
from typing import Tuple, List, Dict, Optional
from dataclasses import dataclass
import math
import torch
import copy
import torch.nn.functional as F
from torch import nn
from pydantic import BaseModel
import random
#from models.common import trunc_normal_init_
#from models.layers import rms_norm, LinearSwish, SwiGLU, Attention, RotaryEmbedding, CosSin, CastedEmbedding, CastedLinear
#from models.sparse_embedding import CastedSparseEmbedding

IGNORE_LABEL_ID = -100

@dataclass
class TinyRecursiveReasoningModel_ACTV1InnerCarry:
    z_H: torch.Tensor
    z_L: torch.Tensor


@dataclass
class TinyRecursiveReasoningModel_ACTV1Carry:
    inner_carry: TinyRecursiveReasoningModel_ACTV1InnerCarry

    steps: torch.Tensor
    halted: torch.Tensor

    current_data: Dict[str, torch.Tensor]


class TinyRecursiveReasoningModel_ACTV1Config(BaseModel):
    batch_size: int
    seq_len: int
    puzzle_emb_ndim: int = 0
    num_puzzle_identifiers: int
    vocab_size: int

    H_cycles: int
    L_cycles: int

    H_layers: int # ignored
    L_layers: int

    # Transformer config
    hidden_size: int
    expansion: float
    num_heads: int
    pos_encodings: str

    rms_norm_eps: float = 1e-5
    rope_theta: float = 10000.0

    # Halting Q-learning config
    halt_max_steps: int
    halt_exploration_prob: float

    forward_dtype: str = "bfloat16"

    # Alexia: added
    mlp_t: bool = False # use mlp on L instead of transformer
    puzzle_emb_len: int = 16 # if non-zero, its specified to this value
    no_ACT_continue: bool =  True # No continue ACT loss, only use the sigmoid of the halt which makes much more sense

class TinyRecursiveReasoningModel_ACTV1Block(nn.Module):
    def __init__(self, config: TinyRecursiveReasoningModel_ACTV1Config) -> None:
        super().__init__()

        self.config = config
        if self.config.mlp_t:
            self.puzzle_emb_len = -(self.config.puzzle_emb_ndim // -self.config.hidden_size) if self.config.puzzle_emb_len == 0 else self.config.puzzle_emb_len
            self.mlp_t = SwiGLU(
                hidden_size=self.config.seq_len + self.puzzle_emb_len, # L
                expansion=config.expansion,
            )
        else:
            self.self_attn = Attention(
                hidden_size=config.hidden_size,
                head_dim=config.hidden_size // config.num_heads,
                num_heads=config.num_heads,
                num_key_value_heads=config.num_heads,
                causal=False
            )
        self.mlp = SwiGLU(
            hidden_size=config.hidden_size,
            expansion=config.expansion,
        )
        self.norm_eps = config.rms_norm_eps

    def forward(self, cos_sin: CosSin, hidden_states: torch.Tensor) -> torch.Tensor:
        # B, L, D = hidden_states.shape
        # Post Norm
        if self.config.mlp_t:
            hidden_states = hidden_states.transpose(1,2)
            out = self.mlp_t(hidden_states)
            hidden_states = rms_norm(hidden_states + out, variance_epsilon=self.norm_eps)
            hidden_states = hidden_states.transpose(1,2)
        else:
            # Self Attention
            hidden_states = rms_norm(hidden_states + self.self_attn(cos_sin=cos_sin, hidden_states=hidden_states), variance_epsilon=self.norm_eps)
        # Fully Connected
        out = self.mlp(hidden_states)
        hidden_states = rms_norm(hidden_states + out, variance_epsilon=self.norm_eps)
        return hidden_states

class TinyRecursiveReasoningModel_ACTV1ReasoningModule(nn.Module):
    def __init__(self, layers: List[TinyRecursiveReasoningModel_ACTV1Block]):
        super().__init__()
        self.layers = torch.nn.ModuleList(layers)

    def forward(self, hidden_states: torch.Tensor, input_injection: torch.Tensor, **kwargs) -> torch.Tensor:
        hidden_states = hidden_states + input_injection
        for layer in self.layers:
            hidden_states = layer(hidden_states=hidden_states, **kwargs)
        return hidden_states


class TinyRecursiveReasoningModel_ACTV1_Inner(nn.Module):
    def __init__(self, config: TinyRecursiveReasoningModel_ACTV1Config) -> None:
        super().__init__()
        self.config = config
        self.forward_dtype = getattr(torch, self.config.forward_dtype)

        # I/O

        self.embed_scale = math.sqrt(self.config.hidden_size)
        embed_init_std = 1.0 / self.embed_scale

        self.embed_tokens = CastedEmbedding(self.config.vocab_size, self.config.hidden_size, init_std=embed_init_std, cast_to=self.forward_dtype)
        self.lm_head      = CastedLinear(self.config.hidden_size, self.config.vocab_size, bias=False)
        self.q_head       = CastedLinear(self.config.hidden_size, 2, bias=True)

        self.puzzle_emb_len = -(self.config.puzzle_emb_ndim // -self.config.hidden_size)  if self.config.puzzle_emb_len == 0 else self.config.puzzle_emb_len  # ceil div
        if self.config.puzzle_emb_ndim > 0:
            # Zero init puzzle embeddings
            self.puzzle_emb = CastedSparseEmbedding(self.config.num_puzzle_identifiers, self.config.puzzle_emb_ndim,
                                                    batch_size=self.config.batch_size, init_std=0, cast_to=self.forward_dtype)

        # LM Blocks
        if self.config.pos_encodings == "rope":
            self.rotary_emb = RotaryEmbedding(dim=self.config.hidden_size // self.config.num_heads,
                                              max_position_embeddings=self.config.seq_len + self.puzzle_emb_len,
                                              base=self.config.rope_theta)
        elif self.config.pos_encodings == "learned":
            self.embed_pos = CastedEmbedding(self.config.seq_len + self.puzzle_emb_len, self.config.hidden_size, init_std=embed_init_std, cast_to=self.forward_dtype)
        else:
            pass

        # Reasoning Layers
        self.L_level = TinyRecursiveReasoningModel_ACTV1ReasoningModule(layers=[TinyRecursiveReasoningModel_ACTV1Block(self.config) for _i in range(self.config.L_layers)])

        # Initial states
        self.H_init = nn.Buffer(trunc_normal_init_(torch.empty(self.config.hidden_size, dtype=self.forward_dtype), std=1), persistent=True)
        self.L_init = nn.Buffer(trunc_normal_init_(torch.empty(self.config.hidden_size, dtype=self.forward_dtype), std=1), persistent=True)

        # Q head special init
        # Init Q to (almost) zero for faster learning during bootstrapping
        with torch.no_grad():
            self.q_head.weight.zero_()
            self.q_head.bias.fill_(-5)  # type: ignore

    def _input_embeddings(self, input: torch.Tensor, puzzle_identifiers: torch.Tensor):
        # Token embedding
        embedding = self.embed_tokens(input.to(torch.int32))

        # Puzzle embeddings
        if self.config.puzzle_emb_ndim > 0:
            puzzle_embedding = self.puzzle_emb(puzzle_identifiers)

            pad_count = self.puzzle_emb_len * self.config.hidden_size - puzzle_embedding.shape[-1]
            if pad_count > 0:
                puzzle_embedding = F.pad(puzzle_embedding, (0, pad_count))

            embedding = torch.cat((puzzle_embedding.view(-1, self.puzzle_emb_len, self.config.hidden_size), embedding), dim=-2)

        # Position embeddings
        if self.config.pos_encodings == "learned":
            # scale by 1/sqrt(2) to maintain forward variance
            embedding = 0.707106781 * (embedding + self.embed_pos.embedding_weight.to(self.forward_dtype))

        # Scale
        return self.embed_scale * embedding

    def empty_carry(self, batch_size: int):
        return TinyRecursiveReasoningModel_ACTV1InnerCarry(
            z_H=torch.empty(batch_size, self.config.seq_len + self.puzzle_emb_len, self.config.hidden_size, dtype=self.forward_dtype),
            z_L=torch.empty(batch_size, self.config.seq_len + self.puzzle_emb_len, self.config.hidden_size, dtype=self.forward_dtype),
        )

    def reset_carry(self, reset_flag: torch.Tensor, carry: TinyRecursiveReasoningModel_ACTV1InnerCarry):
        return TinyRecursiveReasoningModel_ACTV1InnerCarry(
            z_H=torch.where(reset_flag.view(-1, 1, 1), self.H_init, carry.z_H),
            z_L=torch.where(reset_flag.view(-1, 1, 1), self.L_init, carry.z_L),
        )

    def forward(self, carry: TinyRecursiveReasoningModel_ACTV1InnerCarry, batch: Dict[str, torch.Tensor]) -> Tuple[TinyRecursiveReasoningModel_ACTV1InnerCarry, torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
        seq_info = dict(
            cos_sin=self.rotary_emb() if hasattr(self, "rotary_emb") else None,
        )

        # Input encoding
        input_embeddings = self._input_embeddings(batch["inputs"], batch["puzzle_identifiers"])

        # Forward iterations
        it = 0
        z_H, z_L = carry.z_H, carry.z_L
        # H_cycles-1 without grad
        with torch.no_grad():
            for _H_step in range(self.config.H_cycles-1):
                for _L_step in range(self.config.L_cycles):
                    z_L = self.L_level(z_L, z_H + input_embeddings, **seq_info)
                z_H = self.L_level(z_H, z_L, **seq_info)
        # 1 with grad
        for _L_step in range(self.config.L_cycles):
            z_L = self.L_level(z_L, z_H + input_embeddings, **seq_info)
        z_H = self.L_level(z_H, z_L, **seq_info)

        # LM Outputs
        new_carry = TinyRecursiveReasoningModel_ACTV1InnerCarry(z_H=z_H.detach(), z_L=z_L.detach())  # New carry no grad
        output = self.lm_head(z_H)[:, self.puzzle_emb_len:]
        q_logits = self.q_head(z_H[:, 0]).to(torch.float32) # Q-head; uses the first puzzle_emb position
        return new_carry, output, (q_logits[..., 0], q_logits[..., 1])


class TinyRecursiveReasoningModel_ACTV1(nn.Module):
    """ACT wrapper."""

    def __init__(self, config_dict: dict):
        super().__init__()
        self.config = TinyRecursiveReasoningModel_ACTV1Config(**config_dict)
        self.inner = TinyRecursiveReasoningModel_ACTV1_Inner(self.config)

    @property
    def puzzle_emb(self):
        return self.inner.puzzle_emb

    def initial_carry(self, batch: Dict[str, torch.Tensor]):
        batch_size = batch["inputs"].shape[0]

        return TinyRecursiveReasoningModel_ACTV1Carry(
            inner_carry=self.inner.empty_carry(batch_size),  # Empty is expected, it will be reseted in first pass as all sequences are halted.

            steps=torch.zeros((batch_size, ), dtype=torch.int32),
            halted=torch.ones((batch_size, ), dtype=torch.bool),  # Default to halted

            current_data={k: torch.empty_like(v) for k, v in batch.items()}
        )

    def forward(self, carry: TinyRecursiveReasoningModel_ACTV1Carry, batch: Dict[str, torch.Tensor]) -> Tuple[TinyRecursiveReasoningModel_ACTV1Carry, Dict[str, torch.Tensor]]:

        # Update data, carry (removing halted sequences)
        new_inner_carry = self.inner.reset_carry(carry.halted, carry.inner_carry)

        new_steps = torch.where(carry.halted, 0, carry.steps)

        new_current_data = {k: torch.where(carry.halted.view((-1, ) + (1, ) * (batch[k].ndim - 1)), batch[k], v) for k, v in carry.current_data.items()}

        # Forward inner model
        new_inner_carry, logits, (q_halt_logits, q_continue_logits) = self.inner(new_inner_carry, new_current_data)

        outputs = {
            "logits": logits,
            "q_halt_logits": q_halt_logits,
            "q_continue_logits": q_continue_logits
        }

        with torch.no_grad():
            # Step
            new_steps = new_steps + 1
            is_last_step = new_steps >= self.config.halt_max_steps

            halted = is_last_step

            # if training, and ACT is enabled
            if self.training and (self.config.halt_max_steps > 1):

                # Halt signal
                # NOTE: During evaluation, always use max steps, this is to guarantee the same halting steps inside a batch for batching purposes

                if self.config.no_ACT_continue:
                    halted = halted | (q_halt_logits > 0)
                else:
                    halted = halted | (q_halt_logits > q_continue_logits)

                # Exploration
                min_halt_steps = (torch.rand_like(q_halt_logits) < self.config.halt_exploration_prob) * torch.randint_like(new_steps, low=2, high=self.config.halt_max_steps + 1)
                halted = halted & (new_steps >= min_halt_steps)

                if not self.config.no_ACT_continue:
                    # Compute target Q
                    # NOTE: No replay buffer and target networks for computing target Q-value.
                    # As batch_size is large, there're many parallel envs.
                    # Similar concept as PQN https://arxiv.org/abs/2407.04811
                    _, _, (next_q_halt_logits, next_q_continue_logits), _, _ = self.inner(new_inner_carry, new_current_data)
                    outputs["target_q_continue"] = torch.sigmoid(torch.where(is_last_step, next_q_halt_logits, torch.maximum(next_q_halt_logits, next_q_continue_logits)))

        return TinyRecursiveReasoningModel_ACTV1Carry(new_inner_carry, new_steps, halted, new_current_data), outputs

In [None]:
import torch
import torch.nn as nn
from transformers import AutoTokenizer

# @title GPT-2 Tokenizer
# @markdown ##### Modifique los par√°metros con cuidado.
# @markdown Riesgo de sobrecargar la VRAM y no poder ejecutarse el modelo o el entrenamiento.

#m_batch_size = 1 # @param {type:"number"}
#m_seq_len = 512 # @param {type:"number"}
#m_num_puzzle_identifiers = 0 # @param {type:"number"}
#m_vocab_size = "tokenizer.vocab_size" # @param {type:"string"}
#m_H_cycles = 2 # @param {type:"number"}
#m_L_cycles = 6 # @param {type:"number"}
#m_H_layers = 0 # @param {type:"number"}
#m_L_layers = 2 # @param {type:"number"}
#m_hidden_size = 1024 # @param {type:"number"}
#m_expansion = 4.0 # @param
#m_num_heads = 16 # @param {type:"number"}
#m_pos_encodings = "rope" # @param {typer:"string"}
#m_halt_max_steps = 16 # @param {type:"number"}
#m_halt_exploration_prob = 0.0 # @param {type:"number"}
#m_forward_dtype = "\"float32\"" # @param {type:"string"}

class MathTRMConfig(TinyRecursiveReasoningModel_ACTV1Config):
    def __init__(self, vocab_size, **kwargs):
        # Configuraciones forzadas para funcionamiento tipo GPT
        kwargs["vocab_size"] = vocab_size
        kwargs["puzzle_emb_ndim"] = 0   # Desactivamos embeddings de puzzle
        kwargs["puzzle_emb_len"] = 0
        kwargs["mlp_t"] = False         # IMPORTANTE: False para usar Attention class
        kwargs["pos_encodings"] = "rope"

        super().__init__(**kwargs)

def get_math_config(tokenizer):
    return {
        "batch_size": 1,      # No tocar
        "seq_len": 512,       # Contexto suficiente para preguntas matem√°ticas
        "num_puzzle_identifiers": 0,
        "vocab_size": tokenizer.vocab_size,
        "H_cycles": 2,        # Ciclos de refinamiento de respuesta (T)
        "L_cycles": 6,        # Ciclos de razonamiento latente (n)
        "H_layers": 0,        # No usado
        "L_layers": 2,        # Mantenemos la red "Tiny"
        "hidden_size": 1024,  # Dimensi√≥n similar a GPT-2 Small
        "expansion": 4.0,
        "num_heads": 16,
        "pos_encodings": "rope",
        "halt_max_steps": 16, # M√°ximo pasos de "pensamiento" ACT
        "halt_exploration_prob": 0.0,
        "forward_dtype": "float32"
    }

In [None]:
# @title Math_TRM
class MathTRM(nn.Module):
    def __init__(self, tokenizer):
        super().__init__()
        self.tokenizer = tokenizer
        config_dict = get_math_config(tokenizer)

        # Aseguramos que el puzzle_len sea 0 en la configuraci√≥n interna
        config_dict['puzzle_emb_len'] = 0

        self.model = TinyRecursiveReasoningModel_ACTV1(config_dict)

        # --- PARCHE DE CAUSALIDAD (Necesario para texto) ---
        for layer in self.model.inner.L_level.layers:
            if hasattr(layer, 'self_attn'):
                layer.self_attn.causal = True

    def forward(self, input_ids, carry=None):
        device = input_ids.device
        # Capturamos el tama√±o REAL de la entrada actual (ej: 7 tokens)
        batch_size, current_seq_len = input_ids.shape

        batch = {
            "inputs": input_ids,
            "puzzle_identifiers": torch.zeros((batch_size, 1), dtype=torch.long).to(device)
        }

        if carry is None:
            # --- CORRECCI√ìN CR√çTICA ---
            # En lugar de usar initial_carry() (que crea un tensor de 512+16),
            # creamos manualmente una memoria del tama√±o EXACTO de la entrada (ej: 7).

            dtype = self.model.inner.forward_dtype # Generalmente float32 o bfloat16
            hidden_size = self.model.config.hidden_size

            # 1. Crear tensores de memoria vac√≠os con tama√±o [Batch, Current_Len, Hidden]
            inner_carry = TinyRecursiveReasoningModel_ACTV1InnerCarry(
                z_H=torch.zeros(batch_size, current_seq_len, hidden_size, device=device, dtype=dtype),
                z_L=torch.zeros(batch_size, current_seq_len, hidden_size, device=device, dtype=dtype),
            )

            # 2. Empaquetar en el objeto de transporte
            carry = TinyRecursiveReasoningModel_ACTV1Carry(
                inner_carry=inner_carry,
                steps=torch.zeros((batch_size,), dtype=torch.int32, device=device),
                halted=torch.ones((batch_size,), dtype=torch.bool, device=device), # True para que reset_carry inicie los pesos
                current_data={k: v.to(device) for k, v in batch.items()}
            )

        # Pase recursivo
        new_carry, outputs = self.model(carry, batch)

        return outputs["logits"], new_carry

In [None]:
# @title Loop de Inferencia
def solve_math_problem(prompt, model, max_new_tokens=100):
    model.eval()
    device = next(model.parameters()).device
    tokenizer = model.tokenizer

    # Tokenizar entrada
    inputs = tokenizer(prompt, return_tensors="pt").to(device)
    input_ids = inputs["input_ids"]

    print(f"Pregunta: {prompt}\nRespuesta TRM: ", end="", flush=True)

    generated_ids = input_ids

    with torch.no_grad():
        for _ in range(max_new_tokens):
            # 1. Forward Pass
            # Enviamos toda la secuencia actual. El modelo "pensar√°" recursivamente
            # sobre ella (L_cycles * H_cycles) antes de darnos logits.
            logits, _ = model(generated_ids, carry=None)

            # 2. Predecir siguiente token (usamos el √∫ltimo de la secuencia)
            next_token_logits = logits[:, -1, :]
            next_token = torch.argmax(next_token_logits, dim=-1).unsqueeze(0)

            # 3. Decodificar y mostrar
            word = tokenizer.decode(next_token[0])
            print(word, end="", flush=True)

            # 4. Actualizar secuencia
            generated_ids = torch.cat([generated_ids, next_token], dim=1)

            if next_token.item() == tokenizer.eos_token_id:
                break

    return tokenizer.decode(generated_ids[0], skip_special_tokens=True)

In [None]:
# @title Cargar tokenizador ligero
tokenizer = AutoTokenizer.from_pretrained("gpt2")
tokenizer.pad_token = tokenizer.eos_token

# Iniciar modelo
model = MathTRM(tokenizer).to("cuda")

In [None]:
# @title Prueba de verificaci√≥n
pregunta = "Cuanto es 2 + 2?"
respuesta = solve_math_problem(pregunta, model)

## **Etapa de entrenamiento ‚òùü§ì**
* üìö Cargar y Procesar Dataset (GSM8K)
* Funci√≥n de entrenamiento y verificaci√≥n de seguridad
* Ejecutar entrenamiento con Seguridad üõ°Ô∏è


In [None]:
# @title üìö Cargar y Procesar Dataset (GSM8K)
from datasets import load_dataset
from torch.utils.data import Dataset
import torch

# 1. Definici√≥n de la Clase
class GSM8KDataset(Dataset):
    def __init__(self, tokenizer, split="train", max_length=512):
        print(f"üì• Descargando/Cargando dataset GSM8K ({split})...")

        # --- FIX: Usamos el ID oficial 'openai/gsm8k' ---
        try:
            self.dataset = load_dataset("openai/gsm8k", "main", split=split)
        except Exception as e:
            # Fallback por si acaso, aunque el anterior es el oficial
            print(f"‚ö†Ô∏è Error cargando openai/gsm8k: {e}. Intentando ruta alternativa...")
            self.dataset = load_dataset("gsm8k", "main", split=split)

        print(f"‚úÖ ¬°Dataset listo! {len(self.dataset)} ejemplos procesados.")

        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        item = self.dataset[idx]
        # Formato de Chat: Pregunta -> Respuesta
        text = f"Pregunta: {item['question']}\nRespuesta: {item['answer']}<|endoftext|>"

        encodings = self.tokenizer(
            text,
            truncation=True,
            max_length=self.max_length,
            padding="max_length",
            return_tensors="pt"
        )

        input_ids = encodings["input_ids"].squeeze(0)
        labels = input_ids.clone()

        # M√ÅSCARA INTELIGENTE (Smart Masking)
        # Ignoramos el padding para que el modelo no aprenda a generar espacios vac√≠os
        eos_mask = (input_ids == self.tokenizer.eos_token_id)
        if eos_mask.any():
            # Encontrar el primer EOS (donde acaba la frase real)
            first_eos_idx = torch.where(eos_mask)[0][0]

            # Si hay espacio despu√©s del EOS, lo marcamos con -100 (ignorar)
            if first_eos_idx + 1 < len(labels):
                labels[first_eos_idx + 1:] = -100

        return input_ids, labels

# 2. Instanciaci√≥n (Solo preparamos los datos)
if 'tokenizer' not in globals():
    from transformers import AutoTokenizer
    # Usamos GPT-2 tokenizer por ser ligero y eficaz para matem√°ticas
    print("‚öôÔ∏è Cargando Tokenizer...")
    tokenizer = AutoTokenizer.from_pretrained("gpt2")
    tokenizer.pad_token = tokenizer.eos_token

# Instanciamos el objeto Dataset
train_dataset = GSM8KDataset(tokenizer, split="train")

# Verificaci√≥n de integridad
if len(train_dataset) == 0:
    raise ValueError("üî¥ ERROR CR√çTICO: El dataset parece vac√≠o.")
else:
    print(f"‚ú® Todo listo. El dataset 'train_dataset' est√° disponible para el entrenamiento.")
    # NOTA: No creamos el DataLoader aqu√≠.
    # El 'batch_size' se definir√° din√°micamente en el formulario de entrenamiento.

In [None]:
import torch
import os
import time
from torch.optim import AdamW
from torch.utils.data import DataLoader
from torch.optim.lr_scheduler import CosineAnnealingLR
import tqdm

# @title Funci√≥n de entrenamiento y verificaci√≥n de seguridad
# --- MODIFICACI√ìN 1: Guardamos tambi√©n la configuraci√≥n ---
def save_checkpoint(path, model, optimizer, scheduler, epoch, loss, config):
    torch.save({
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'scheduler_state_dict': scheduler.state_dict(),
        'loss': loss,
        'config': config # <--- Guardamos los par√°metros usados
    }, path)
    print(f"   üíæ Checkpoint guardado.")

# --- MODIFICACI√ìN 2: Nueva funci√≥n de verificaci√≥n ---
def safety_check(saved_config, current_config):
    """Compara configuraciones y pide confirmaci√≥n al usuario si hay cambios."""
    if saved_config is None:
        print("‚ö†Ô∏è Checkpoint antiguo sin datos de configuraci√≥n. Se usar√°n los nuevos par√°metros.")
        return current_config, True # Continuar

    diffs = []
    # Comparamos claves relevantes
    keys_to_check = ['lr', 'batch_size', 'accumulation_steps', 'total_epochs']

    for key in keys_to_check:
        old_val = saved_config.get(key)
        new_val = current_config.get(key)
        # Usamos str() para evitar problemas de redondeo con floats al comparar
        if str(old_val) != str(new_val):
            diffs.append((key, old_val, new_val))

    if not diffs:
        print("‚úÖ Verificaci√≥n de seguridad: Par√°metros id√©nticos. Continuando...")
        return current_config, True

    # Si hay diferencias, activamos la ALERTA
    print("\n" + "!"*60)
    print("üõë ALERTA DE SEGURIDAD: Los par√°metros han cambiado")
    print("!"*60)
    print(f"{'PAR√ÅMETRO':<20} | {'GUARDADO':<15} | {'NUEVO (@param)':<15}")
    print("-" * 56)
    for key, old, new in diffs:
        print(f"{key:<20} | {str(old):<15} | {str(new):<15}")
    print("-" * 56)

    print("\nOpciones:")
    print("  [Y]  Usar los NUEVOS par√°metros (Sobreescribe la configuraci√≥n anterior)")
    print("  [N]  Usar los par√°metros GUARDADOS (Ignora lo que pusiste en @param)")
    print("  [S]  STOP / Abortar (Para corregir manualmente)")

    while True:
        choice = input("\n¬øDesea continuar con los NUEVOS par√°metros? (Y/N/S): ").strip().upper()
        if choice == 'Y':
            print("üëâ Has elegido: NUEVOS par√°metros.")
            return current_config, True
        elif choice == 'N':
            print("üëâ Has elegido: RESTAURAR par√°metros guardados.")
            return saved_config, True
        elif choice == 'S' or choice == 'STOP':
            print("üõë Ejecuci√≥n abortada por el usuario.")
            return None, False
        else:
            print("Opci√≥n no v√°lida. Escribe Y, N o S.")

def train_resumable_safe(model, train_dataset, current_config):
    # Desempaquetar config actual
    total_epochs = current_config['total_epochs']
    batch_size = current_config['batch_size']
    accumulation_steps = current_config['accumulation_steps']
    lr = current_config['lr']

    # Rutas
    checkpoint_path = os.path.join(CHECKPOINT_DIR, "trm_gsm8k_latest.pt")
    best_model_path = os.path.join(CHECKPOINT_DIR, "trm_gsm8k_best.pt")

    # Loader inicial (se actualizar√° si cambiamos batch_size en el check)
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

    device = next(model.parameters()).device
    optimizer = AdamW(model.parameters(), lr=lr, weight_decay=0.1)

    # Scheduler placeholder (se recalcular√°)
    total_steps = total_epochs * len(train_loader)
    scheduler = CosineAnnealingLR(optimizer, T_max=total_steps)
    loss_fct = torch.nn.CrossEntropyLoss()

    start_epoch = 0
    best_loss = float('inf')

    # --- L√ìGICA DE CARGA Y SEGURIDAD ---
    if os.path.exists(checkpoint_path):
        print(f"üîÑ Checkpoint detectado. Analizando...")
        checkpoint = torch.load(checkpoint_path, map_location=device)

        saved_config = checkpoint.get('config', None)

        # === EJECUCI√ìN DEL CHECK ===
        final_config, should_continue = safety_check(saved_config, current_config)

        if not should_continue:
            return None # Abortar

        # Actualizar variables locales con la decisi√≥n tomada
        batch_size = final_config['batch_size']
        accumulation_steps = final_config['accumulation_steps']
        lr = final_config['lr']
        total_epochs = final_config['total_epochs']

        # Re-crear loader y scheduler si cambiaron los par√°metros
        if final_config != current_config or saved_config is not None:
             train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
             total_steps = total_epochs * len(train_loader)
             # Importante: Reinstanciar scheduler con nuevos steps totales
             scheduler = CosineAnnealingLR(optimizer, T_max=total_steps)

        # Cargar estados
        model.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        # Solo cargamos scheduler si no cambiamos dr√°sticamente la duraci√≥n
        # Si cambiamos epochs, es mejor dejar el scheduler nuevo
        if saved_config and saved_config['total_epochs'] == total_epochs:
             scheduler.load_state_dict(checkpoint['scheduler_state_dict'])

        start_epoch = checkpoint['epoch'] + 1
        best_loss = checkpoint.get('loss', float('inf'))

        # FORZAR LR (Si elegimos nuevos par√°metros o viejos, aseguramos que el optimizador obedezca)
        for param_group in optimizer.param_groups:
            param_group['lr'] = lr

        print(f"‚úÖ Reanudando en √âpoca {start_epoch} | Batch {batch_size} | LR {lr}")
    else:
        print("üÜï Iniciando entrenamiento nuevo (Sin checkpoint previo).")

    if start_epoch >= total_epochs:
        print("üéâ ¬°Entrenamiento completado! Aumenta 'Epochs' si deseas continuar.")
        return model

    # --- BUCLE DE ENTRENAMIENTO ---
    model.train()
    print(f"üöÄ Corriendo... (Guardando configuraci√≥n en cada paso)")

    # Configuraci√≥n final a guardar
    active_config = {
        'total_epochs': total_epochs,
        'batch_size': batch_size,
        'accumulation_steps': accumulation_steps,
        'lr': lr
    }

    for epoch in range(start_epoch, total_epochs):
        loop = tqdm.tqdm(train_loader, desc=f"Epoca {epoch+1}/{total_epochs}")
        epoch_loss = 0

        for batch_idx, (input_ids, labels) in enumerate(loop):
            input_ids, labels = input_ids.to(device), labels.to(device)

            logits, _ = model(input_ids, carry=None)

            shift_logits = logits[..., :-1, :].contiguous()
            shift_labels = labels[..., 1:].contiguous()
            loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))

            loss = loss / accumulation_steps
            loss.backward()

            if (batch_idx + 1) % accumulation_steps == 0:
                torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
                optimizer.step()
                scheduler.step()
                optimizer.zero_grad()

            current_loss = loss.item() * accumulation_steps
            epoch_loss += current_loss
            loop.set_postfix(loss=f"{current_loss:.4f}")

        avg_loss = epoch_loss / len(train_loader)
        print(f"   üìâ Fin Epoca {epoch+1} - Loss: {avg_loss:.4f}")

        # Pasamos active_config al guardar
        save_checkpoint(checkpoint_path, model, optimizer, scheduler, epoch, avg_loss, active_config)

        if avg_loss < best_loss:
            best_loss = avg_loss
            torch.save(model.state_dict(), best_model_path)
            print(f"   üèÜ R√©cord guardado.")

    return model

In [None]:
# @title Ejecutar entrenamiento con Seguridad üõ°Ô∏è
# @markdown ## **Par√°metros de control**
# @markdown * Ingrese los par√°metros para empezar entrenamiento en etapas.
# @markdown * Debido a las limitaciones de uso del entorno de ejecuci√≥n y disponibilidad de GPUs, el entrenamiento se guarda cada etapa con duraci√≥n promedio de 35-45 minutos en GPU T4 con la configuraci√≥n actual del modelo y par√°metros usados en el entrenamiento.
# @markdown * Puede volver a continuar el entrenamiento con los mismos par√°metros para terminar todas las "Epocas" de entrenamiento.
# @markdown * Pruebe sus propios par√°metros de modelo y entrenamiento para calcular el tiempo de cada √©poca y calcular cuanto aproximadamente durar√° todo el entrenamiento. BAJO TU PROPIO RIESGO el aumentar demasiado los par√°metros de modelo o entrenamiento.

# Variables
Epochs = 50 # @param {type:"slider", min:1, max:100, step:1}
Batch_Size = 8 # @param {type:"slider", min:1, max:32, step:1}
Accumulation_Steps = 4 # @param {type:"slider", min:1, max:16, step:1}
Learning_Rate = 1e-4 # @param {type:"number"}

# Empaquetamos la configuraci√≥n actual del usuario
current_user_config = {
    'total_epochs': Epochs,
    'batch_size': Batch_Size,
    'accumulation_steps': Accumulation_Steps,
    'lr': Learning_Rate
}

# Ejecutamos con la nueva funci√≥n segura
# (Nota: Aseg√∫rate de haber ejecutado la celda anterior con la nueva funci√≥n train_resumable_safe)
if 'model' in globals():
    model = train_resumable_safe(model, train_dataset, current_user_config)
else:
    print("‚ö†Ô∏è Primero debes instanciar el 'model' y cargar el 'train_dataset'.")

## **Prueba de fuego ü§î**
Ejecuta el modelo con un prompt de matem√°ticas en lenguaje natural, en ingl√©s.

In [None]:
# @title Funciones de carga, inferencia y seguridad üõ°Ô∏è
import torch
import os
from transformers import AutoTokenizer

# --- FUNCIONES AUXILIARES ---

def load_trained_model(device, use_latest=False):
    """
    Carga el modelo.
    use_latest=True  -> Carga el √∫ltimo checkpoint (ideal para ver progreso).
    use_latest=False -> Carga el mejor modelo validado (ideal para demos).
    """
    best_path = os.path.join(CHECKPOINT_DIR, "trm_gsm8k_best.pt")
    latest_path = os.path.join(CHECKPOINT_DIR, "trm_gsm8k_latest.pt")

    # L√≥gica de selecci√≥n
    if use_latest:
        path_to_load = latest_path if os.path.exists(latest_path) else best_path
        print("Build: Cargando versi√≥n M√ÅS RECIENTE (Latest)...")
    else:
        path_to_load = best_path if os.path.exists(best_path) else latest_path
        print("Build: Cargando versi√≥n MEJOR EVALUADA (Best)...")

    if not os.path.exists(path_to_load):
        raise FileNotFoundError(f"‚ùå No se encontr√≥ modelo en: {path_to_load}")

    checkpoint = torch.load(path_to_load, map_location=device)

    # 1. Recuperar Configuraci√≥n Guardada
    saved_config = checkpoint.get('config', None)

    # --- FIX: Inyecci√≥n de Configuraci√≥n ---
    # Para que MathTRM use la config guardada (ej. hidden_size=1024) y no la default.
    # Guardamos la funci√≥n original para restaurarla despu√©s.
    original_get_config = globals().get('get_math_config')

    if saved_config:
        print(f"‚öôÔ∏è Configuraci√≥n recuperada: H={saved_config.get('hidden_size')} | L_cycles={saved_config.get('L_cycles')}")

        # Creamos una funci√≥n temporal que devuelve TU configuraci√≥n guardada
        def patched_get_config(tokenizer):
            base_config = original_get_config(tokenizer)
            base_config.update(saved_config) # Sobreescribimos con lo guardado
            return base_config

        # Reemplazamos globalmente
        globals()['get_math_config'] = patched_get_config

    try:
        # Instanciamos el modelo (usar√° la config parcheada)
        model = MathTRM(tokenizer)
    finally:
        # Restauramos la funci√≥n original pase lo que pase
        if saved_config and original_get_config:
            globals()['get_math_config'] = original_get_config

    model.to(device)

    # 2. Cargar Pesos
    try:
        model.load_state_dict(checkpoint['model_state_dict'])
    except RuntimeError as e:
        print(f"‚ö†Ô∏è Error cargando pesos: {e}")
        print("Posible causa: Cambiaste la arquitectura (capas, tama√±os) y no coincide con el checkpoint.")
        raise e

    model.eval()
    return model, checkpoint['epoch']

def generate_answer(model, prompt, max_tokens=100, temperature=0.7):
    model.eval()
    device = next(model.parameters()).device
    formatted_prompt = f"Pregunta: {prompt}\nRespuesta:"

    inputs = tokenizer(formatted_prompt, return_tensors="pt").to(device)
    input_ids = inputs["input_ids"]
    generated_ids = input_ids

    with torch.no_grad():
        for _ in range(max_tokens):
            logits, _ = model(generated_ids, carry=None)
            next_token_logits = logits[:, -1, :]

            if temperature > 0:
                probs = torch.softmax(next_token_logits / temperature, dim=-1)
                next_token = torch.multinomial(probs, num_samples=1)
            else:
                next_token = torch.argmax(next_token_logits, dim=-1).unsqueeze(0)

            generated_ids = torch.cat([generated_ids, next_token], dim=1)
            if next_token.item() == tokenizer.eos_token_id:
                break

    full_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
    return full_text.split("Respuesta:")[-1].strip()

In [None]:
# @title üß™ Consola de Pruebas (Inferencia Mejorada) - Prompt de control
# --- INTERFAZ ---
# @markdown ### Par√°metros de Prueba
Prompt = "Natalia sold clips to 48 of her friends in April, and then she sold half as many clips in May. How many clips did Natalia sell altogether in April and May?" # @param {type:"string"}
Max_Tokens = 100 # @param {type:"slider", min:10, max:512, step:10}
Temperature = 0.1 # @param {type:"slider", min:0.0, max:1.0, step:0.1}
Version_Modelo = "Latest (Progreso actual)" # @param ["Best (Mejor hist√≥rico)", "Latest (Progreso actual)"]

# L√≥gica de carga
use_latest = (Version_Modelo == "Latest (Progreso actual)")
device = "cuda" if torch.cuda.is_available() else "cpu"

if 'tokenizer' not in globals():
    tokenizer = AutoTokenizer.from_pretrained("gpt2")
    tokenizer.pad_token = tokenizer.eos_token

# Cargar modelo (siempre recargamos para asegurar que sea la versi√≥n elegida)
# Nota: Si est√°s entrenando en la misma sesi√≥n, esto pausar√° brevemente el uso de GPU.
try:
    model_infer, epoch_loaded = load_trained_model(device, use_latest=use_latest)
    print(f"‚úÖ Modelo cargado (√âpoca {epoch_loaded})")

    print(f"\nüß† Generando respuesta...\n{'-'*30}")
    res = generate_answer(model_infer, Prompt, Max_Tokens, Temperature)
    print(f"üìù Pregunta: {Prompt}")
    print(f"üí° Respuesta:\n{res}")
    print(f"{'-'*30}")

    # Limpieza para liberar VRAM si es necesario volver a entrenar
    del model_infer
    torch.cuda.empty_cache()

except Exception as e:
    print(f"‚ùå Error: {e}")

In [None]:

# @title üß™ Consola de Pruebas
# --- INTERFAZ ---
# @markdown ### Par√°metros de Prueba
Prompt = "A farmer has 10 cows and 5 horses. If 2 cows escape, how many horses are left?" # @param {type:"string"}
Max_Tokens = 100 # @param {type:"slider", min:10, max:512, step:10}
Temperature = 0.1 # @param {type:"slider", min:0.0, max:1.0, step:0.1}
Version_Modelo = "Latest (Progreso actual)" # @param ["Best (Mejor hist√≥rico)", "Latest (Progreso actual)"]

# L√≥gica de carga
use_latest = (Version_Modelo == "Latest (Progreso actual)")
device = "cuda" if torch.cuda.is_available() else "cpu"

if 'tokenizer' not in globals():
    tokenizer = AutoTokenizer.from_pretrained("gpt2")
    tokenizer.pad_token = tokenizer.eos_token

# Cargar modelo (siempre recargamos para asegurar que sea la versi√≥n elegida)
# Nota: Si est√°s entrenando en la misma sesi√≥n, esto pausar√° brevemente el uso de GPU.
try:
    model_infer, epoch_loaded = load_trained_model(device, use_latest=use_latest)
    print(f"‚úÖ Modelo cargado (√âpoca {epoch_loaded})")

    print(f"\nüß† Generando respuesta...\n{'-'*30}")
    res = generate_answer(model_infer, Prompt, Max_Tokens, Temperature)
    print(f"üìù Pregunta: {Prompt}")
    print(f"üí° Respuesta:\n{res}")
    print(f"{'-'*30}")

    # Limpieza para liberar VRAM si es necesario volver a entrenar
    del model_infer
    torch.cuda.empty_cache()

except Exception as e:
    print(f"‚ùå Error: {e}")