# NanoRL REINFORCE training loop (step-by-step)

This notebook mirrors the training script with explicit steps per cell.


## Imports


In [None]:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from vllm import SamplingParams

import gyllm
from gyllm.envs import AutoResetWrapper
from nanorl.agent import InstructAgent
from nanorl.rl import compute_reinforce_loss
from nanorl.rollout import NanoLLM
from nanorl.rollout.reporting import summarize_rollouts


## Configuration


In [None]:
model_id = "Qwen/Qwen2.5-3B-Instruct"
num_envs = 2
episodes = 4
num_updates = 3
minibatch_size = 2
lr = 1e-5
max_grad_norm = 1.0
temperature = 1.0
max_tokens = 128


## Model and NanoLLM


In [None]:
tokenizer = AutoTokenizer.from_pretrained(model_id)
if tokenizer.pad_token_id is None:
    tokenizer.pad_token_id = tokenizer.eos_token_id
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    dtype="bfloat16",
    device_map="cuda",
)
llm = NanoLLM(
    model,
    tokenizer=tokenizer,
    gpu_memory_utilization=0.4,
    enable_sleep_mode=True,
)


## Environment and agent


In [None]:
env = gyllm.make(
    "simple/reverse_echo",
    env_kwargs={"num_turns": 2},
    num_envs=num_envs,
)
env = AutoResetWrapper(env)

sampling_params = SamplingParams(temperature=temperature, max_tokens=max_tokens)
agent = InstructAgent(
    model=model,
    llm=llm,
    tokenizer=tokenizer,
    sampling_params=sampling_params,
)


## Optimizer


In [None]:
optimizer = torch.optim.AdamW(model.parameters(), lr=lr)


## Training loop


In [None]:
for update in range(num_updates):
    model.eval()
    with torch.no_grad():
        llm.wake_up()
        rollouts = agent.rollout_autoreset_batched(
            env,
            max_episodes=episodes,
        )
    llm.sleep(1)

    _tokens, mean_reward, _sample_text = summarize_rollouts(rollouts, tokenizer)

    model.train()
    optimizer.zero_grad(set_to_none=True)

    total_rollouts = len(rollouts)
    if total_rollouts == 0:
        print(f"update={update} skipped (no rollouts)")
        continue

    total_loss_value = 0.0
    total_assistant_tokens = 0.0
    total_logprob = 0.0
    reward_sum = 0.0

    for start in range(0, total_rollouts, minibatch_size):
        minibatch = rollouts[start : start + minibatch_size]
        loss, mb_metrics = compute_reinforce_loss(
            minibatch,
            model,
            tokenizer,
            device=next(model.parameters()).device,
        )
        mb_size = len(minibatch)
        reward_sum += mb_metrics["avg_reward"] * mb_size
        total_assistant_tokens += mb_metrics["assistant_tokens"]
        total_logprob += mb_metrics["avg_logprob"] * mb_metrics["assistant_tokens"]

        if mb_metrics["assistant_tokens"] <= 0:
            continue

        scale = mb_size / total_rollouts
        (loss * scale).backward()
        total_loss_value += float(loss.item()) * scale

    if total_assistant_tokens == 0:
        print(f"update={update} skipped (no assistant tokens)")
        continue

    torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)
    optimizer.step()

    avg_reward = reward_sum / total_rollouts
    avg_logprob = total_logprob / total_assistant_tokens
    print(
        f"update={update} loss={total_loss_value:.4f} avg_reward={avg_reward:.3f} "
        f"avg_logprob={avg_logprob:.3f}"
    )
