In [1]:
print("hi")

hi


In [2]:
import exploretinyrm as m
m.__version__


'0.1.0'

In [3]:
# (Optional) make notebooks pick up code edits automatically without kernel restarts:
%load_ext autoreload
%autoreload 2

import torch
from exploretinyrm.utils import compute_tensor_summary

test_tensor = torch.randn(5, 3)
summary = compute_tensor_summary(test_tensor)
summary


{'mean': 0.10167226195335388,
 'standard_deviation': 0.8768734931945801,
 'minimum': -1.5688676834106445,
 'maximum': 1.3839704990386963}

In [4]:
import torch
import torch.nn.functional as F

import exploretinyrm as etm
from exploretinyrm import TRM, TRMConfig

print("ExploreTinyRM version:", etm.__version__)
torch.manual_seed(0)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device


ExploreTinyRM version: 0.1.0


device(type='cuda')

In [5]:
sequence_length = 32
config_attn = TRMConfig(
    input_vocab_size=256,
    output_vocab_size=256,
    seq_len=sequence_length,
    d_model=64,
    n_layers=2,
    use_attention=True,   # self-attention path
    n_heads=4,
    dropout=0.0,
    mlp_ratio=4.0,
    token_mlp_ratio=2.0,
    n=2,                  # inner loops
    T=2,                  # deep-recursion loops per supervision step
    k_last_ops=None       # backprop through all net-calls in final recursion
)

model_attn = TRM(config_attn).to(device)
total_parameters = sum(p.numel() for p in model_attn.parameters())
print(f"Parameters (attention TRM): {total_parameters:,}")


Parameters (attention TRM): 164,288


In [6]:
batch_size = 2
x_tokens = torch.randint(low=0, high=config_attn.input_vocab_size,
                         size=(batch_size, sequence_length), device=device)

# optional: explicit state init (the API will do this if you pass None)
y_state, z_state = model_attn.init_state(batch_size=batch_size, device=device)

y_next, z_next, logits, halt_logit = model_attn.forward_step(
    x_tokens, y_state, z_state
)

print("y_next:", y_next.shape, "requires_grad:", y_next.requires_grad)
print("z_next:", z_next.shape, "requires_grad:", z_next.requires_grad)
print("logits:", logits.shape)
print("halt_logit:", halt_logit.shape)

assert y_next.shape == (batch_size, sequence_length, config_attn.d_model)
assert z_next.shape == (batch_size, sequence_length, config_attn.d_model)
assert logits.shape == (batch_size, sequence_length, config_attn.output_vocab_size)
assert halt_logit.shape == (batch_size,)
assert y_next.requires_grad is False and z_next.requires_grad is False  # states are detached
print("Basic forward + shapes OK.")


y_next: torch.Size([2, 32, 64]) requires_grad: False
z_next: torch.Size([2, 32, 64]) requires_grad: False
logits: torch.Size([2, 32, 256])
halt_logit: torch.Size([2])
Basic forward + shapes OK.


In [7]:
# simple supervised loss on logits against random targets
target_tokens = torch.randint(
    low=0, high=config_attn.output_vocab_size,
    size=(batch_size, sequence_length), device=device
)
loss = F.cross_entropy(
    logits.view(-1, config_attn.output_vocab_size),
    target_tokens.view(-1)
)
loss.backward()

# confirm some gradients exist
num_with_grads = sum((p.grad is not None) and (p.grad.abs().sum() > 0) for p in model_attn.parameters())
print(f"Parameters with nonzero grads: {num_with_grads}")
assert num_with_grads > 0
model_attn.zero_grad(set_to_none=True)
print("Backward pass OK.")


Parameters with nonzero grads: 15
Backward pass OK.


In [8]:
sequence_length_mlp = 16  # mixer is best for short, fixed L; keep it small here
config_mlp = TRMConfig(
    input_vocab_size=128,
    output_vocab_size=128,
    seq_len=sequence_length_mlp,
    d_model=64,
    n_layers=2,
    use_attention=False,  # token-MLP path
    n_heads=4,            # ignored when use_attention=False
    dropout=0.0,
    mlp_ratio=4.0,
    token_mlp_ratio=2.0,
    n=2,
    T=2,
    k_last_ops=None
)

model_mlp = TRM(config_mlp).to(device)
x_tokens_mlp = torch.randint(0, config_mlp.input_vocab_size,
                             (batch_size, sequence_length_mlp), device=device)

y_next2, z_next2, logits2, halt_logit2 = model_mlp.forward_step(x_tokens_mlp)
print("MLP mixer run OK:",
      y_next2.shape, z_next2.shape, logits2.shape, halt_logit2.shape)


MLP mixer run OK: torch.Size([2, 16, 64]) torch.Size([2, 16, 64]) torch.Size([2, 16, 128]) torch.Size([2])


In [9]:
# Keep gradients only through the last 2 "net calls" of the final recursion.
# Each inner loop performs 2 net calls (z update, then y update).
config_trunc = TRMConfig(**{**config_attn.__dict__, "k_last_ops": 2})
model_trunc = TRM(config_trunc).to(device)

x_tokens_t = torch.randint(0, config_trunc.input_vocab_size,
                           (batch_size, sequence_length), device=device)
y_t, z_t, logits_t, halt_t = model_trunc.forward_step(x_tokens_t, k_last_ops=2)

loss_t = logits_t.pow(2).mean()  # any scalar loss
loss_t.backward()
print("Truncated-grad test OK (k_last_ops=2).")
model_trunc.zero_grad(set_to_none=True)


Truncated-grad test OK (k_last_ops=2).
