# Mixture-of-Recursions quickstart
This notebook mirrors the CLI demo in a lightweight, fully self-contained setting. It configures a tiny Mixture-of-Recursions (MoR) model, runs a few toy training iterations on synthetic data, and finally performs greedy decoding while exposing the router statistics collected during the forward pass.


In [None]:
import torch
from mixture_of_recursions import ModelConfig, RouterConfig, TrainConfig, KVConfig, train
from mixture_of_recursions.inference import generate
torch.manual_seed(42)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device


In [None]:
model_config = ModelConfig(
    vocab_size=128,
    d_model=128,
    n_heads=4,
    d_ff=256,
    n_layers_shared=1,
    max_recursions=3,
    dropout=0.1,
)
router_config = RouterConfig(
    type='token_choice',
    min_depth=1,
    max_depth=3,
    entropy_reg=0.01,
    target_depth=2.0,
    depth_penalty=0.05,
)
train_config = TrainConfig(
    seq_len=32,
    batch_size=2,
    steps=6,
    log_interval=2,
    device=str(device),
)
kv_config = KVConfig(mode='share_first')
model, val_loss = train(model_config, router_config, train_config, kv_config)
val_loss


In [None]:
prompt = 'Mixture-of-Recursions '
generated_text, token_depths = generate(model, prompt, tokenizer=None, max_new_tokens=16)
generated_text, token_depths


In [None]:
with torch.no_grad():
    batch = torch.randint(0, model_config.vocab_size, (1, train_config.seq_len + 1), device=device)
    stats = model(batch[:, :-1], labels=batch)
stats['router_avg_depth'].item(), stats['router_active'], stats['router_exits']
