From 46e64c694b808a7a764285c0d18627077589b8a3 Mon Sep 17 00:00:00 2001 From: Morgane Moss Date: Tue, 31 Mar 2026 19:38:05 +0000 Subject: [PATCH 01/48] =?UTF-8?q?feat:=20add=20Erd=C5=91s=20Discovery=20en?= =?UTF-8?q?vironment=20+=20entropic=20adaptive-=CE=B2=20advantage=20estima?= =?UTF-8?q?tor?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit TTT-Discover implementation (arXiv:2601.16175) for the Erdős Minimum Overlap Problem. New files: - nemo_rl/algorithms/entropic_advantage_estimator.py LOO entropic advantage with adaptive β via bisection. Solves for β such that KL(softmax_β(R) || uniform) = ln(2), then computes leave-one-out advantages w_i = exp(β·r_i)/Z_{-i} - 1. - nemo_rl/environments/erdos_discovery_environment.py Ray remote environment that calls the NeMo Gym Erdős resource server for sandboxed code execution and reward computation. Modified: - nemo_rl/algorithms/grpo.py: add entropic_adaptive_beta to advantage estimator factory + AdvEstimatorConfig TypedDict. - nemo_rl/environments/utils.py: register erdos_discovery in ENV_REGISTRY. --- .../entropic_advantage_estimator.py | 179 ++++++++++++ nemo_rl/algorithms/grpo.py | 15 +- .../erdos_discovery_environment.py | 256 ++++++++++++++++++ nemo_rl/environments/utils.py | 3 + 4 files changed, 451 insertions(+), 2 deletions(-) create mode 100644 nemo_rl/algorithms/entropic_advantage_estimator.py create mode 100644 nemo_rl/environments/erdos_discovery_environment.py diff --git a/nemo_rl/algorithms/entropic_advantage_estimator.py b/nemo_rl/algorithms/entropic_advantage_estimator.py new file mode 100644 index 0000000000..af78ae5637 --- /dev/null +++ b/nemo_rl/algorithms/entropic_advantage_estimator.py @@ -0,0 +1,179 @@ +"""Entropic Adaptive-Beta Advantage Estimator for TTT-Discover. + +Implements the Leave-One-Out (LOO) entropic advantage from +"Learning to Discover at Test Time" (arXiv:2601.16175). + +Instead of standard group-relative advantages (Adv = R - mean(R)), +this estimator: + 1. Solves for β such that KL(softmax_β(R) || uniform) = γ (default γ = ln(2)) + 2. Computes LOO advantages: w_i = exp(β·r_i) / Z_{-i} - 1 + where Z_{-i} is the normalizer excluding the i-th sample. + +Properties: + - Shift-invariant, approximately scale-invariant + - Monotone in reward + - Approximately mean-zero + - Adaptive scaling via β solves the reward-scale sensitivity of standard GRPO +""" + +import math +from typing import Optional + +import torch + + +def _solve_beta( + rewards: torch.Tensor, + gamma: float = math.log(2), + max_iter: int = 50, + tol: float = 1e-6, +) -> float: + """Solve for β such that KL(softmax_β(R) || uniform) = γ via bisection. + + Args: + rewards: [K] tensor of rewards for one group. + gamma: Target KL divergence. Default ln(2) as in the paper. + max_iter: Maximum bisection iterations. + tol: Convergence tolerance on β. + + Returns: + Scalar β value. + """ + K = rewards.shape[0] + if K <= 1: + return 0.0 + + log_K = math.log(K) + r = rewards.double() + r_max = r.max() + + def kl_at_beta(b: float) -> float: + logits = b * (r - r_max) + log_Z = torch.logsumexp(logits, dim=0) + logq = logits - log_Z + q = logq.exp() + kl = (q * (logq + log_K)).sum().item() + return kl + + # Bisect: KL is monotonically increasing in |β| for non-constant rewards + # Find upper bound for β + lo, hi = 0.0, 1.0 + while kl_at_beta(hi) < gamma and hi < 1e8: + hi *= 2.0 + + # Edge case: all rewards identical → β = 0, KL = 0 for any β + if hi >= 1e8: + return 0.0 + + for _ in range(max_iter): + mid = (lo + hi) / 2.0 + kl = kl_at_beta(mid) + if abs(kl - gamma) < tol: + return mid + if kl < gamma: + lo = mid + else: + hi = mid + + return (lo + hi) / 2.0 + + +def compute_entropic_advantages( + rewards: torch.Tensor, + gamma: float = math.log(2), + eps: float = 1e-8, +) -> torch.Tensor: + """Compute LOO entropic advantages for a group of rewards. + + Args: + rewards: [K] tensor of rewards for one group. + gamma: Target KL for adaptive β. + eps: Small constant for numerical stability. + + Returns: + [K] tensor of advantages. + """ + K = rewards.shape[0] + if K <= 1: + return torch.zeros_like(rewards) + + beta = _solve_beta(rewards, gamma=gamma) + if beta == 0.0: + return torch.zeros_like(rewards) + + r = rewards.double() + r_max = r.max() + e = torch.exp(beta * (r - r_max)) + + if K == 1: + Z_loo = e + else: + # Leave-one-out normalizer: Z_{-i} = (sum(e) - e_i) / (K - 1) + Z_loo = (e.sum() - e) / (K - 1) + + w = e / (Z_loo + eps) + advantages = (w - 1.0).to(rewards.dtype) + return advantages + + +class EntropicAdaptiveBetaAdvantageEstimator: + """Advantage estimator using entropic adaptive-β LOO weighting. + + Follows the same interface as GRPOAdvantageEstimator: + compute_advantage(prompt_ids, rewards, mask, **kwargs) -> [B, S] tensor + + Config keys (under grpo.adv_estimator): + gamma: Target KL for β search. Default ln(2) ≈ 0.693. + eps: Numerical stability constant. Default 1e-8. + """ + + def __init__(self, estimator_config: dict, loss_config: dict): + self.gamma = estimator_config.get("gamma", math.log(2)) + self.eps = estimator_config.get("eps", 1e-8) + + def compute_advantage( + self, + prompt_ids: torch.Tensor, + rewards: torch.Tensor, + mask: torch.Tensor, + **kwargs, + ) -> torch.Tensor: + """Compute per-token advantages using entropic adaptive-β LOO. + + Args: + prompt_ids: [B] or [B, S] tensor identifying which prompt each + sample belongs to (same prompt = same group). + rewards: [B] scalar rewards per sample. + mask: [B, S] response token mask (1 = generation token). + + Returns: + [B, S] advantages tensor. Each generation token gets the + sample-level advantage; non-generation tokens get 0. + """ + batch_size, seq_len = mask.shape + advantages = torch.zeros_like(mask, dtype=rewards.dtype) + + # Group by prompt (same as GRPO's per-prompt baseline) + if prompt_ids.dim() > 1: + # prompt_ids is [B, S] — use first token as group key + group_ids = prompt_ids[:, 0] + else: + group_ids = prompt_ids + + unique_prompts = group_ids.unique() + + for pid in unique_prompts: + group_mask = group_ids == pid + group_rewards = rewards[group_mask] + + group_adv = compute_entropic_advantages( + group_rewards, gamma=self.gamma, eps=self.eps + ) + + # Expand sample-level advantages to [group_size, seq_len] + # and mask to generation tokens only + group_indices = group_mask.nonzero(as_tuple=True)[0] + for i, idx in enumerate(group_indices): + advantages[idx] = group_adv[i] * mask[idx] + + return advantages diff --git a/nemo_rl/algorithms/grpo.py b/nemo_rl/algorithms/grpo.py index e550429ce2..58f722c653 100644 --- a/nemo_rl/algorithms/grpo.py +++ b/nemo_rl/algorithms/grpo.py @@ -122,14 +122,17 @@ class AsyncGRPOConfig(TypedDict): class AdvEstimatorConfig(TypedDict): - """Configuration for advantage estimator (GRPO, GDPO, or Reinforce++).""" + """Configuration for advantage estimator (GRPO, GDPO, Reinforce++, or Entropic).""" - name: str # "grpo", "gdpo", or "reinforce_plus_plus" + name: str # "grpo", "gdpo", "reinforce_plus_plus", or "entropic_adaptive_beta" # GRPO specific normalize_rewards: NotRequired[bool] use_leave_one_out_baseline: NotRequired[bool] # Reinforce++ specific minus_baseline: NotRequired[bool] + # Entropic Adaptive-Beta specific (TTT-Discover, arXiv:2601.16175) + gamma: NotRequired[float] # Target KL for beta search; default ln(2) + eps: NotRequired[float] # Numerical stability; default 1e-8 class GRPOConfig(TypedDict): @@ -1066,6 +1069,14 @@ def _create_advantage_estimator(master_config: MasterConfig): adv_estimator_config, loss_config ) print(" ✓ Using Reinforce++ advantage estimator") + elif adv_estimator_name == "entropic_adaptive_beta": + from nemo_rl.algorithms.entropic_advantage_estimator import ( + EntropicAdaptiveBetaAdvantageEstimator, + ) + adv_estimator = EntropicAdaptiveBetaAdvantageEstimator( + adv_estimator_config, loss_config + ) + print(" ✓ Using Entropic Adaptive-Beta advantage estimator (TTT-Discover)") else: raise ValueError(f"Invalid adv_estimator name: {adv_estimator_name}") diff --git a/nemo_rl/environments/erdos_discovery_environment.py b/nemo_rl/environments/erdos_discovery_environment.py new file mode 100644 index 0000000000..37d82541ee --- /dev/null +++ b/nemo_rl/environments/erdos_discovery_environment.py @@ -0,0 +1,256 @@ +"""Erdős Discovery Environment for NeMo RL. + +Implements EnvironmentInterface for TTT-Discover with the Erdős Minimum +Overlap Problem. Calls the NeMo Gym resource server for code execution +and reward computation. + +Reference: "Learning to Discover at Test Time" (arXiv:2601.16175) + +The environment: + 1. Receives LLM-generated code from the GRPO rollout + 2. Sends it to the Erdős Gym resource server for sandboxed execution + scoring + 3. Returns reward = 1/bound (or 0 on failure) + 4. Tracks best constructions and buffer statistics via metrics +""" + +import logging +import math +from typing import Any, Optional + +import aiohttp +import ray +import torch + +from nemo_rl.environments.interfaces import EnvironmentInterface, EnvironmentReturn + +logger = logging.getLogger(__name__) + +# Type alias matching NeMo RL's convention +LLMMessageLogType = list[dict[str, Any]] +ErdosMetadata = dict[str, Any] + + +@ray.remote(max_restarts=-1, max_task_retries=-1) +class ErdosDiscoveryEnvironment(EnvironmentInterface[ErdosMetadata]): + """Erdős Minimum Overlap Problem environment for GRPO training. + + Communicates with the NeMo Gym Erdős resource server via HTTP for: + - /verify: code execution + reward computation + - /select_state: PUCT state selection for prompts + - /seed_session: buffer initialization + - /compute_entropic_advantages: LOO entropic advantages + - /update_buffer: add new discoveries to PUCT tree + + Config (under env.erdos_discovery): + resource_server_url: Base URL of the Erdős Gym resource server. + seed: Random seed for PUCT buffer initialization. + num_initial_states: States to seed the buffer with. + sandbox_timeout: Code execution timeout in seconds. + """ + + def __init__(self, config: dict): + self.config = config + self.resource_server_url = config.get( + "resource_server_url", "http://localhost:8080" + ) + self.seed = config.get("seed", None) + self.num_initial_states = config.get("num_initial_states", 16) + self.sandbox_timeout = config.get("sandbox_timeout", 600) + self.request_timeout = config.get("request_timeout", 660) + + self.best_reward = 0.0 + self.best_bound = float("inf") + self.total_verified = 0 + self.total_valid = 0 + self._session_initialized = False + + async def _ensure_session(self): + """Initialize the PUCT buffer on the resource server if not done.""" + if self._session_initialized: + return + try: + timeout = aiohttp.ClientTimeout(total=30) + async with aiohttp.ClientSession(timeout=timeout) as session: + async with session.post( + f"{self.resource_server_url}/seed_session", + json={ + "num_initial_states": self.num_initial_states, + "seed": self.seed, + }, + ) as resp: + data = await resp.json() + self.best_reward = data.get("best_initial_reward", 0.0) + self.best_bound = data.get( + "best_initial_bound", float("inf") + ) + logger.info( + "ErdosDiscovery: seeded buffer with %d states, " + "best_reward=%.4f, best_bound=%.6f", + data.get("num_states", 0), + self.best_reward, + self.best_bound, + ) + self._session_initialized = True + except Exception as e: + logger.error("ErdosDiscovery: seed_session failed: %s", e) + + async def _verify_single( + self, + session: aiohttp.ClientSession, + response_text: str, + parent_state: Optional[list[float]] = None, + ) -> dict: + """Call /verify on the resource server for one response.""" + # Build a minimal NeMoGymResponse-like payload + # The resource server extracts output_text from response.output_text + body = { + "responses_create_params": { + "input": [{"role": "user", "content": ""}], + }, + "response": { + "id": "verify", + "output": [ + { + "type": "message", + "role": "assistant", + "content": [{"type": "output_text", "text": response_text}], + } + ], + "output_text": response_text, + }, + "parent_state": parent_state, + } + try: + timeout = aiohttp.ClientTimeout(total=self.request_timeout) + async with session.post( + f"{self.resource_server_url}/verify", + json=body, + timeout=timeout, + ) as resp: + return await resp.json() + except Exception as e: + logger.warning("ErdosDiscovery: verify failed: %s", e) + return {"reward": 0.0, "bound": None, "error_msg": str(e)} + + def step( + self, + message_log_batch: list[LLMMessageLogType], + metadata: list[ErdosMetadata], + ) -> EnvironmentReturn[ErdosMetadata]: + """Evaluate a batch of LLM responses. + + Extracts the assistant's last message from each conversation, + sends it to the resource server for code execution + scoring, + returns rewards. + """ + import asyncio + + return asyncio.get_event_loop().run_until_complete( + self._async_step(message_log_batch, metadata) + ) + + async def _async_step( + self, + message_log_batch: list[LLMMessageLogType], + metadata: list[ErdosMetadata], + ) -> EnvironmentReturn[ErdosMetadata]: + await self._ensure_session() + + batch_size = len(message_log_batch) + rewards = torch.zeros(batch_size) + terminateds = torch.ones(batch_size) # Always single-turn + observations = [{}] * batch_size + answers = [None] * batch_size + updated_metadata = list(metadata) + + timeout = aiohttp.ClientTimeout(total=self.request_timeout) + async with aiohttp.ClientSession(timeout=timeout) as session: + import asyncio + + tasks = [] + for i, message_log in enumerate(message_log_batch): + # Extract the last assistant message + response_text = "" + for msg in reversed(message_log): + if msg.get("role") == "assistant": + response_text = msg.get("content", "") + break + + # Get parent_state from metadata if available + parent_state = None + if metadata and i < len(metadata): + parent_state = metadata[i].get("parent_state", None) + + tasks.append( + self._verify_single(session, response_text, parent_state) + ) + + results = await asyncio.gather(*tasks, return_exceptions=True) + + for i, result in enumerate(results): + if isinstance(result, Exception): + logger.warning( + "ErdosDiscovery: verify exception for sample %d: %s", + i, + result, + ) + continue + + reward = result.get("reward", 0.0) + rewards[i] = reward + self.total_verified += 1 + + if reward > 0: + self.total_valid += 1 + bound = result.get("bound", None) + if reward > self.best_reward: + self.best_reward = reward + self.best_bound = bound or ( + 1.0 / reward if reward > 0 else float("inf") + ) + + answers[i] = ( + f"bound={bound:.6f}" if bound else f"reward={reward:.4f}" + ) + + # Update metadata with verification results + if i < len(updated_metadata): + updated_metadata[i] = { + **updated_metadata[i], + "reward": reward, + "bound": result.get("bound"), + "error_msg": result.get("error_msg", ""), + "best_reward_ever": result.get( + "best_reward_ever", self.best_reward + ), + } + + return EnvironmentReturn( + observations=observations, + metadata=updated_metadata, + next_stop_strings=[None] * batch_size, + rewards=rewards, + terminateds=terminateds, + answers=answers, + ) + + def global_post_process_and_metrics( + self, batch: dict + ) -> tuple[dict, dict]: + """Compute and return environment-level metrics.""" + valid_rate = ( + self.total_valid / max(self.total_verified, 1) + ) + metrics = { + "env/best_reward": self.best_reward, + "env/best_bound": self.best_bound + if self.best_bound < float("inf") + else 0.0, + "env/total_verified": self.total_verified, + "env/valid_rate": valid_rate, + } + return batch, metrics + + def shutdown(self): + """Cleanup.""" + pass diff --git a/nemo_rl/environments/utils.py b/nemo_rl/environments/utils.py index df82c7d1af..a1bc6ace3f 100644 --- a/nemo_rl/environments/utils.py +++ b/nemo_rl/environments/utils.py @@ -50,6 +50,9 @@ class EnvRegistryEntry(TypedDict, total=False): "vlm": { "actor_class_fqn": "nemo_rl.environments.vlm_environment.VLMEnvironment", }, + "erdos_discovery": { + "actor_class_fqn": "nemo_rl.environments.erdos_discovery_environment.ErdosDiscoveryEnvironment", + }, "nemo_gym": { "actor_class_fqn": "nemo_rl.environments.nemo_gym.NemoGym", }, From bb43f23ba17d754c480821217bb0cb8850956054 Mon Sep 17 00:00:00 2001 From: Morgane Moss Date: Tue, 31 Mar 2026 19:54:29 +0000 Subject: [PATCH 02/48] =?UTF-8?q?docs:=20add=20PUCT=20buffer=20utility,=20?= =?UTF-8?q?Erd=C5=91s=20discovery=20README,=20update=20main=20README?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - nemo_rl/utils/puct_buffer.py: General-purpose PUCT tree buffer for iterative optimization environments. Reusable across any task that needs exploration/exploitation state selection. - nemo_rl/environments/ERDOS_DISCOVERY.md: Full documentation for the TTT-Discover integration — architecture diagram, component locations, config examples, hyperparameters from the paper. - README.md: Added Advantage Estimators (including entropic adaptive-β) and PUCT Buffer to the Features section. --- README.md | 2 + nemo_rl/environments/ERDOS_DISCOVERY.md | 156 +++++++ nemo_rl/utils/puct_buffer.py | 561 ++++++++++++++++++++++++ 3 files changed, 719 insertions(+) create mode 100644 nemo_rl/environments/ERDOS_DISCOVERY.md create mode 100644 nemo_rl/utils/puct_buffer.py diff --git a/README.md b/README.md index f90091e5af..e6431e6136 100644 --- a/README.md +++ b/README.md @@ -104,6 +104,8 @@ For detailed information on backend selection, configuration, and examples, see - ✅ **Environment Support and Isolation** - Support for multi-environment training and dependency isolation between components. - ✅ **Worker Isolation** - Process isolation between RL Actors (no worries about global state). - ✅ **Learning Algorithms** - GRPO/GSPO/DAPO, SFT(with LoRA), DPO, and On-policy distillation. +- ✅ **Advantage Estimators** - Group-relative (GRPO), multi-reward (GDPO), Reinforce++, and [Entropic Adaptive-β](nemo_rl/algorithms/entropic_advantage_estimator.py) (LOO entropic weighting from [TTT-Discover](https://arxiv.org/abs/2601.16175)). +- ✅ **PUCT Buffer** - [Tree-structured state selection](nemo_rl/utils/puct_buffer.py) for iterative optimization environments (exploration/exploitation via Upper Confidence bounds). - ✅ **Multi-Turn RL** - Multi-turn generation and training for RL with tool use, games, etc. - ✅ **Advanced Parallelism with DTensor** - PyTorch FSDP2, TP, CP, and SP for efficient training (through NeMo AutoModel). - ✅ **Larger Model Support with Longer Sequences** - Performant parallelisms with Megatron Core (TP/PP/CP/SP/EP/FSDP) (through NeMo Megatron Bridge). diff --git a/nemo_rl/environments/ERDOS_DISCOVERY.md b/nemo_rl/environments/ERDOS_DISCOVERY.md new file mode 100644 index 0000000000..938d7b9d9e --- /dev/null +++ b/nemo_rl/environments/ERDOS_DISCOVERY.md @@ -0,0 +1,156 @@ +# TTT-Discover: Learning to Discover at Test Time + +Implementation of [TTT-Discover](https://arxiv.org/abs/2601.16175) using NeMo RL +for GRPO training and NeMo Gym for environment scoring. + +## Overview + +TTT-Discover is a framework for using LLMs to make mathematical discoveries +through iterative refinement. The model generates candidate solutions (as code), +which are scored and organized in a tree structure (PUCT buffer). Training uses +entropic adaptive-β advantages instead of standard GRPO group-relative baselines. + +The first application is the **Erdős Minimum Overlap Problem**: finding step +functions that minimize the upper bound on the Erdős constant `c` (known bounds: +`0.379005 < c < 0.380927`). + +## Architecture + +``` +┌─────────────────────────────────────────────────┐ +│ NeMo RL (training) │ +│ │ +│ GRPO Loop: │ +│ 1. Dataloader → prompts from PUCT buffer │ +│ 2. vLLM generates code completions │ +│ 3. Environment returns rewards │ +│ 4. Entropic adaptive-β advantage estimator │ +│ 5. Policy gradient update │ +│ 6. Weight sync to vLLM │ +│ │ +│ Components in nemo_rl/: │ +│ algorithms/entropic_advantage_estimator.py │ +│ environments/erdos_discovery_environment.py │ +│ utils/puct_buffer.py │ +└─────────────┬───────────────────────────────────┘ + │ HTTP (POST /verify, /select_state, ...) + │ +┌─────────────▼───────────────────────────────────┐ +│ NeMo Gym (environment) │ +│ │ +│ Erdős Resource Server: │ +│ - Sandboxed Python code execution │ +│ - FFT-based bound computation │ +│ - Constraint validation (len, range, mean) │ +│ - reward = 1 / bound │ +│ - PUCT buffer state management │ +│ - Prompt formatting from tree context │ +│ │ +│ Location: Gym/resources_servers/erdos_discovery │ +└─────────────────────────────────────────────────┘ +``` + +## Key Components + +### Entropic Adaptive-β Advantage Estimator + +Location: `nemo_rl/algorithms/entropic_advantage_estimator.py` + +Replaces standard GRPO group-relative advantages with Leave-One-Out (LOO) +entropic weighting: + +1. **Solve for β**: Find β such that `KL(softmax_β(R) || uniform) = ln(2)` + via bisection search. +2. **LOO advantages**: `w_i = exp(β·r_i) / Z_{-i} - 1` where `Z_{-i}` is the + leave-one-out normalizer (excludes sample `i`). + +Properties: shift-invariant, approximately scale-invariant, monotone, ~zero-mean. + +Config: +```yaml +grpo: + adv_estimator: + name: entropic_adaptive_beta + gamma: 0.6931 # ln(2), target KL divergence +``` + +### PUCT Buffer + +Location: `nemo_rl/utils/puct_buffer.py` + +Tree-structured state selection using Predictor + Upper Confidence bounds for +Trees (PUCT). Balances exploitation (high-reward states) with exploration +(under-visited branches). + +Score: `Q(s) + c · P(s) · √(1+T) / (1+n(s))` + +- `Q(s)`: best reward reachable from state `s` +- `P(s)`: rank-based prior +- `n(s)`: visit count +- `T`: total visits +- `c`: exploration constant (default 1.0) + +This is a **general utility** — usable by any iterative optimization environment, +not just Erdős. + +### Erdős Discovery Environment + +Location: `nemo_rl/environments/erdos_discovery_environment.py` + +Ray remote actor implementing `EnvironmentInterface`. Calls the NeMo Gym Erdős +resource server for sandboxed code execution and reward computation. + +Config: +```yaml +env: + erdos_discovery: + resource_server_url: http://localhost:8080 + num_initial_states: 16 + sandbox_timeout: 600 +``` + +### Erdős Gym Resource Server + +Location: `Gym/resources_servers/erdos_discovery/` + +Standalone FastAPI server handling: +- `/verify`: execute code, validate `f`, compute `reward = 1/bound` +- `/seed_session`: initialize PUCT buffer with random states +- `/select_state`: PUCT-select states for next training batch +- `/update_buffer`: add discoveries to the tree + +## Hyperparameters (from the paper) + +| Parameter | Value | Description | +|-----------|-------|-------------| +| model | 120B MoE | gpt-oss-120b-bf16 with LoRA r=32 | +| group_size | 64 | Rollouts per initial state | +| groups_per_batch | 8 | PUCT-selected states per step | +| epochs | 50 | Training steps | +| lr | 4e-5 | Learning rate | +| context_window | 32768 | Max tokens | +| kl_penalty | 0.1 | KL penalty coefficient | +| puct_c | 1.0 | PUCT exploration constant | +| entropic γ | ln(2) | Target KL for adaptive β | +| sandbox_timeout | 600s | Code execution limit | + +## Running + +_Run script coming soon. Will follow the `research/template_project/` pattern._ + +```bash +# 1. Start the Gym resource server +cd ~/Gym +ng_run "+config_paths=[resources_servers/erdos_discovery/configs/erdos_discovery.yaml]" + +# 2. Run GRPO training with TTT-Discover +cd ~/RL +python research/ttt_discover/run_discover.py \ + --config research/ttt_discover/configs/erdos_120b.yaml +``` + +## References + +- Yu Sun et al., "Learning to Discover at Test Time" (arXiv:2601.16175), 2026. +- Reference implementation: https://github.com/test-time-training/discover +- Haugland (2016) for prior SOTA bound 0.380927. diff --git a/nemo_rl/utils/puct_buffer.py b/nemo_rl/utils/puct_buffer.py new file mode 100644 index 0000000000..53f808d783 --- /dev/null +++ b/nemo_rl/utils/puct_buffer.py @@ -0,0 +1,561 @@ +""" +PUCT buffer for TTT-Discover state reuse. + +Reference: "Learning to Discover at Test Time" (arXiv:2601.04116) + +The buffer maintains a tree of (state, reward) nodes. At each training step, +PUCT scoring selects which states to warm-start rollouts from, balancing: + - Exploitation: states whose children have achieved high rewards (high Q) + - Exploration: states that haven't been visited much yet (low n) + +Pure data structure — no ML framework dependencies. +""" + +import math +import dataclasses +from typing import Any, Optional + +import numpy as np + + +# --------------------------------------------------------------------------- +# Internal node +# --------------------------------------------------------------------------- + +@dataclasses.dataclass +class _Node: + state: Any + reward: float # reward of THIS state (from its own evaluation) + parent_key: Any # key of parent node, or None for roots + children_keys: list # keys of direct children + n: int # visit count (number of times selected for expansion) + Q: float # max reward among all descendants (or own reward if leaf) + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +def _make_key(state: Any) -> Any: + """Convert state to a hashable key. + + Supports: str, int, float, tuple, list, np.ndarray, and arbitrary objects + (fallback: id-based, so two different objects with equal content are + treated as distinct — acceptable for LLM response strings). + """ + if isinstance(state, (str, int, float, bool)): + return state + if isinstance(state, np.ndarray): + return (state.dtype, state.shape, state.tobytes()) + if isinstance(state, (list, tuple)): + return tuple(_make_key(x) for x in state) + # Fallback: identity-based key — wrap id so it doesn't collide with ints + return ("__id__", id(state)) + + +# --------------------------------------------------------------------------- +# PUCTBuffer +# --------------------------------------------------------------------------- + +class PUCTBuffer: + """ + Tree-structured buffer with PUCT selection. + + PUCT score for node s: + score(s) = Q(s) + c · P(s) · sqrt(1 + T) / (1 + n(s)) + + Where: + Q(s) = max reward among all descendants of s (own reward if leaf) + P(s) = rank-based prior: rank states by reward, normalize by total rank + n(s) = visit count of s + T = total visit count across all nodes + c = exploration constant (default 1.0) + """ + + def __init__(self, c: float = 1.0) -> None: + self.c = c + self._nodes: dict[Any, _Node] = {} # key → _Node + self._T: int = 0 # total expansions so far + + # ------------------------------------------------------------------ + # Public API + # ------------------------------------------------------------------ + + def add(self, state: Any, reward: float, parent_state: Any = None) -> None: + """Insert a new node into the buffer. + + If the state is already present, this is a no-op (deduplication). + If parent_state is given and present in the buffer, the new node is + linked as a child and Q values are propagated upward. + + Args: + state: The state to insert (any type with a consistent identity). + reward: Scalar reward associated with this state. + parent_state: Parent state, or None for a root node. + """ + key = _make_key(state) + if key in self._nodes: + return # already present — deduplicate + + parent_key = _make_key(parent_state) if parent_state is not None else None + node = _Node( + state=state, + reward=float(reward), + parent_key=parent_key, + children_keys=[], + n=0, + Q=float(reward), # leaf: Q = own reward + ) + self._nodes[key] = node + + if parent_key is not None and parent_key in self._nodes: + self._nodes[parent_key].children_keys.append(key) + self._propagate_Q(parent_key) + + def select( + self, batch_size: int, num_groups: int = 8 + ) -> list[tuple[Any, list]]: + """Select states to warm-start rollouts from. + + Scores each node with PUCT, picks the top `num_groups` distinct states, + and returns `batch_size` (state, context) pairs grouped so that each + group of `batch_size // num_groups` entries shares the same state. + + Context is the ancestry path from root to the selected node: + [(ancestor_state, ancestor_reward), ..., (selected_state, selected_reward)] + The env uses this to build the prompt (previous attempts / warm start). + + Visit counts are incremented for the selected nodes, and T is updated. + + Args: + batch_size: Total number of (state, context) pairs to return. + Must be divisible by num_groups. + num_groups: Number of distinct initial states to select. + + Returns: + List of (state, context) tuples, length == batch_size. + """ + if not self._nodes: + raise ValueError("Buffer is empty — call add() before select()") + if batch_size % num_groups != 0: + raise ValueError( + f"batch_size ({batch_size}) must be divisible by num_groups ({num_groups})" + ) + rollouts_per_group = batch_size // num_groups + + priors = self._rank_priors() + scores = { + key: self._puct_score(node, priors[key]) + for key, node in self._nodes.items() + } + + # Top num_groups keys by PUCT score (at most len(nodes) if buffer is small) + k = min(num_groups, len(self._nodes)) + top_keys = sorted(scores, key=lambda x: scores[x], reverse=True)[:k] + + result: list[tuple[Any, list]] = [] + for key in top_keys: + node = self._nodes[key] + context = self._ancestry(key) + pair = (node.state, context) + result.extend([pair] * rollouts_per_group) + # Increment visit count for this selection + node.n += 1 + self._T += 1 + + return result + + def update( + self, parent_state: Any, child_state: Any, reward: float + ) -> None: + """Add a child node and update Q values up the tree. + + Convenience wrapper around add() that makes the parent/child + relationship explicit. + + Args: + parent_state: The state that was selected and rolled out from. + child_state: The resulting new state produced by the rollout. + reward: Reward of the new child state. + """ + self.add(child_state, reward, parent_state=parent_state) + + def best(self) -> tuple[Any, float]: + """Return the (state, reward) with the highest reward ever seen. + + Returns: + (state, reward) tuple. + """ + if not self._nodes: + raise ValueError("Buffer is empty") + best_key = max(self._nodes, key=lambda k: self._nodes[k].reward) + node = self._nodes[best_key] + return node.state, node.reward + + def __len__(self) -> int: + return len(self._nodes) + + # ------------------------------------------------------------------ + # Private helpers + # ------------------------------------------------------------------ + + def _puct_score(self, node: _Node, prior: float) -> float: + return node.Q + self.c * prior * math.sqrt(1 + self._T) / (1 + node.n) + + def _rank_priors(self) -> dict[Any, float]: + """Rank-based prior: rank by node reward, normalize by sum of ranks. + + Rank 1 = lowest reward, rank N = highest. Ties get the same rank + (average of tied ranks), consistent with scipy.stats.rankdata. + """ + keys = list(self._nodes.keys()) + rewards = np.array([self._nodes[k].reward for k in keys], dtype=float) + + # argsort twice gives rank (0-indexed); add 1 to make 1-indexed + order = np.argsort(rewards) + ranks = np.empty_like(order, dtype=float) + ranks[order] = np.arange(1, len(rewards) + 1, dtype=float) + + # Handle ties: assign average rank to tied rewards. + # Use ranks[tied].mean() — not tied.mean()+1, which would use array + # indices instead of the already-assigned rank values. + # (simple O(N²) loop is fine for buffer sizes we care about) + for i, r in enumerate(rewards): + tied = np.where(rewards == r)[0] + if len(tied) > 1: + ranks[tied] = ranks[tied].mean() + + total = ranks.sum() + return {k: float(ranks[i] / total) for i, k in enumerate(keys)} + + def _propagate_Q(self, key: Any) -> None: + """Propagate max-Q upward from `key` to the root.""" + node = self._nodes[key] + if node.children_keys: + child_rewards = [ + self._nodes[ck].Q + for ck in node.children_keys + if ck in self._nodes + ] + new_Q = max(node.reward, max(child_rewards)) if child_rewards else node.reward + else: + new_Q = node.reward + + if new_Q == node.Q: + return # no change — stop propagation + + node.Q = new_Q + if node.parent_key is not None and node.parent_key in self._nodes: + self._propagate_Q(node.parent_key) + + def _ancestry(self, key: Any) -> list[tuple[Any, float]]: + """Return the path from root to `key` as [(state, reward), ...].""" + path = [] + cur = key + while cur is not None: + node = self._nodes[cur] + path.append((node.state, node.reward)) + cur = node.parent_key + path.reverse() + return path + + +# --------------------------------------------------------------------------- +# Unit tests +# --------------------------------------------------------------------------- + +def _run_tests() -> None: + import sys + + failures: list[str] = [] + + def check(name: str, cond: bool, msg: str = "") -> None: + if not cond: + failures.append(f"FAIL [{name}]: {msg}") + else: + print(f" PASS [{name}]") + + print("=== puct_buffer unit tests ===\n") + + # ------------------------------------------------------------------ + # Basic add / best + # ------------------------------------------------------------------ + print("-- add / best --") + + buf = PUCTBuffer(c=1.0) + buf.add("s0", 0.5) + buf.add("s1", 0.8) + buf.add("s2", 0.3) + + state, reward = buf.best() + check("best_returns_max_reward_state", reward == 0.8, f"reward={reward}") + check("best_returns_correct_state", state == "s1", f"state={state!r}") + check("len_after_adds", len(buf) == 3, f"len={len(buf)}") + + # Duplicate add is a no-op + buf.add("s0", 99.0) + check("duplicate_add_noop", len(buf) == 3, "duplicate changed buffer size") + check("duplicate_reward_unchanged", buf._nodes[_make_key("s0")].reward == 0.5) + + # ------------------------------------------------------------------ + # Q uses MAX not mean + # ------------------------------------------------------------------ + print("\n-- Q = MAX not mean --") + + buf2 = PUCTBuffer() + buf2.add("root", 0.0) + buf2.add("child_low", 0.1, parent_state="root") + buf2.add("child_high", 0.9, parent_state="root") + + root_node = buf2._nodes[_make_key("root")] + check( + "Q_is_max_not_mean", + root_node.Q == 0.9, + f"root.Q={root_node.Q}, expected 0.9 (max), mean would be 0.5", + ) + + # Add another child with even higher reward — Q should update + buf2.add("child_best", 0.95, parent_state="root") + check( + "Q_updates_when_better_child_added", + root_node.Q == 0.95, + f"root.Q={root_node.Q}, expected 0.95", + ) + + # ------------------------------------------------------------------ + # Q propagates through grandchildren (MAX of all descendants) + # ------------------------------------------------------------------ + print("\n-- Q propagation --") + + buf3 = PUCTBuffer() + buf3.add("r", 0.0) + buf3.add("c1", 0.3, parent_state="r") + buf3.add("gc", 0.99, parent_state="c1") # grandchild + + r_node = buf3._nodes[_make_key("r")] + c1_node = buf3._nodes[_make_key("c1")] + check("grandchild_Q_propagates_to_child", c1_node.Q == 0.99, f"c1.Q={c1_node.Q}") + check("grandchild_Q_propagates_to_root", r_node.Q == 0.99, f"r.Q={r_node.Q}") + + # Parent with high own reward should NOT lose Q when children underperform + buf3b = PUCTBuffer() + buf3b.add("great_parent", 0.9) + buf3b.add("weak_child", 0.2, parent_state="great_parent") + gp_node = buf3b._nodes[_make_key("great_parent")] + check( + "parent_Q_not_lowered_by_weak_child", + gp_node.Q == 0.9, + f"great_parent.Q={gp_node.Q}, expected 0.9 (own reward dominates)", + ) + + # ------------------------------------------------------------------ + # Rank priors: ties get correct average rank (not index-based) + # ------------------------------------------------------------------ + print("\n-- rank prior tie handling --") + + buf_ties = PUCTBuffer() + # rewards: s0=0.1 (rank 1), s1=0.5 (tied), s2=0.3 (rank 2), s3=0.5 (tied) + # After tie-averaging: s0→1, s2→2, s1&s3→(3+4)/2=3.5 + buf_ties.add("s0", 0.1) + buf_ties.add("s1", 0.5) + buf_ties.add("s2", 0.3) + buf_ties.add("s3", 0.5) + priors_ties = buf_ties._rank_priors() + p1 = priors_ties[_make_key("s1")] + p3 = priors_ties[_make_key("s3")] + p2 = priors_ties[_make_key("s2")] + check("tied_states_equal_prior", abs(p1 - p3) < 1e-9, f"p1={p1:.6f} p3={p3:.6f}") + check("tied_states_outrank_lower", p1 > p2, f"tied={p1:.4f} vs s2={p2:.4f}") + + # ------------------------------------------------------------------ + # update() convenience wrapper + # ------------------------------------------------------------------ + print("\n-- update() --") + + buf4 = PUCTBuffer() + buf4.add("p", 0.5) + buf4.update("p", "child_via_update", 0.7) + check("update_adds_child", len(buf4) == 2, f"len={len(buf4)}") + check("update_links_child", "child_via_update" in [ + buf4._nodes[ck].state for ck in buf4._nodes[_make_key("p")].children_keys + ]) + + # ------------------------------------------------------------------ + # Exploration: unvisited high-reward states get selected + # ------------------------------------------------------------------ + print("\n-- exploration: unvisited high-reward states --") + + buf5 = PUCTBuffer(c=1.0) + # Old state, visited many times + buf5.add("visited", 0.6) + buf5._nodes[_make_key("visited")].n = 100 + # New high-reward state, never visited + buf5.add("fresh_high", 0.9) + + selected = buf5.select(batch_size=2, num_groups=2) + selected_states = [s for s, _ in selected] + check( + "unvisited_high_reward_selected", + "fresh_high" in selected_states, + f"selected states: {selected_states}", + ) + + # ------------------------------------------------------------------ + # Exploitation: Q(parent) rises after adding a high-reward child, making + # the parent score higher than a sibling with no children. + # We verify PUCT scores directly — not via select() — because select() + # would correctly pick the child itself (even better warm-start). + # ------------------------------------------------------------------ + print("\n-- exploitation: high-Q parent outscores peer --") + + buf6 = PUCTBuffer(c=0.01) # low exploration → scores dominated by Q + buf6.add("peer_no_children", 0.5) + buf6.add("parent_explored", 0.5) + # Give parent_explored a great child: Q should propagate to 0.99 + buf6.add("great_child_2", 0.99, parent_state="parent_explored") + + priors6 = buf6._rank_priors() + pk_peer = _make_key("peer_no_children") + pk_parent = _make_key("parent_explored") + score_peer = buf6._puct_score(buf6._nodes[pk_peer], priors6[pk_peer]) + score_parent = buf6._puct_score(buf6._nodes[pk_parent], priors6[pk_parent]) + + check( + "parent_Q_raised_by_great_child", + buf6._nodes[pk_parent].Q == 0.99, + f"parent.Q={buf6._nodes[pk_parent].Q}", + ) + check( + "high_Q_parent_outscores_peer", + score_parent > score_peer, + f"score_parent={score_parent:.4f}, score_peer={score_peer:.4f}", + ) + + # ------------------------------------------------------------------ + # select() group structure + # ------------------------------------------------------------------ + print("\n-- select() group structure --") + + buf7 = PUCTBuffer() + for i in range(10): + buf7.add(f"s{i}", float(i) / 10) + + result = buf7.select(batch_size=16, num_groups=4) + check("select_total_length", len(result) == 16, f"len={len(result)}") + + # Each group of 4 should share the same state + groups_of_4 = [result[i*4:(i+1)*4] for i in range(4)] + for gi, group in enumerate(groups_of_4): + states_in_group = [s for s, _ in group] + check( + f"group_{gi}_same_state", + len(set(states_in_group)) == 1, + f"group {gi} has mixed states: {states_in_group}", + ) + + # Each group should have a DIFFERENT initial state from the others + group_states = [group[0][0] for group in groups_of_4] + check( + "groups_have_distinct_states", + len(set(group_states)) == 4, + f"group states: {group_states}", + ) + + # ------------------------------------------------------------------ + # select() raises on batch_size not divisible by num_groups + # ------------------------------------------------------------------ + print("\n-- select() error handling --") + + buf8 = PUCTBuffer() + buf8.add("x", 1.0) + try: + buf8.select(batch_size=7, num_groups=3) + check("indivisible_batch_raises", False, "should have raised ValueError") + except ValueError: + check("indivisible_batch_raises", True) + + # select() on empty buffer raises + buf_empty = PUCTBuffer() + try: + buf_empty.select(batch_size=4, num_groups=2) + check("empty_buffer_select_raises", False, "should have raised ValueError") + except ValueError: + check("empty_buffer_select_raises", True) + + # ------------------------------------------------------------------ + # Context (ancestry path) + # ------------------------------------------------------------------ + print("\n-- context / ancestry path --") + + buf9 = PUCTBuffer() + buf9.add("root", 0.1) + buf9.add("child", 0.5, parent_state="root") + buf9.add("grand", 0.9, parent_state="child") + + # Force select to pick "grand" by making it best by far + buf9._nodes[_make_key("grand")].reward = 10.0 + buf9._propagate_Q(_make_key("child")) + buf9._propagate_Q(_make_key("root")) + + result9 = buf9.select(batch_size=1, num_groups=1) + state9, context9 = result9[0] + check("context_is_list", isinstance(context9, list)) + check( + "context_starts_at_root", + context9[0][0] == "root", + f"context[0]={context9[0]}", + ) + check( + "context_ends_at_selected", + context9[-1][0] == state9, + f"context[-1]={context9[-1]}, state={state9!r}", + ) + check( + "context_length_equals_depth", + len(context9) == 3, + f"len={len(context9)}, expected 3", + ) + + # ------------------------------------------------------------------ + # Visit count increments on select + # ------------------------------------------------------------------ + print("\n-- visit count tracking --") + + buf10 = PUCTBuffer() + buf10.add("a", 0.5) + buf10.add("b", 0.6) + n_before_a = buf10._nodes[_make_key("a")].n + buf10.select(batch_size=4, num_groups=2) + T_after = buf10._T + check("T_incremented_by_num_groups", T_after == 2, f"T={T_after}") + total_n = sum(n.n for n in buf10._nodes.values()) + check("total_n_equals_T", total_n == T_after, f"sum(n)={total_n}, T={T_after}") + + # ------------------------------------------------------------------ + # numpy array states + # ------------------------------------------------------------------ + print("\n-- numpy array states --") + + buf11 = PUCTBuffer() + arr_a = np.array([0.1, 0.5, 0.4]) + arr_b = np.array([0.3, 0.3, 0.4]) + buf11.add(arr_a, 0.7) + buf11.add(arr_b, 0.9) + check("numpy_states_len", len(buf11) == 2, f"len={len(buf11)}") + best_s, best_r = buf11.best() + check("numpy_best_reward", best_r == 0.9, f"best_r={best_r}") + check("numpy_best_state", np.array_equal(best_s, arr_b), f"best_s={best_s}") + + # ------------------------------------------------------------------ + print() + if failures: + for f in failures: + print(f) + print(f"\n{len(failures)} test(s) FAILED") + import sys; sys.exit(1) + else: + print("All tests passed.") + + +if __name__ == "__main__": + _run_tests() From 224d76e2cf5746451847d2bfaa4ae24f4f1f8406 Mon Sep 17 00:00:00 2001 From: Morgane Moss Date: Tue, 31 Mar 2026 20:04:16 +0000 Subject: [PATCH 03/48] =?UTF-8?q?feat:=20add=20TTT-Discover=20run=20script?= =?UTF-8?q?=20and=20config=20for=20Erd=C5=91s=20GRPO=20training?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Run script (examples/run_discover.py): - DiscoverDataset: IterableDataset backed by PUCT buffer via HTTP. Calls /select_state on the Gym resource server each step to get dynamically selected states, generates DatumSpecs with tokenized prompts and parent_state metadata. - setup_discover_data(): Wires dataset + ErdosDiscoveryEnvironment Ray actor, returns the (dataset, env) tuple for grpo_train(). - Follows the sliding_puzzle pattern for custom env integration. Config (examples/configs/grpo_erdos_discover.yaml): - entropic_adaptive_beta advantage estimator (gamma=ln(2)) - 8 groups × 64 rollouts = 512 trajectories per step - LoRA r=32, lr=4e-5, 50 training steps - KL penalty 0.1, importance sampling loss - 32K context window for long code generation --- examples/configs/grpo_erdos_discover.yaml | 87 ++++++ examples/run_discover.py | 329 ++++++++++++++++++++++ 2 files changed, 416 insertions(+) create mode 100644 examples/configs/grpo_erdos_discover.yaml create mode 100644 examples/run_discover.py diff --git a/examples/configs/grpo_erdos_discover.yaml b/examples/configs/grpo_erdos_discover.yaml new file mode 100644 index 0000000000..59be71b1b5 --- /dev/null +++ b/examples/configs/grpo_erdos_discover.yaml @@ -0,0 +1,87 @@ +# TTT-Discover GRPO config for Erdős Minimum Overlap Problem. +# Reference: arXiv:2601.16175 +# +# Usage: +# uv run python examples/run_discover.py --config examples/configs/grpo_erdos_discover.yaml +# +# Requires the Gym Erdős resource server running separately: +# cd ~/Gym && ng_run "+config_paths=[resources_servers/erdos_discovery/configs/erdos_discovery.yaml]" + +seed: 42 + +grpo: + num_prompts_per_step: 8 # PUCT selects 8 states per step + num_generations_per_prompt: 64 # 64 rollouts per state = 512 total + max_num_epochs: 1 + max_num_steps: 50 # 50 training steps (paper default) + max_rollout_turns: 1 # Single-turn: generate code, get reward + remove_constant_reward_groups: true + adv_estimator: + name: entropic_adaptive_beta + gamma: 0.6931471805599453 # ln(2) + +loss_fn: + kl_penalty_coef: 0.1 # KL penalty (paper uses 0.1) + ratio_clip: 0.2 + token_level_loss: false # Sequence-level policy ratio + importance_sampling: true + +policy: + model_name: "gpt-oss-120b-bf16" + tokenizer: "gpt-oss-120b-bf16" + max_total_sequence_length: 32768 + train_global_batch_size: 512 # 8 groups × 64 rollouts + train_micro_batch_size: 4 + + dtensor_cfg: + enabled: true + tensor_parallel_size: 1 + sequence_parallel: false + + lora_cfg: + enabled: true + rank: 32 + alpha: 1.0 + dropout: 0.0 + +optimizer: + name: adamw + lr: 4.0e-5 + betas: [0.9, 0.999] + weight_decay: 0.01 + scheduler: + name: cosine + warmup_steps: 2 + min_lr_ratio: 0.1 + +data: + shuffle: false # PUCT handles selection order + # No static dataset — DiscoverDataset generates prompts dynamically + +env: + erdos_discovery: + resource_server_url: "http://localhost:8080" + num_initial_states: 16 + num_groups_per_step: 8 + sandbox_timeout: 600 + request_timeout: 660 + +cluster: + gpus_per_node: 8 + num_nodes: 2 # 2 training nodes + +generation: + backend: vllm + colocated: false # Inference on separate nodes + temperature: 1.0 + top_p: 1.0 + max_new_tokens: 16384 # Long context for code generation + +checkpointing: + enabled: true + checkpoint_dir: "results/ttt-discover-erdos" + save_frequency: 5 # Every 5 steps + +wandb: + enabled: true + project: "ttt-discover-erdos" diff --git a/examples/run_discover.py b/examples/run_discover.py new file mode 100644 index 0000000000..f0e19a5bd3 --- /dev/null +++ b/examples/run_discover.py @@ -0,0 +1,329 @@ +"""Run script for TTT-Discover GRPO training on the Erdős Minimum Overlap Problem. + +This follows the sliding_puzzle pattern: custom IterableDataset that generates +prompts dynamically from a PUCT buffer, wired into the standard GRPO loop. + +Usage: + # Start the Gym resource server first (separate process/node): + cd ~/Gym && ng_run "+config_paths=[resources_servers/erdos_discovery/configs/erdos_discovery.yaml]" + + # Then run training: + cd ~/RL && uv run python examples/run_discover.py [--config examples/configs/grpo_erdos_discover.yaml] + +Reference: "Learning to Discover at Test Time" (arXiv:2601.16175) +""" + +import itertools +import logging +import sys +from typing import Optional + +import aiohttp +import asyncio +import numpy as np +import ray +import torch +from torch.utils.data import IterableDataset + +from nemo_rl.algorithms.grpo import MasterConfig, grpo_train, setup +from nemo_rl.algorithms.utils import get_tokenizer, set_seed +from nemo_rl.data.interfaces import DatumSpec, LLMMessageLogType +from nemo_rl.distributed.virtual_cluster import init_ray +from nemo_rl.environments.erdos_discovery_environment import ( + ErdosDiscoveryEnvironment, +) +from nemo_rl.models.generation import configure_generation_config + +logger = logging.getLogger(__name__) + +# ═══════════════════════════════════════════════════════════════════ +# Problem description (same as in the Gym resource server) +# ═══════════════════════════════════════════════════════════════════ + +PROBLEM_DESCRIPTION = """\ +Erdos Minimum Overlap Problem +============================== + +Goal: Find a step function f (Python list or NumPy array) giving the +tightest possible upper bound on the Erdos minimum overlap constant c. + +Background: + For integer n, partition {1,...,2n} into equal sets A, B. + M_k = #{(a,b) : a in A, b in B, a-b=k}. + c = lim_{n->inf} min_{A,B} max_k M_k / n. + +Known bounds: 0.379005 < c < 0.380927 (Haugland 2016) +Current best upper bound: 0.380876 (2026) + +Upper Bound via Step Functions: + f : [0,1] -> [0,1] with mean(f) = 0.5 gives: + bound = 2*n*max(autocorr(f)) / sum(f)^2 + where autocorr is computed via FFT. + Smaller bound -> higher reward (reward = 1/bound). + +Constraints: 1 <= len(f) <= 1000, 0 <= f[i] <= 1, mean(f) ~ 0.5 (tol 1e-3). + +Output: Python code defining variable `f` in a ```python block. +Allowed: numpy, math, random, itertools, functools, collections. +Execution limit: 600 seconds. Target: bound < 0.380876.\ +""" + + +# ═══════════════════════════════════════════════════════════════════ +# Datum generation +# ═══════════════════════════════════════════════════════════════════ + + +def generate_discover_datum( + tokenizer, + state_info: dict, + idx: int, + task_name: str = "erdos_discovery", +) -> DatumSpec: + """Create a DatumSpec from a PUCT-selected state. + + Args: + tokenizer: HuggingFace tokenizer. + state_info: Dict from /select_state with keys: + state, context, reward, system_prompt, user_prompt. + idx: Datum index. + task_name: Task name for env routing. + + Returns: + DatumSpec ready for the GRPO training loop. + """ + system_prompt = state_info.get("system_prompt", PROBLEM_DESCRIPTION) + user_prompt = state_info["user_prompt"] + + messages: LLMMessageLogType = [ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": user_prompt}, + ] + + # Tokenize the prompt + prompt_text = tokenizer.apply_chat_template( + messages, tokenize=False, add_generation_prompt=True + ) + prompt_ids = tokenizer.encode(prompt_text, add_special_tokens=False) + prompt_tensor = torch.tensor(prompt_ids, dtype=torch.long) + + # Attach token_ids to messages for NeMo RL's message_log format + for msg in messages: + msg_text = tokenizer.apply_chat_template( + [msg], tokenize=False, add_generation_prompt=False + ) + msg_ids = tokenizer.encode(msg_text, add_special_tokens=False) + msg["token_ids"] = torch.tensor(msg_ids, dtype=torch.long) + + return DatumSpec( + message_log=messages, + length=len(prompt_ids), + extra_env_info={ + "parent_state": state_info.get("state"), + "context": state_info.get("context"), + "reward": state_info.get("reward", 0.0), + }, + loss_multiplier=1.0, + idx=idx, + task_name=task_name, + ) + + +# ═══════════════════════════════════════════════════════════════════ +# Dynamic dataset backed by PUCT buffer +# ═══════════════════════════════════════════════════════════════════ + + +class DiscoverDataset(IterableDataset): + """Iterable dataset that fetches prompts from the PUCT buffer each step. + + Each iteration fetches `num_groups_per_step` states from the Gym resource + server's /select_state endpoint and yields them as DatumSpecs. + + The dataset loops indefinitely — the training loop controls termination + via max_num_steps in the GRPO config. + """ + + def __init__( + self, + tokenizer, + resource_server_url: str, + num_groups_per_step: int = 8, + task_name: str = "erdos_discovery", + length: int = 1000, # Nominal length for dataloader + ): + self.tokenizer = tokenizer + self.resource_server_url = resource_server_url + self.num_groups_per_step = num_groups_per_step + self.task_name = task_name + self.length = length + self._idx_counter = itertools.count() + + def _fetch_states_sync(self) -> list[dict]: + """Synchronously fetch states from the PUCT buffer.""" + import requests + + try: + resp = requests.post( + f"{self.resource_server_url}/select_state", + json={ + "batch_size": self.num_groups_per_step, + "num_groups": self.num_groups_per_step, + }, + timeout=30, + ) + resp.raise_for_status() + data = resp.json() + return data.get("states", []) + except Exception as e: + logger.error("Failed to fetch states from PUCT buffer: %s", e) + # Return fallback: single default prompt + return [ + { + "state": [0.5] * 50, + "context": [], + "reward": 0.5, + "system_prompt": PROBLEM_DESCRIPTION, + "user_prompt": ( + "Starting construction (bound=2.000000, 50 pieces):\n" + "[0.5000, 0.5000, ..., 0.5000]\n\n" + "Improve on this construction. Write Python code that " + "defines a better step function `f`. Think carefully." + ), + } + ] + + def __iter__(self): + for _ in itertools.count(): + states = self._fetch_states_sync() + for state_info in states: + idx = next(self._idx_counter) + yield generate_discover_datum( + self.tokenizer, + state_info, + idx=idx, + task_name=self.task_name, + ) + + def __len__(self): + return self.length + + +# ═══════════════════════════════════════════════════════════════════ +# Setup +# ═══════════════════════════════════════════════════════════════════ + + +def setup_discover_data(config: MasterConfig, tokenizer): + """Create dataset, environment, and wire them together. + + Returns: + (train_dataset, val_dataset, task_to_env, val_task_to_env) + """ + env_config = config.get("env", {}).get("erdos_discovery", {}) + resource_server_url = env_config.get( + "resource_server_url", "http://localhost:8080" + ) + num_groups_per_step = env_config.get("num_groups_per_step", 8) + task_name = "erdos_discovery" + + # Create the dynamic dataset + train_dataset = DiscoverDataset( + tokenizer=tokenizer, + resource_server_url=resource_server_url, + num_groups_per_step=num_groups_per_step, + task_name=task_name, + length=config["grpo"]["max_num_steps"] * num_groups_per_step, + ) + + # Validation dataset: same thing (could be a fixed set, but for discovery + # we just re-sample from the buffer) + val_dataset = DiscoverDataset( + tokenizer=tokenizer, + resource_server_url=resource_server_url, + num_groups_per_step=num_groups_per_step, + task_name=task_name, + length=num_groups_per_step, + ) + + # Create the environment as a Ray actor + env = ErdosDiscoveryEnvironment.options( + num_gpus=0, + max_restarts=-1, + max_task_retries=-1, + ).remote(config=env_config) + + task_to_env = {task_name: env} + val_task_to_env = {task_name: env} + + return train_dataset, val_dataset, task_to_env, val_task_to_env + + +# ═══════════════════════════════════════════════════════════════════ +# Main +# ═══════════════════════════════════════════════════════════════════ + + +def main(): + import yaml + from pathlib import Path + + # Load config + config_path = sys.argv[1] if len(sys.argv) > 1 else str( + Path(__file__).parent / "configs" / "grpo_erdos_discover.yaml" + ) + + if config_path.startswith("--config="): + config_path = config_path.split("=", 1)[1] + elif config_path == "--config" and len(sys.argv) > 2: + config_path = sys.argv[2] + + print(f"Loading config from: {config_path}") + with open(config_path) as f: + config: MasterConfig = yaml.safe_load(f) + + # Initialize Ray + init_ray(config) + set_seed(config.get("seed", 42)) + + # Tokenizer + tokenizer = get_tokenizer(config["policy"]["tokenizer"]) + + # Generation config + configure_generation_config(config) + + # Setup data + environment + train_dataset, val_dataset, task_to_env, val_task_to_env = ( + setup_discover_data(config, tokenizer) + ) + + # Setup policy, generation backend, cluster, etc. + ( + cluster, + policy, + generation, + train_dataloader, + val_dataloader, + ) = setup( + config=config, + tokenizer=tokenizer, + train_dataset=train_dataset, + val_dataset=val_dataset, + ) + + # Run GRPO training + grpo_train( + master_config=config, + policy=policy, + generation=generation, + cluster=cluster, + wrapped_dataloader=train_dataloader, + val_wrapped_dataloader=val_dataloader, + task_to_env=task_to_env, + val_task_to_env=val_task_to_env, + tokenizer=tokenizer, + ) + + +if __name__ == "__main__": + main() From c2f9e7529d540cf553e7ecca1aef7fc1fd232a03 Mon Sep 17 00:00:00 2001 From: Morgane Moss Date: Tue, 31 Mar 2026 22:35:08 +0000 Subject: [PATCH 04/48] =?UTF-8?q?feat:=20add=20debug=20config=20and=20SLUR?= =?UTF-8?q?M=20launcher=20for=20Erd=C5=91s=20TTT-Discover?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - grpo_erdos_discover_debug.yaml: Qwen3-1.7B, single node, 4×8=32 trajectories, 5 steps, 4K seq len, colocated vLLM. For testing the full pipeline before scaling up. - erdos_debug.slurm: Starts Gym resource server in background on the same node, then runs NeMo RL GRPO training. --- erdos_debug.slurm | 56 +++++++++++++ .../configs/grpo_erdos_discover_debug.yaml | 83 +++++++++++++++++++ 2 files changed, 139 insertions(+) create mode 100644 erdos_debug.slurm create mode 100644 examples/configs/grpo_erdos_discover_debug.yaml diff --git a/erdos_debug.slurm b/erdos_debug.slurm new file mode 100644 index 0000000000..39adf3845e --- /dev/null +++ b/erdos_debug.slurm @@ -0,0 +1,56 @@ +#!/bin/bash +#SBATCH --job-name=erdos-debug +#SBATCH --output=logs/erdos-debug-%j.out +#SBATCH --error=logs/erdos-debug-%j.err +#SBATCH --nodes=1 +#SBATCH --ntasks-per-node=1 +#SBATCH --exclusive +#SBATCH --gpus-per-task=8 +#SBATCH --cpus-per-task=64 +#SBATCH --exclude=d2dfac12-001,d2dfac12-002,d2dfac12-004,d2dfac12-007,d2dfac12-008,d2dfac12-019,d2dfac12-027,d2dfac12-028,d2dfac12-029 + +# TTT-Discover debug run — single node, Qwen3-1.7B, 5 training steps. +# Tests the full pipeline: PUCT → prompts → vLLM → sandbox → reward → entropic advantages → gradient + +set -euo pipefail +mkdir -p logs + +echo "Node: $(hostname)" +echo "GPUs: $(nvidia-smi -L | wc -l)" + +# ── Gym resource server (background) ───────────────────────────── +# Start the Erdős Gym server on this node (CPU-only, no GPU needed). +# It handles code execution, reward computation, and PUCT state management. +cd /home/mormio/Gym + +echo "Starting Gym Erdős resource server on port 8080..." +# TODO: adjust this to your Gym environment setup +# For now, run the FastAPI server directly: +PYTHONPATH=/home/mormio/Gym:$PYTHONPATH \ +nohup python -m uvicorn \ + resources_servers.erdos_discovery.app:app \ + --host 0.0.0.0 --port 8080 \ + > logs/erdos-gym-${SLURM_JOB_ID}.log 2>&1 & +GYM_PID=$! +echo "Gym server PID: $GYM_PID" + +# Wait for Gym server to be ready +for i in $(seq 1 30); do + if curl -s http://localhost:8080/docs > /dev/null 2>&1; then + echo "Gym server ready" + break + fi + sleep 2 +done + +# ── NeMo RL training ───────────────────────────────────────────── +cd /home/mormio/RL + +echo "Starting TTT-Discover GRPO debug training..." +uv run python examples/run_discover.py \ + examples/configs/grpo_erdos_discover_debug.yaml + +echo "Training complete" + +# Cleanup +kill $GYM_PID 2>/dev/null || true diff --git a/examples/configs/grpo_erdos_discover_debug.yaml b/examples/configs/grpo_erdos_discover_debug.yaml new file mode 100644 index 0000000000..5f95b94915 --- /dev/null +++ b/examples/configs/grpo_erdos_discover_debug.yaml @@ -0,0 +1,83 @@ +# TTT-Discover DEBUG config — small model, short runs, verify plumbing. +# +# Usage: +# # Terminal 1: Start Gym resource server +# cd ~/Gym && ng_run "+config_paths=[resources_servers/erdos_discovery/configs/erdos_discovery.yaml]" +# +# # Terminal 2: Run debug training +# cd ~/RL && uv run python examples/run_discover.py examples/configs/grpo_erdos_discover_debug.yaml + +seed: 42 + +grpo: + num_prompts_per_step: 4 # 4 states per step (small batch) + num_generations_per_prompt: 8 # 8 rollouts per state = 32 total + max_num_epochs: 1 + max_num_steps: 5 # Just 5 steps to verify pipeline + max_rollout_turns: 1 + remove_constant_reward_groups: true + adv_estimator: + name: entropic_adaptive_beta + gamma: 0.6931471805599453 + +loss_fn: + kl_penalty_coef: 0.1 + ratio_clip: 0.2 + token_level_loss: false + importance_sampling: true + +policy: + model_name: "/home/shared/models/Qwen3-1.7B" + tokenizer: "/home/shared/models/Qwen3-1.7B" + max_total_sequence_length: 4096 + train_global_batch_size: 32 # 4 groups × 8 rollouts + train_micro_batch_size: 4 + + dtensor_cfg: + enabled: true + tensor_parallel_size: 1 + sequence_parallel: false + + lora_cfg: + enabled: true + rank: 16 + alpha: 1.0 + dropout: 0.0 + +optimizer: + name: adamw + lr: 1.0e-4 + betas: [0.9, 0.999] + weight_decay: 0.01 + scheduler: + name: cosine + warmup_steps: 1 + min_lr_ratio: 0.1 + +data: + shuffle: false + +env: + erdos_discovery: + resource_server_url: "http://localhost:8080" + num_initial_states: 8 + num_groups_per_step: 4 + sandbox_timeout: 60 # Shorter timeout for debug + request_timeout: 120 + +cluster: + gpus_per_node: 8 + num_nodes: 1 # Single node for debug + +generation: + backend: vllm + colocated: true # Colocated on same GPUs for debug + temperature: 1.0 + top_p: 1.0 + max_new_tokens: 2048 # Shorter for debug + +checkpointing: + enabled: false # No checkpointing for debug + +wandb: + enabled: false # No wandb for debug From a6c624607ad55c9bc439cd7aa5db734a8b03fc2e Mon Sep 17 00:00:00 2001 From: Morgane Moss Date: Tue, 31 Mar 2026 22:46:55 +0000 Subject: [PATCH 05/48] feat: add inline mode for ErdosDiscoveryEnvironment, simplify debug launch - Environment supports resource_server_url="inline" which runs the code sandbox and reward computation directly in-process, no Gym server dependency needed. - Debug config updated to use inline mode. - SLURM script simplified: no Gym server startup, just uv run. --- erdos_debug.slurm | 38 +----- .../configs/grpo_erdos_discover_debug.yaml | 28 ++--- .../erdos_discovery_environment.py | 116 +++++++++++++++++- 3 files changed, 129 insertions(+), 53 deletions(-) diff --git a/erdos_debug.slurm b/erdos_debug.slurm index 39adf3845e..5248781469 100644 --- a/erdos_debug.slurm +++ b/erdos_debug.slurm @@ -9,48 +9,22 @@ #SBATCH --cpus-per-task=64 #SBATCH --exclude=d2dfac12-001,d2dfac12-002,d2dfac12-004,d2dfac12-007,d2dfac12-008,d2dfac12-019,d2dfac12-027,d2dfac12-028,d2dfac12-029 -# TTT-Discover debug run — single node, Qwen3-1.7B, 5 training steps. -# Tests the full pipeline: PUCT → prompts → vLLM → sandbox → reward → entropic advantages → gradient +# TTT-Discover debug run — single node, Qwen3-1.7B, inline reward, 5 steps. +set -eo pipefail -set -euo pipefail mkdir -p logs echo "Node: $(hostname)" echo "GPUs: $(nvidia-smi -L | wc -l)" +echo "Job ID: $SLURM_JOB_ID" -# ── Gym resource server (background) ───────────────────────────── -# Start the Erdős Gym server on this node (CPU-only, no GPU needed). -# It handles code execution, reward computation, and PUCT state management. -cd /home/mormio/Gym +# No Gym server needed — using inline mode for debug. +# The environment computes rewards directly in-process. -echo "Starting Gym Erdős resource server on port 8080..." -# TODO: adjust this to your Gym environment setup -# For now, run the FastAPI server directly: -PYTHONPATH=/home/mormio/Gym:$PYTHONPATH \ -nohup python -m uvicorn \ - resources_servers.erdos_discovery.app:app \ - --host 0.0.0.0 --port 8080 \ - > logs/erdos-gym-${SLURM_JOB_ID}.log 2>&1 & -GYM_PID=$! -echo "Gym server PID: $GYM_PID" - -# Wait for Gym server to be ready -for i in $(seq 1 30); do - if curl -s http://localhost:8080/docs > /dev/null 2>&1; then - echo "Gym server ready" - break - fi - sleep 2 -done - -# ── NeMo RL training ───────────────────────────────────────────── cd /home/mormio/RL -echo "Starting TTT-Discover GRPO debug training..." +echo "Starting TTT-Discover GRPO debug training (inline mode)..." uv run python examples/run_discover.py \ examples/configs/grpo_erdos_discover_debug.yaml echo "Training complete" - -# Cleanup -kill $GYM_PID 2>/dev/null || true diff --git a/examples/configs/grpo_erdos_discover_debug.yaml b/examples/configs/grpo_erdos_discover_debug.yaml index 5f95b94915..e8473217ff 100644 --- a/examples/configs/grpo_erdos_discover_debug.yaml +++ b/examples/configs/grpo_erdos_discover_debug.yaml @@ -1,19 +1,15 @@ -# TTT-Discover DEBUG config — small model, short runs, verify plumbing. +# TTT-Discover DEBUG config — small model, short runs, inline reward (no Gym server). # # Usage: -# # Terminal 1: Start Gym resource server -# cd ~/Gym && ng_run "+config_paths=[resources_servers/erdos_discovery/configs/erdos_discovery.yaml]" -# -# # Terminal 2: Run debug training # cd ~/RL && uv run python examples/run_discover.py examples/configs/grpo_erdos_discover_debug.yaml seed: 42 grpo: - num_prompts_per_step: 4 # 4 states per step (small batch) - num_generations_per_prompt: 8 # 8 rollouts per state = 32 total + num_prompts_per_step: 4 + num_generations_per_prompt: 8 max_num_epochs: 1 - max_num_steps: 5 # Just 5 steps to verify pipeline + max_num_steps: 5 max_rollout_turns: 1 remove_constant_reward_groups: true adv_estimator: @@ -30,7 +26,7 @@ policy: model_name: "/home/shared/models/Qwen3-1.7B" tokenizer: "/home/shared/models/Qwen3-1.7B" max_total_sequence_length: 4096 - train_global_batch_size: 32 # 4 groups × 8 rollouts + train_global_batch_size: 32 train_micro_batch_size: 4 dtensor_cfg: @@ -59,25 +55,25 @@ data: env: erdos_discovery: - resource_server_url: "http://localhost:8080" + resource_server_url: "inline" # No Gym server — compute reward in-process num_initial_states: 8 num_groups_per_step: 4 - sandbox_timeout: 60 # Shorter timeout for debug + sandbox_timeout: 60 request_timeout: 120 cluster: gpus_per_node: 8 - num_nodes: 1 # Single node for debug + num_nodes: 1 generation: backend: vllm - colocated: true # Colocated on same GPUs for debug + colocated: true temperature: 1.0 top_p: 1.0 - max_new_tokens: 2048 # Shorter for debug + max_new_tokens: 2048 checkpointing: - enabled: false # No checkpointing for debug + enabled: false wandb: - enabled: false # No wandb for debug + enabled: false diff --git a/nemo_rl/environments/erdos_discovery_environment.py b/nemo_rl/environments/erdos_discovery_environment.py index 37d82541ee..70c4710fca 100644 --- a/nemo_rl/environments/erdos_discovery_environment.py +++ b/nemo_rl/environments/erdos_discovery_environment.py @@ -25,6 +25,97 @@ logger = logging.getLogger(__name__) + +# ═══════════════════════════════════════════════════════════════════ +# Inline reward computation (no Gym server needed for debug/testing) +# ═══════════════════════════════════════════════════════════════════ + +def _inline_compute_reward(response_text: str, timeout: int = 60) -> dict: + """Compute reward directly in-process. No HTTP call needed.""" + import re + import signal + import builtins + import math as _math + import itertools as _itertools + import functools as _functools + import collections as _collections + + import numpy as _np + from numpy.fft import rfft, irfft + + _ALLOWED_MODULES = frozenset({ + "numpy", "np", "math", "cmath", "random", + "itertools", "functools", "collections", "fractions", "decimal", + }) + _SAFE_BUILTIN_NAMES = [ + "abs", "all", "any", "bool", "dict", "divmod", "enumerate", + "filter", "float", "format", "int", "isinstance", "issubclass", + "iter", "len", "list", "map", "max", "min", "next", "object", + "print", "range", "repr", "reversed", "round", "set", "slice", + "sorted", "str", "sum", "tuple", "type", "zip", + "Exception", "ValueError", "TypeError", "KeyError", "IndexError", + "StopIteration", "RuntimeError", "NotImplementedError", + "OverflowError", "ZeroDivisionError", "AttributeError", + ] + + # Extract code + code_re = re.compile(r"```(?:python)?\s*\n(.*?)```", re.DOTALL) + blocks = code_re.findall(response_text) + code = blocks[-1].strip() if blocks else response_text.strip() + + # Build sandbox + import random as _random + safe_builtins = {k: getattr(builtins, k) for k in _SAFE_BUILTIN_NAMES if hasattr(builtins, k)} + def _safe_import(name, globals=None, locals=None, fromlist=(), level=0): + if name.split(".")[0] not in _ALLOWED_MODULES: + raise ImportError(f"Module '{name}' not allowed") + return builtins.__import__(name, globals, locals, fromlist, level) + safe_builtins["__import__"] = _safe_import + namespace = { + "__builtins__": safe_builtins, + "np": _np, "numpy": _np, "math": _math, "random": _random, + "itertools": _itertools, "functools": _functools, "collections": _collections, + } + + try: + class _Timeout(Exception): + pass + def _handler(s, f): + raise _Timeout() + old = signal.signal(signal.SIGALRM, _handler) + signal.alarm(timeout) + try: + exec(compile(code, "", "exec"), namespace) + finally: + signal.alarm(0) + signal.signal(signal.SIGALRM, old) + + if "f" not in namespace: + return {"reward": 0.0, "bound": None, "error_msg": "no variable f"} + + f = _np.asarray(namespace["f"], dtype=float).flatten() + + # Validate + if len(f) < 1 or len(f) > 1000: + return {"reward": 0.0, "bound": None, "error_msg": "bad length"} + if _np.any(~_np.isfinite(f)) or _np.any(f < 0) or _np.any(f > 1): + return {"reward": 0.0, "bound": None, "error_msg": "bad values"} + if abs(float(_np.mean(f)) - 0.5) > 1e-3: + return {"reward": 0.0, "bound": None, "error_msg": "bad mean"} + + # Compute bound + n = len(f) + F = rfft(f, n=2*n) + autocorr = irfft(F * _np.conj(F), n=2*n) + bound = float(2 * n * _np.max(autocorr.real) / (_np.sum(f)**2)) + + if bound <= 0 or not _math.isfinite(bound): + return {"reward": 0.0, "bound": None, "error_msg": "bad bound"} + return {"reward": 1.0 / bound, "bound": bound, "error_msg": ""} + + except Exception as e: + return {"reward": 0.0, "bound": None, "error_msg": str(e)[:200]} + # Type alias matching NeMo RL's convention LLMMessageLogType = list[dict[str, Any]] ErdosMetadata = dict[str, Any] @@ -63,6 +154,10 @@ def __init__(self, config: dict): self.total_verified = 0 self.total_valid = 0 self._session_initialized = False + self._inline_mode = (self.resource_server_url == "inline") + if self._inline_mode: + logger.info("ErdosDiscovery: running in INLINE mode (no Gym server)") + self._session_initialized = True # No server to init async def _ensure_session(self): """Initialize the PUCT buffer on the resource server if not done.""" @@ -96,11 +191,15 @@ async def _ensure_session(self): async def _verify_single( self, - session: aiohttp.ClientSession, + session: Optional[aiohttp.ClientSession], response_text: str, parent_state: Optional[list[float]] = None, ) -> dict: - """Call /verify on the resource server for one response.""" + """Call /verify on the resource server, or compute inline.""" + if self._inline_mode: + return _inline_compute_reward( + response_text, timeout=self.sandbox_timeout + ) # Build a minimal NeMoGymResponse-like payload # The resource server extracts output_text from response.output_text body = { @@ -163,10 +262,14 @@ async def _async_step( answers = [None] * batch_size updated_metadata = list(metadata) - timeout = aiohttp.ClientTimeout(total=self.request_timeout) - async with aiohttp.ClientSession(timeout=timeout) as session: - import asyncio + if self._inline_mode: + session = None + else: + session = aiohttp.ClientSession( + timeout=aiohttp.ClientTimeout(total=self.request_timeout) + ) + try: tasks = [] for i, message_log in enumerate(message_log_batch): # Extract the last assistant message @@ -186,6 +289,9 @@ async def _async_step( ) results = await asyncio.gather(*tasks, return_exceptions=True) + finally: + if session is not None: + await session.close() for i, result in enumerate(results): if isinstance(result, Exception): From 05c7e95f17f19dd4cf5e858a9a2701699536e361 Mon Sep 17 00:00:00 2001 From: Morgane Moss Date: Tue, 31 Mar 2026 23:59:44 +0000 Subject: [PATCH 06/48] fix: run_discover.py config loading, setup/grpo_train signatures, debug iterations - Use load_config + OmegaConf resolve instead of yaml.safe_load - Match setup() and grpo_train() calling convention from sliding_puzzle - Fix SLURM script: PATH for uv, PYTHONPATH unbound var - Debug config uses defaults: grpo_math_1B.yaml for proper field inheritance --- 3rdparty/Gym-workspace/Gym | 2 +- erdos_debug.slurm | 2 +- .../configs/grpo_erdos_discover_debug.yaml | 59 ++++++------ examples/run_discover.py | 90 +++++++++++-------- pyproject.toml | 2 +- uv.lock | 2 + 6 files changed, 89 insertions(+), 68 deletions(-) mode change 160000 => 120000 3rdparty/Gym-workspace/Gym diff --git a/3rdparty/Gym-workspace/Gym b/3rdparty/Gym-workspace/Gym deleted file mode 160000 index 23cdeb3807..0000000000 --- a/3rdparty/Gym-workspace/Gym +++ /dev/null @@ -1 +0,0 @@ -Subproject commit 23cdeb38077d7b72a5fbae0927a2e1a74bfc15f7 diff --git a/3rdparty/Gym-workspace/Gym b/3rdparty/Gym-workspace/Gym new file mode 120000 index 0000000000..0d1f8dba9b --- /dev/null +++ b/3rdparty/Gym-workspace/Gym @@ -0,0 +1 @@ +/home/mormio/Gym \ No newline at end of file diff --git a/erdos_debug.slurm b/erdos_debug.slurm index 5248781469..439fcb1e7f 100644 --- a/erdos_debug.slurm +++ b/erdos_debug.slurm @@ -24,7 +24,7 @@ echo "Job ID: $SLURM_JOB_ID" cd /home/mormio/RL echo "Starting TTT-Discover GRPO debug training (inline mode)..." -uv run python examples/run_discover.py \ +PATH=$HOME/.local/bin:$PATH uv run python examples/run_discover.py --config \ examples/configs/grpo_erdos_discover_debug.yaml echo "Training complete" diff --git a/examples/configs/grpo_erdos_discover_debug.yaml b/examples/configs/grpo_erdos_discover_debug.yaml index e8473217ff..e34cd2d030 100644 --- a/examples/configs/grpo_erdos_discover_debug.yaml +++ b/examples/configs/grpo_erdos_discover_debug.yaml @@ -1,9 +1,7 @@ -# TTT-Discover DEBUG config — small model, short runs, inline reward (no Gym server). -# -# Usage: -# cd ~/RL && uv run python examples/run_discover.py examples/configs/grpo_erdos_discover_debug.yaml - -seed: 42 +# TTT-Discover DEBUG config. +# Inherits from grpo_math_1B.yaml for all defaults. +# Overrides: entropic advantages, inline reward, small batch, 5 steps. +defaults: "grpo_math_1B.yaml" grpo: num_prompts_per_step: 4 @@ -20,18 +18,17 @@ loss_fn: kl_penalty_coef: 0.1 ratio_clip: 0.2 token_level_loss: false - importance_sampling: true policy: - model_name: "/home/shared/models/Qwen3-1.7B" - tokenizer: "/home/shared/models/Qwen3-1.7B" + model_name: "Qwen/Qwen2.5-1.5B-Instruct" max_total_sequence_length: 4096 train_global_batch_size: 32 train_micro_batch_size: 4 dtensor_cfg: enabled: true - tensor_parallel_size: 1 + cpu_offload: true + activation_checkpointing: true sequence_parallel: false lora_cfg: @@ -40,40 +37,42 @@ policy: alpha: 1.0 dropout: 0.0 + generation: + backend: "vllm" + max_new_tokens: 2048 + temperature: 1.0 + top_p: 1.0 + stop_token_ids: null + stop_strings: null + vllm_cfg: + async_engine: false + tensor_parallel_size: 1 + pipeline_parallel_size: 1 + expert_parallel_size: 1 + gpu_memory_utilization: 0.6 + max_model_len: ${policy.max_total_sequence_length} + optimizer: name: adamw lr: 1.0e-4 - betas: [0.9, 0.999] - weight_decay: 0.01 - scheduler: - name: cosine - warmup_steps: 1 - min_lr_ratio: 0.1 data: shuffle: false env: erdos_discovery: - resource_server_url: "inline" # No Gym server — compute reward in-process + resource_server_url: "inline" num_initial_states: 8 num_groups_per_step: 4 sandbox_timeout: 60 request_timeout: 120 -cluster: - gpus_per_node: 8 - num_nodes: 1 - -generation: - backend: vllm - colocated: true - temperature: 1.0 - top_p: 1.0 - max_new_tokens: 2048 - checkpointing: enabled: false -wandb: - enabled: false +logger: + log_dir: "logs/erdos-debug" + wandb_enabled: false + tensorboard_enabled: false + mlflow_enabled: false + swanlab_enabled: false diff --git a/examples/run_discover.py b/examples/run_discover.py index f0e19a5bd3..6b2d7bc80a 100644 --- a/examples/run_discover.py +++ b/examples/run_discover.py @@ -13,8 +13,11 @@ Reference: "Learning to Discover at Test Time" (arXiv:2601.16175) """ +import itertools +import argparse import itertools import logging +import os import sys from typing import Optional @@ -33,6 +36,7 @@ ErdosDiscoveryEnvironment, ) from nemo_rl.models.generation import configure_generation_config +from nemo_rl.utils.config import load_config, register_omegaconf_resolvers logger = logging.getLogger(__name__) @@ -265,63 +269,79 @@ def setup_discover_data(config: MasterConfig, tokenizer): def main(): - import yaml - from pathlib import Path - - # Load config - config_path = sys.argv[1] if len(sys.argv) > 1 else str( - Path(__file__).parent / "configs" / "grpo_erdos_discover.yaml" - ) - - if config_path.startswith("--config="): - config_path = config_path.split("=", 1)[1] - elif config_path == "--config" and len(sys.argv) > 2: - config_path = sys.argv[2] + import os + from omegaconf import OmegaConf + from nemo_rl.utils.config import load_config, register_omegaconf_resolvers + + register_omegaconf_resolvers() + + # Parse --config argument + config_path = None + for i, arg in enumerate(sys.argv[1:], 1): + if arg.startswith("--config="): + config_path = arg.split("=", 1)[1] + elif arg == "--config" and i < len(sys.argv) - 1: + config_path = sys.argv[i + 1] + elif not arg.startswith("--") and config_path is None: + config_path = arg + + if config_path is None: + config_path = os.path.join( + os.path.dirname(__file__), "configs", "grpo_erdos_discover_debug.yaml" + ) print(f"Loading config from: {config_path}") - with open(config_path) as f: - config: MasterConfig = yaml.safe_load(f) + config = load_config(config_path) + + # Resolve OmegaConf interpolations (e.g. ${policy.model_name}) + oc = OmegaConf.create(config) + config = OmegaConf.to_container(oc, resolve=True) # Initialize Ray - init_ray(config) + init_ray() set_seed(config.get("seed", 42)) # Tokenizer tokenizer = get_tokenizer(config["policy"]["tokenizer"]) # Generation config - configure_generation_config(config) + config["policy"]["generation"] = configure_generation_config( + config["policy"]["generation"], tokenizer + ) # Setup data + environment train_dataset, val_dataset, task_to_env, val_task_to_env = ( setup_discover_data(config, tokenizer) ) - # Setup policy, generation backend, cluster, etc. + # Setup policy, generation, cluster, dataloader, etc. ( - cluster, policy, - generation, - train_dataloader, + policy_generation, + clusters, + dataloader, val_dataloader, - ) = setup( - config=config, - tokenizer=tokenizer, - train_dataset=train_dataset, - val_dataset=val_dataset, - ) + loss_fn, + nemo_logger, + checkpointer, + grpo_state, + master_config, + ) = setup(config, tokenizer, train_dataset, val_dataset) # Run GRPO training grpo_train( - master_config=config, - policy=policy, - generation=generation, - cluster=cluster, - wrapped_dataloader=train_dataloader, - val_wrapped_dataloader=val_dataloader, - task_to_env=task_to_env, - val_task_to_env=val_task_to_env, - tokenizer=tokenizer, + policy, + policy_generation, + dataloader, + val_dataloader, + tokenizer, + loss_fn, + task_to_env, + val_task_to_env, + nemo_logger, + checkpointer, + grpo_state, + master_config, ) diff --git a/pyproject.toml b/pyproject.toml index 17830c391c..22bc8a1ba8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -50,11 +50,11 @@ dependencies = [ "swanlab", "pyzmq", "decord2", - "nccl4py", # for non-colocated refit "cuda-bindings", # for non-colocated refit "pybase64", # for sglang refit "nvidia-cudnn-cu12==9.19.0.56", # for transformer-engine no build isolation + "sentencepiece>=0.2.1", ] [project.optional-dependencies] diff --git a/uv.lock b/uv.lock index 27ddd02b2a..e110fa526d 100644 --- a/uv.lock +++ b/uv.lock @@ -4482,6 +4482,7 @@ dependencies = [ { name = "pyzmq" }, { name = "ray", extra = ["default"] }, { name = "rich" }, + { name = "sentencepiece" }, { name = "setuptools" }, { name = "swanlab" }, { name = "sympy" }, @@ -4628,6 +4629,7 @@ requires-dist = [ { name = "pyzmq" }, { name = "ray", extras = ["default"], specifier = "==2.54.0" }, { name = "rich" }, + { name = "sentencepiece", specifier = ">=0.2.1" }, { name = "setuptools" }, { name = "sgl-kernel", marker = "extra == 'sglang'", git = "https://github.com/JustinTong0323/sglang.git?subdirectory=sgl-kernel&rev=70aa688742dd2b75bf9e8e980249303f39295b0d" }, { name = "sglang", marker = "extra == 'sglang'", git = "https://github.com/JustinTong0323/sglang.git?subdirectory=python&rev=70aa688742dd2b75bf9e8e980249303f39295b0d" }, From b634fec2835d82029e75d93c1938c98c96e66801 Mon Sep 17 00:00:00 2001 From: Morgane Moss Date: Wed, 1 Apr 2026 05:36:31 +0000 Subject: [PATCH 07/48] fix: v0.5.0 container compatibility - handle missing register_omegaconf_resolvers, add launch scripts and shim for container deployment --- erdos_debug_container.slurm | 37 + examples/run_discover.py | 11 +- launch_erdos_debug.sh | 94 + message (3).md | 207 ++ ray.sub | 142 +- ray.sub.bak | 487 +++ shim/grpo_erdos_discover_debug.yaml | 78 + shim/nemo_rl/__init__.py | 0 shim/nemo_rl/algorithms/__init__.py | 0 .../entropic_advantage_estimator.py | 179 + shim/nemo_rl/algorithms/grpo.py | 3205 +++++++++++++++++ shim/nemo_rl/environments/__init__.py | 0 .../erdos_discovery_environment.py | 362 ++ shim/nemo_rl/environments/utils.py | 139 + shim/nemo_rl/utils/__init__.py | 0 shim/nemo_rl/utils/puct_buffer.py | 561 +++ shim/run_discover.py | 349 ++ 17 files changed, 5828 insertions(+), 23 deletions(-) create mode 100644 erdos_debug_container.slurm create mode 100755 launch_erdos_debug.sh create mode 100644 message (3).md create mode 100644 ray.sub.bak create mode 100644 shim/grpo_erdos_discover_debug.yaml create mode 100644 shim/nemo_rl/__init__.py create mode 100644 shim/nemo_rl/algorithms/__init__.py create mode 100644 shim/nemo_rl/algorithms/entropic_advantage_estimator.py create mode 100644 shim/nemo_rl/algorithms/grpo.py create mode 100644 shim/nemo_rl/environments/__init__.py create mode 100644 shim/nemo_rl/environments/erdos_discovery_environment.py create mode 100644 shim/nemo_rl/environments/utils.py create mode 100644 shim/nemo_rl/utils/__init__.py create mode 100644 shim/nemo_rl/utils/puct_buffer.py create mode 100644 shim/run_discover.py diff --git a/erdos_debug_container.slurm b/erdos_debug_container.slurm new file mode 100644 index 0000000000..f077d7b529 --- /dev/null +++ b/erdos_debug_container.slurm @@ -0,0 +1,37 @@ +#!/bin/bash +#SBATCH --job-name=erdos-debug +#SBATCH --output=logs/erdos-debug-%j.out +#SBATCH --error=logs/erdos-debug-%j.err +#SBATCH --nodes=1 +#SBATCH --ntasks-per-node=1 +#SBATCH --exclusive +#SBATCH --gres=gpu:8 +#SBATCH --cpus-per-task=64 +#SBATCH --container-image=nvcr.io#nvidia/nemo-rl:v0.5.0 +#SBATCH --container-writable +#SBATCH --container-mounts=/home/mormio/RL:/home/mormio/RL,/home/shared/models:/home/shared/models +#SBATCH --exclude=d2dfac12-001,d2dfac12-002,d2dfac12-004,d2dfac12-007,d2dfac12-008,d2dfac12-019,d2dfac12-027,d2dfac12-028,d2dfac12-029 + +# TTT-Discover debug — single node, inside NeMo RL container. +# Skips ray.sub complexity. Starts Ray inline and runs training. + +set -eo pipefail + +echo "Node: $(hostname)" +echo "GPUs: $(nvidia-smi -L | wc -l)" +echo "Python: $(python --version)" +echo "Job: $SLURM_JOB_ID" + +# NCCL config +export NCCL_IB_HCA=mlx5_4:1,mlx5_7:1,mlx5_8:1,mlx5_9:1,mlx5_10:1,mlx5_13:1,mlx5_14:1,mlx5_15:1 +export NCCL_SOCKET_IFNAME=bond0 +export UCX_NET_DEVICES=bond0 +export HF_HUB_ENABLE_HF_TRANSFER=0 + +cd /home/mormio/RL + +# Run training directly — Ray will start automatically via init_ray() +uv run python examples/run_discover.py \ + --config examples/configs/grpo_erdos_discover_debug.yaml + +echo "Training complete" diff --git a/examples/run_discover.py b/examples/run_discover.py index 6b2d7bc80a..e391a8e8c0 100644 --- a/examples/run_discover.py +++ b/examples/run_discover.py @@ -36,7 +36,7 @@ ErdosDiscoveryEnvironment, ) from nemo_rl.models.generation import configure_generation_config -from nemo_rl.utils.config import load_config, register_omegaconf_resolvers +from nemo_rl.utils.config import load_config logger = logging.getLogger(__name__) @@ -271,9 +271,12 @@ def setup_discover_data(config: MasterConfig, tokenizer): def main(): import os from omegaconf import OmegaConf - from nemo_rl.utils.config import load_config, register_omegaconf_resolvers - - register_omegaconf_resolvers() + from nemo_rl.utils.config import load_config + try: + from nemo_rl.utils.config import register_omegaconf_resolvers + register_omegaconf_resolvers() + except ImportError: + pass # v0.5.0 container doesn't have this # Parse --config argument config_path = None diff --git a/launch_erdos_debug.sh b/launch_erdos_debug.sh new file mode 100755 index 0000000000..10e98e65e8 --- /dev/null +++ b/launch_erdos_debug.sh @@ -0,0 +1,94 @@ +#!/bin/bash +# TTT-Discover Erdős GRPO Debug — FINAL +# Only adds our new files to the container's /opt/nemo-rl, does NOT overwrite existing code +set -euo pipefail +cd /home/mormio/RL + +CONTAINER="nvcr.io#nvidia/nemo-rl:v0.5.0" +EXP="results/erdos-debug-$(date +%Y%m%d_%H%M)" +mkdir -p "$EXP" + +MOUNTS="$PWD:/home/mormio/RL,/home/shared/models:/home/shared/models" + +# Only add NEW files to the container's /opt/nemo-rl +# Do NOT overwrite existing files (they match the container's deps) +COMMAND=' +export HF_HUB_ENABLE_HF_TRANSFER=0 +export TORCH_CUDA_ARCH_LIST="9.0 10.0" +export NRL_IGNORE_VERSION_MISMATCH=1 + +SRC=/home/mormio/RL + +# Add only our NEW files (not overwriting anything) +cp $SRC/nemo_rl/algorithms/entropic_advantage_estimator.py /opt/nemo-rl/nemo_rl/algorithms/ +cp $SRC/nemo_rl/environments/erdos_discovery_environment.py /opt/nemo-rl/nemo_rl/environments/ +cp $SRC/nemo_rl/utils/puct_buffer.py /opt/nemo-rl/nemo_rl/utils/ +cp $SRC/examples/run_discover.py /opt/nemo-rl/examples/ +cp $SRC/examples/configs/grpo_erdos_discover_debug.yaml /opt/nemo-rl/examples/configs/ + +# Patch the container grpo.py to register our entropic estimator +# (append the elif branch to _create_advantage_estimator) +python -c " +path = \"/opt/nemo-rl/nemo_rl/algorithms/grpo.py\" +with open(path) as f: + content = f.read() +if \"entropic_adaptive_beta\" not in content: + old = \" else:\\n raise ValueError(f\\\"Invalid adv_estimator name: {adv_estimator_name}\\\")\\n\\n return adv_estimator\" + new = \"\"\" elif adv_estimator_name == \\\"entropic_adaptive_beta\\\": + from nemo_rl.algorithms.entropic_advantage_estimator import ( + EntropicAdaptiveBetaAdvantageEstimator, + ) + adv_estimator = EntropicAdaptiveBetaAdvantageEstimator( + adv_estimator_config, loss_config + ) + print(\\\" Using Entropic Adaptive-Beta advantage estimator (TTT-Discover)\\\") + else: + raise ValueError(f\\\"Invalid adv_estimator name: {adv_estimator_name}\\\") + + return adv_estimator\"\"\" + content = content.replace(old, new) + with open(path, \"w\") as f: + f.write(content) + print(\"Patched grpo.py with entropic_adaptive_beta\") +else: + print(\"grpo.py already patched\") +" + +# Patch environments/utils.py to register erdos_discovery +python -c " +path = \"/opt/nemo-rl/nemo_rl/environments/utils.py\" +with open(path) as f: + content = f.read() +if \"erdos_discovery\" not in content: + content = content.replace( + \"\\\"nemo_gym\\\": {\", + \"\\\"erdos_discovery\\\": {\\n \\\"actor_class_fqn\\\": \\\"nemo_rl.environments.erdos_discovery_environment.ErdosDiscoveryEnvironment\\\",\\n },\\n \\\"nemo_gym\\\": {\" + ) + with open(path, \"w\") as f: + f.write(content) + print(\"Patched utils.py with erdos_discovery\") +else: + print(\"utils.py already patched\") +" + +cd /opt/nemo-rl +python examples/run_discover.py \ + --config examples/configs/grpo_erdos_discover_debug.yaml +' + +echo "Submitting Erdős TTT-Discover debug..." +echo "Experiment dir: $EXP" + +COMMAND="$COMMAND" \ +CONTAINER="$CONTAINER" \ +MOUNTS="$MOUNTS" \ +GPUS_PER_NODE=8 \ +sbatch \ + --nodes=2 --partition=batch --exclusive \ + --job-name=erdos-debug --time=01:00:00 \ + --output="$EXP/slurm-%j.out" \ + --error="$EXP/slurm-%j.err" \ + --exclude=d2dfac12-001,d2dfac12-002,d2dfac12-004,d2dfac12-007,d2dfac12-008,d2dfac12-019,d2dfac12-027,d2dfac12-028,d2dfac12-029 \ + ray.sub + +echo "Logs: $EXP/" diff --git a/message (3).md b/message (3).md new file mode 100644 index 0000000000..2e435dd893 --- /dev/null +++ b/message (3).md @@ -0,0 +1,207 @@ +# NeMo RL GRPO on Our Cluster — What It Actually Took + +## The Ask +Run the NVIDIA NeMo RL GRPO tutorial on our B200 cluster. + +## The Cluster +- 32 nodes, 8x B200 (183GB) per node +- Slurm + Pyxis/Enroot for containers +- Shared home directory (NFS), no /scratch +- InfiniBand networking (mlx5 HCAs, bond0) + +## What Worked Immediately +- Cloning the repo, downloading models from HuggingFace +- The `ray.sub` script for orchestrating Ray clusters via Slurm (with patches) + +## What Needed Fixing + +### Cluster-Specific Patches to ray.sub +Every run needed these two fixes to `ray.sub`: +```bash +# 1. MPI plugin: cluster has pmi2, not pmix +sed -i 's/--mpi=pmix/--mpi=pmi2/' ray.sub + +# 2. Container filesystem must be writable (ray writes launch scripts to /) +sed -i '/--no-container-mount-home/a COMMON_SRUN_ARGS+=" --container-writable"' ray.sub +``` + +Also needed to remove `--no-container-mount-home` so the shared home dir +(with model conversion cache) is visible across all nodes. + +### NGC Container Authentication +Pyxis/enroot needs NGC credentials to pull containers: +```bash +# ~/.config/enroot/.credentials +machine nvcr.io login $oauthtoken password +``` +The image URI format for Pyxis is `nvcr.io#nvidia/nemo-rl:v0.5.0` (note the `#`). + +### NCCL / InfiniBand Configuration +Multi-node training was getting `Network is unreachable` NCCL errors until +we added the cluster's IB config (copied from our existing torchtitan slurm scripts): +```bash +export NCCL_IB_HCA=mlx5_4:1,mlx5_7:1,mlx5_8:1,mlx5_9:1,mlx5_10:1,mlx5_13:1,mlx5_14:1,mlx5_15:1 +export NCCL_SOCKET_IFNAME=bond0 +export UCX_NET_DEVICES=bond0 +export NCCL_BUFFSIZE=33554432 +export NCCL_IB_AR_THRESHOLD=0 +export NCCL_IB_PCI_RELAXED_ORDERING=1 +# ... etc +``` +These must be set in `ray.sub` (not just the COMMAND) so every node gets them. + +### HF Transfer +The v0.5.0 container sets `HF_HUB_ENABLE_HF_TRANSFER=1` but the package +isn't installed. Set `HF_HUB_ENABLE_HF_TRANSFER=0` in the launch command. + +## The Llama 8B Math Run (worked quickly) +**Container**: `nvcr.io/nvidia/nemo-rl:v0.5.0` +**Branch**: `v0.5.0` tag +**Config**: `examples/configs/grpo_math_8B_megatron.yaml` +**Script**: `examples/run_grpo_math.py` + +This used OpenMathInstruct-2 (auto-downloads), single node, colocated generation. +Worked after fixing the ray.sub patches above and enabling W&B +(`++logger.wandb_enabled=true`). The base config has `wandb_enabled: false`. + +## The Workplace Assistant Tutorial (abandoned) +The original tutorial targets Nemotron Nano 9B v2 with the Workplace Assistant +NeMo Gym environment. This was a nightmare: + +- **v0.5.0 container + v0.5.0 code**: Chat template tokenization assertion errors + (`non-monotonically increasing trajectory`). The Nemotron Nano v2 tokenizer + handles multi-turn tool-calling conversations in a way that breaks the + `_replace_prefix_tokens` function during multi-step rollouts. +- **nano-v3 branch + v0.4.0.nemotron_3_nano container**: The `nemotron_json` + tool parser wasn't registered in the container's vLLM. +- The tutorial's `sed` commands to patch the chat template are insufficient. + The real fix exists only on the `nano-v3` branch which removes the assertion entirely. + +**Lesson**: The Workplace Assistant environment is tightly coupled to specific +branch/container combos. Use any other environment instead. + +## The Nemotron 3 Super 120B Run (what finally worked) + +### Container Build +The `super-v3` branch requires a custom container build because it uses a +patched vLLM for the NemotronH MoE architecture: + +```bash +# On a compute node (docker access required): +docker buildx build \ + --build-context nemo-rl=. \ + --build-arg SKIP_SGLANG_BUILD=1 \ + --build-arg BUILD_CUSTOM_VLLM=1 \ + -f docker/Dockerfile \ + --tag nemo-rl-super:v3 --load . + +# Convert to sqsh for Pyxis: +sudo enroot import -o nemo-rl-super-v3.sqsh "dockerd://nemo-rl-super:v3" +``` + +We had to install `docker-buildx` first (not available on the cluster by default). + +### HF→Megatron Model Conversion +First run converts the HuggingFace checkpoint to Megatron format (~231GB). +This is cached at `~/.cache/huggingface/nemo_rl/model__/`. +**The home dir must be mounted in the container** for this cache to be shared +across nodes. Previous runs with `--no-container-mount-home` caused the +conversion to succeed on the head node but be invisible to training nodes. + +### Chat Template +The base model (`NVIDIA-Nemotron-3-Super-120B-A12B-Base-BF16`) has no chat +template. The `math_hf_data_processor` calls `tokenizer.apply_chat_template()`, +which crashes. We added a minimal one: + +```python +data["chat_template"] = "{% for message in messages %}..." +``` + +### Data +The internal NVIDIA data paths in the configs (`/lustre/fsw/...`) don't exist. +We downloaded DAPO-Math-17k from HuggingFace and converted it, but ultimately +used the built-in `OpenMathInstruct-2` dataset with `math_hf_data_processor` +and `env.math` (rule-based math verification, no LLM judge needed). + +The base config also sets `data.max_input_seq_length: null` which causes a +`TypeError: '>' not supported between instances of 'int' and 'NoneType'`. +Override with `++data.max_input_seq_length=4096`. + +### NeMo Gym +The base `grpo_superv3.yaml` config has `env.nemo_gym.num_gpu_nodes: 4` which +reserves 4 GPU nodes for Gym environment servers (genrm judges, etc.). With +only 6 total nodes this left negative nodes for training. Either: +- Set `++env.nemo_gym.num_gpu_nodes=0` if using simple environments +- Set `++env.should_use_nemo_gym=false` to skip Gym entirely + +The container's built-in `nemo_gym` package also had an `ImportError` +(`cannot import name 'PARENT_DIR'`) when the Gym submodule from our repo +checkout was mounted over the container's version. Don't mount the Gym dir. + +### Parallelism (the hard part) +The 120B MoE model needs careful parallelism to fit in memory and divide evenly: + +**What didn't work:** +- TP=4, CP=4, PP=2: PP=2 from base config made TP×CP×PP=32 > 16 training GPUs +- TP=4, CP=4, PP=1, EP=8, 13 nodes: `World size (40) not divisible by 32` +- TP=4, CP=1, PP=1, EP=2: OOM (105GB allocation, only 32GB free per GPU) +- TP=4, CP=4, PP=1, EP=8, 4 training nodes: Worked for generation+logprobs, + but CUDA illegal memory access during backward (cross-node NCCL before IB fix) +- TP=4, CP=1, PP=1, EP=8 (after IB fix): OOM during training backward pass + +**What worked:** +```yaml +# 6 nodes: 2 inference, 4 training (32 GPUs) +tensor_model_parallel_size: 4 # within-node +pipeline_model_parallel_size: 1 +context_parallel_size: 1 # no cross-node context parallel +expert_model_parallel_size: 8 # MoE experts sharded across all 32 GPUs +# TP×PP×CP×EP = 4×1×1×8 = 32 = world_size, DP=1 + +# Memory optimizations required: +activation_checkpointing: true +empty_unused_memory_level: 2 +optimizer_cpu_offload: true +max_total_sequence_length: 4096 # reduced from 16384 +train_micro_batch_size: 1 +logprob_batch_size: 1 +num_prompts_per_step: 16 # reduced from 128 +num_generations_per_prompt: 8 # reduced from 16 +train_global_batch_size: 128 # reduced from 2048 +``` + +### Final Working Command +```bash +cd /opt/nemo-rl && \ +export NCCL_IB_HCA=mlx5_4:1,mlx5_7:1,... && \ +export NCCL_SOCKET_IFNAME=bond0 && \ +uv run python examples/run_grpo.py \ + --config=examples/configs/grpo_superv3.yaml \ + ++env.should_use_nemo_gym=false \ + ++data.train.dataset_name=OpenMathInstruct-2 \ + ++data.default.processor=math_hf_data_processor \ + ++data.default.env_name=math \ + ++env.math.num_workers=8 \ + ++env.math.math_verify_impl=hf_math_verify \ + ++policy.model_name=/path/to/model \ + ++cluster.num_nodes=6 \ + ++policy.generation.colocated.enabled=false \ + ++policy.generation.colocated.resources.num_nodes=2 \ + ++policy.megatron_cfg.tensor_model_parallel_size=4 \ + ++policy.megatron_cfg.pipeline_model_parallel_size=1 \ + ++policy.megatron_cfg.context_parallel_size=1 \ + ++policy.megatron_cfg.expert_model_parallel_size=8 \ + ++policy.megatron_cfg.activation_checkpointing=true \ + ++policy.megatron_cfg.optimizer.optimizer_cpu_offload=true \ + ++policy.max_total_sequence_length=4096 \ + # ... etc +``` + +~120 seconds per step on 6x B200 nodes (48 GPUs). + +## TL;DR +1. Fix `ray.sub` for your cluster (MPI plugin, container writable, home mount) +2. Set NCCL IB env vars for multi-node +3. Don't use the Workplace Assistant tutorial — use math environments instead +4. For Nemotron Super: build container from `super-v3` branch, use EP=8 to shard MoE experts, offload optimizer to CPU, reduce batch sizes +5. The model conversion cache needs to be on a shared filesystem visible to all nodes \ No newline at end of file diff --git a/ray.sub b/ray.sub index e6e3e07af7..3a9b367599 100644 --- a/ray.sub +++ b/ray.sub @@ -32,7 +32,8 @@ maybe_gres_arg() { # Assumes a homogeneous allocation (not a heterogeneous job) if sinfo -p $SLURM_JOB_PARTITION -h -o "%G" | grep -q "gpu:"; then # Do a quick assert here that gpus:8 == gpus:$GPUS_PER_NODE. It is probably a user error if someone isn't using GPUS_PER_NODE=8 on our clusters if it supports --gres=gpu:8 or gpu:a100:8 - if [[ $GPUS_PER_NODE -ne $(sinfo -p $SLURM_JOB_PARTITION -h -o "%G" | grep "gpu:" | awk -F: '{print $NF}') ]]; then + # Note: cut -d'(' -f1 removes the socket spec like "(S:0-3)" that some clusters append + if [[ $GPUS_PER_NODE -ne $(sinfo -p $SLURM_JOB_PARTITION -h -o "%G" | grep "gpu:" | cut -d'(' -f1 | awk -F: '{print $NF}') ]]; then echo "Error: GPUS_PER_NODE=$GPUS_PER_NODE but GRES detected is $(sinfo -p $SLURM_JOB_PARTITION -h -o "%G" | grep "gpu:") meaning GPUS_PER_NODE is not set to fully claim the GPUs on the nodes." >&2 exit 1 fi @@ -47,9 +48,11 @@ maybe_gres_arg() { ######################################################## # User defined variables ######################################################## +export OMP_NUM_THREADS=16 CONTAINER=$CONTAINER MOUNTS=$MOUNTS COMMAND=${COMMAND:-} # This is a script relative to the SLURM_SUBMIT_DIR. If left empty, it will leave the cluster idle after it's brought up. +SETUP_COMMAND=${SETUP_COMMAND:-} # Setup commands to run on all nodes before starting Ray ######################################################## # Ports for all nodes (should be odd numbers since we place head/worker[0] on the same node) so all workers get the odd ports, but the head will get +1 the ports NODE_MANAGER_PORT=${NODE_MANAGER_PORT:-53001} @@ -59,8 +62,9 @@ DASHBOARD_AGENT_GRPC_PORT=${DASHBOARD_AGENT_GRPC_PORT:-53007} METRICS_EXPORT_PORT=${METRICS_EXPORT_PORT:-53009} # Ports for the head node -PORT=${PORT:-54514} -RAY_CLIENT_SERVER_PORT=${RAY_CLIENT_SERVER_PORT:-10001} +# GCS head port and client server port -- kept below ephemeral range (32768-60999). +PORT=${PORT:-9900} +RAY_CLIENT_SERVER_PORT=${RAY_CLIENT_SERVER_PORT:-9901} #REDIT_SHARD_PORTS=${REDIT_SHARD_PORTS:-"random"} ?? DASHBOARD_PORT=${DASHBOARD_PORT:-8265} # Also used by debugger DASHBOARD_AGENT_LISTEN_PORT=${DASHBOARD_AGENT_LISTEN_PORT:-52365} @@ -84,16 +88,36 @@ elif [[ $(ulimit -Hn) != "unlimited" ]] && [[ $(ulimit -Hn) -lt 65535 ]]; then echo "[WARNING]: Cannot increase ulimit on file descriptors to 65535 according ray recommendation: https://docs.ray.io/en/latest/cluster/vms/user-guides/large-cluster-best-practices.html. Speak to cluster admins to increase, otherwise ray may crash unexpectedly." fi -# On our clusters, the largest port range on an idle worker appeared between 52369-64607 -# (not including the other ports set by this script). So this range is chosen to be -# somewhere in the middle -MIN_WORKER_PORT=${MIN_WORKER_PORT:-54001} -MAX_WORKER_PORT=${MAX_WORKER_PORT:-54513} +# Worker port range must NOT overlap with the OS ephemeral range (32768-60999) +# to prevent TOCTOU collisions. Ray's Raylet uses a probe-and-release pattern +# (CheckPortFree) that can leave a window where the port is unguarded. +# If the range overlaps with the ephemeral range, the kernel can assign the same +# port as an ephemeral source port for outgoing TCP traffic during that window, +# causing EADDRINUSE when the worker tries to bind. +# +# Port layout (all below ephemeral range): +# 10002-11000 Ray worker gRPC (this setting) +# 11001-15000 NeMo RL HTTP servers / TCPStore (policy.generation.port_range_low/high) +# 15001-20000 NeMo Gym HTTP servers (Gym global config port_range_low/high) +MIN_WORKER_PORT=${MIN_WORKER_PORT:-10002} +MAX_WORKER_PORT=${MAX_WORKER_PORT:-11000} ######################################################## # Number seconds to sync logs from /tmp/ray/session_*/logs to $LOG_DIR/ray/ RAY_LOG_SYNC_FREQUENCY=${RAY_LOG_SYNC_FREQUENCY:-} ######################################################## +# NCCL IB config for multi-node training +export NCCL_BUFFSIZE=33554432 +export CUDA_DEVICE_ORDER=PCI_BUS_ID +export NCCL_IB_AR_THRESHOLD=0 +export NCCL_IB_PCI_RELAXED_ORDERING=1 +export NCCL_IB_QPS_PER_CONNECTION=2 +export NCCL_IB_SPLIT_DATA_ON_QPS=0 +export NCCL_IGNORE_CPU_AFFINITY=1 +export NCCL_IB_HCA=mlx5_4:1,mlx5_7:1,mlx5_8:1,mlx5_9:1,mlx5_10:1,mlx5_13:1,mlx5_14:1,mlx5_15:1 +export NCCL_SOCKET_IFNAME=bond0 +export UCX_NET_DEVICES=bond0 + # Unset UV_CACHE_DIR to avoid local cache directory interferring with the container cache unset UV_CACHE_DIR @@ -111,6 +135,14 @@ BASE_LOG_DIR=${BASE_LOG_DIR:-$SLURM_SUBMIT_DIR} LOG_DIR="$BASE_LOG_DIR/$SLURM_JOB_ID-logs" mkdir -p $LOG_DIR +# Write SETUP_COMMAND to a file to avoid heredoc escaping issues +SETUP_COMMAND_FILE="" +if [[ -n "$SETUP_COMMAND" ]]; then + SETUP_COMMAND_FILE="$LOG_DIR/setup_command.sh" + echo "$SETUP_COMMAND" > "$SETUP_COMMAND_FILE" + chmod +x "$SETUP_COMMAND_FILE" +fi + # Number of GPUs per worker node GPUS_PER_NODE=${GPUS_PER_NODE:-8} @@ -123,8 +155,9 @@ else fi COMMON_SRUN_ARGS="$GRES_ARG" -COMMON_SRUN_ARGS+=" --no-container-mount-home" -COMMON_SRUN_ARGS+=" --mpi=pmix" +# COMMON_SRUN_ARGS+=" --no-container-mount-home" # Disabled: need shared home for model conversion cache +COMMON_SRUN_ARGS+=" --container-writable" +COMMON_SRUN_ARGS+=" --mpi=pmi2" COMMON_SRUN_ARGS+=" --container-mounts=$MOUNTS" COMMON_SRUN_ARGS+=" --container-image=$CONTAINER" COMMON_SRUN_ARGS+=" --container-workdir=$SLURM_SUBMIT_DIR" @@ -132,7 +165,7 @@ COMMON_SRUN_ARGS+=" --container-workdir=$SLURM_SUBMIT_DIR" COMMON_SRUN_ARGS+=" -p $SLURM_JOB_PARTITION" COMMON_SRUN_ARGS+=" -A $SLURM_JOB_ACCOUNT" # Number of CPUs per worker node -CPUS_PER_WORKER=${CPUS_PER_WORKER:-$((GPUS_PER_NODE * 16))} +CPUS_PER_WORKER=${CPUS_PER_WORKER:-112} num_retries=3 @@ -295,6 +328,10 @@ chmod +x /launch-head.sh count=0 while [[ \$count -lt $num_retries ]]; do + if [[ -n "$SETUP_COMMAND_FILE" ]] && [[ -f "$SETUP_COMMAND_FILE" ]]; then + echo "[INFO] Running setup command from $SETUP_COMMAND_FILE..." + bash "$SETUP_COMMAND_FILE" + fi bash /launch-head.sh count=\$((count+1)) echo "Head node failed \$count/$num_retries times, restarting in 5 seconds..." @@ -304,9 +341,10 @@ touch $LOG_DIR/ENDED exit 1 EOF ) -srun $COMMON_SRUN_ARGS --container-name=ray-head --nodes=1 --ntasks=1 --cpus-per-task=$CPUS_PER_WORKER -w "$head_node" -o $LOG_DIR/ray-head.log bash -x -c "$head_cmd" & +srun $COMMON_SRUN_ARGS --nodes=1 --ntasks=1 --cpus-per-task=$CPUS_PER_WORKER -w "$head_node" -o $LOG_DIR/ray-head.log bash -x -c "$head_cmd" & SRUN_PIDS["ray-head"]=$! +sleep 120s NUM_ACTORS=$((GPUS_PER_NODE * SLURM_JOB_NUM_NODES)) # Start Ray worker nodes @@ -394,6 +432,10 @@ EOFINNER count=0 while [[ \$count -lt $num_retries ]]; do + if [[ -n "$SETUP_COMMAND_FILE" ]] && [[ -f "$SETUP_COMMAND_FILE" ]]; then + echo "[INFO] Running setup command from $SETUP_COMMAND_FILE..." + bash "$SETUP_COMMAND_FILE" + fi bash /launch-worker.sh count=\$((count+1)) echo "Worker failed \$count/$num_retries times, restarting in 5 seconds..." @@ -403,7 +445,7 @@ touch $LOG_DIR/ENDED exit 1 EOF ) - srun $COMMON_SRUN_ARGS --container-name=ray-worker-$i --exact --nodes=1 --ntasks=1 --cpus-per-task=$CPUS_PER_WORKER -w "$node_i" -o $LOG_DIR/ray-worker-$i.log bash -x -c "$worker_cmd" & + srun $COMMON_SRUN_ARGS --exact --nodes=1 --ntasks=1 --cpus-per-task=$CPUS_PER_WORKER -w "$node_i" -o $LOG_DIR/ray-worker-$i.log bash -x -c "$worker_cmd" & SRUN_PIDS["ray-worker-$i"]=$! sleep 3 done @@ -414,11 +456,60 @@ while check_srun_processes && ! srun --overlap --nodes=1 --ntasks=1 -w $head_nod sleep 2 done +######################################################## +# Run sandbox in parallel on the head node (overlap) +######################################################## +export SLURM_MASTER_NODE=$(scontrol show hostnames $SLURM_JOB_NODELIST | head -n1) + +# Check if SANDBOX_CONTAINER and SANDBOX_COMMAND are defined +if [[ -n "${SANDBOX_CONTAINER:-}" ]] && [[ -n "${SANDBOX_COMMAND:-}" ]]; then + SANDBOX_PORTS_DIR="$LOG_DIR/sandbox" + mkdir -p "$SANDBOX_PORTS_DIR" + echo "[INFO] Starting sandbox on head node in parallel (ports_dir=$SANDBOX_PORTS_DIR)..." + srun --output "$LOG_DIR/sandbox.log" \ + --error "$LOG_DIR/sandbox.log" \ + --container-image="$SANDBOX_CONTAINER" \ + --container-mounts="$SANDBOX_PORTS_DIR:$SANDBOX_PORTS_DIR" \ + --no-container-mount-home \ + --mpi=pmi2 \ + -A "$SLURM_JOB_ACCOUNT" \ + -p "$SLURM_JOB_PARTITION" \ + --wait=60 \ + --kill-on-bad-exit=1 \ + --overlap \ + --nodes="$SLURM_JOB_NUM_NODES" \ + --ntasks-per-node=1 \ + -w "$head_node" \ + --export=ALL,SANDBOX_PORTS_DIR=$SANDBOX_PORTS_DIR \ + bash -c "$SANDBOX_COMMAND" & + SRUN_PIDS["sandbox"]=$! + echo "[INFO] Sandbox started in background (PID: ${SRUN_PIDS["sandbox"]})" +else + echo "[INFO] SANDBOX_CONTAINER or SANDBOX_COMMAND not defined, skipping sandbox startup" +fi + # At this stage the Ray cluster bringup has started on the physical nodes in the allocation # Before we launch a job on this cluster we need to make sure that the bringup is complete # We do so by querying the number of worker_units in the ray cluster and asserting = NUM_ACTORS + +# Helper function to get container PID using enroot +# Since we use -w to target specific nodes and only run one container per node, +# we can just grab the first (and only) container PID on that node +get_container_pid() { + local node=$1 + srun --overlap --nodes=1 -w "$node" bash -c "enroot list -f | awk 'NR>1 && \$2 ~ /^[0-9]+\$/ {print \$2; exit}'" +} + extract_worker_units() { - status_output=$(srun --overlap --container-name=ray-head --nodes=1 --ntasks=1 -w "$head_node" ray status) + # Get the container PID for ray-head + head_container_pid=$(get_container_pid "$head_node") + if [[ -z "$head_container_pid" ]]; then + echo 0 + return + fi + + # Execute ray status inside the container using enroot exec + status_output=$(srun --overlap --nodes=1 -w "$head_node" enroot exec "$head_container_pid" ray status) if echo "$status_output" | grep -q "worker_units"; then worker_units=$(echo "$status_output" | grep "worker_units" | awk -F'[/. ]' '{print $4}') echo $worker_units @@ -448,20 +539,30 @@ echo "All workers connected!" # This driver process is responsible for launching a job on the Ray cluster CONTAINER_CWD=$(scontrol show job $SLURM_JOB_ID | grep -oP 'WorkDir=\K[^ ]+' | head -1) if [[ -n "$COMMAND" ]]; then - srun --no-container-mount-home --overlap --container-name=ray-head --container-workdir=$CONTAINER_CWD --nodes=1 --ntasks=1 -w "$head_node" -o $LOG_DIR/ray-driver.log bash -c "$COMMAND" + # Get container PID and execute command inside it + head_container_pid=$(get_container_pid "$head_node") + srun --overlap --nodes=1 -w "$head_node" -o $LOG_DIR/ray-driver.log enroot exec "$head_container_pid" bash -c "cd $CONTAINER_CWD && $COMMAND" else echo "[INFO]: Ray Cluster is idled, run this on the slurm head node to get a shell to the head node:" cat <$SLURM_SUBMIT_DIR/${SLURM_JOB_ID}-attach.sh # No args launches on the head node (node 0) # Args 1-N launch on worker nodes (nodes 1 through N-1) # Optional: set COMMAND='...' to run non-interactively instead of opening an interactive shell + +# Helper to get container PID +get_container_pid() { + local node=\$1 + srun --overlap --nodes=1 -w "\$node" --jobid $SLURM_JOB_ID bash -c "enroot list -f | awk 'NR>1 && \\\$2 ~ /^[0-9]+\\\$/ {print \\\$2; exit}'" +} + WORKER_NUM=\${1:-} if [[ -z "\$WORKER_NUM" ]]; then # Empty means we are on the head node + HEAD_PID=\$(get_container_pid "$head_node") if [[ -n "\${COMMAND:-}" ]]; then - srun --no-container-mount-home $GRES_ARG -A $SLURM_JOB_ACCOUNT -p $SLURM_JOB_PARTITION --overlap --container-name=ray-head --container-workdir=$CONTAINER_CWD --nodes=1 --ntasks=1 -w "$head_node" --jobid $SLURM_JOB_ID bash -c "\$COMMAND" + srun $GRES_ARG -A $SLURM_JOB_ACCOUNT -p $SLURM_JOB_PARTITION --overlap --nodes=1 -w "$head_node" --jobid $SLURM_JOB_ID enroot exec "\$HEAD_PID" bash -c "cd $CONTAINER_CWD && \$COMMAND" else - srun --no-container-mount-home $GRES_ARG -A $SLURM_JOB_ACCOUNT -p $SLURM_JOB_PARTITION --overlap --container-name=ray-head --container-workdir=$CONTAINER_CWD --nodes=1 --ntasks=1 -w "$head_node" --jobid $SLURM_JOB_ID --pty bash + srun $GRES_ARG -A $SLURM_JOB_ACCOUNT -p $SLURM_JOB_PARTITION --overlap --nodes=1 -w "$head_node" --jobid $SLURM_JOB_ID --pty enroot exec "\$HEAD_PID" bash -c "cd $CONTAINER_CWD && exec bash" fi else # Worker numbers 1 through N-1 correspond to ray-worker-1 through ray-worker-(N-1) @@ -471,10 +572,12 @@ else exit 1 fi nodes_array=($nodes) + node="\${nodes_array[\$WORKER_NUM]}" + WORKER_PID=\$(get_container_pid "\$node") if [[ -n "\${COMMAND:-}" ]]; then - srun --no-container-mount-home $GRES_ARG -A $SLURM_JOB_ACCOUNT -p $SLURM_JOB_PARTITION --overlap --container-name=ray-worker-\$WORKER_NUM --container-workdir=$CONTAINER_CWD --nodes=1 --ntasks=1 -w "\${nodes_array[\$WORKER_NUM]}" --jobid $SLURM_JOB_ID bash -c "\$COMMAND" + srun $GRES_ARG -A $SLURM_JOB_ACCOUNT -p $SLURM_JOB_PARTITION --overlap --nodes=1 -w "\$node" --jobid $SLURM_JOB_ID enroot exec "\$WORKER_PID" bash -c "cd $CONTAINER_CWD && \$COMMAND" else - srun --no-container-mount-home $GRES_ARG -A $SLURM_JOB_ACCOUNT -p $SLURM_JOB_PARTITION --overlap --container-name=ray-worker-\$WORKER_NUM --container-workdir=$CONTAINER_CWD --nodes=1 --ntasks=1 -w "\${nodes_array[\$WORKER_NUM]}" --jobid $SLURM_JOB_ID --pty bash + srun $GRES_ARG -A SLURM_JOB_ACCOUNT -p $SLURM_JOB_PARTITION --overlap --nodes=1 -w "\$node" --jobid $SLURM_JOB_ID --pty enroot exec "\$WORKER_PID" bash -c "cd $CONTAINER_CWD && exec bash" fi fi EOF @@ -485,3 +588,4 @@ EOF echo " bash $SLURM_SUBMIT_DIR/${SLURM_JOB_ID}-attach.sh 2 # to attach to worker 2, etc." sleep infinity fi + diff --git a/ray.sub.bak b/ray.sub.bak new file mode 100644 index 0000000000..e6e3e07af7 --- /dev/null +++ b/ray.sub.bak @@ -0,0 +1,487 @@ +#!/bin/bash +#SBATCH --nodes=2 +#SBATCH --exclusive +#SBATCH --account=ACCOUNT +#SBATCH --job-name=JOB_NAME +#SBATCH --partition=PARTITION +#SBATCH --time=1:0:0 +#SBATCH --dependency=singleton + +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +set -eoux pipefail + +######################################################## +# Function to detect if SLURM cluster uses GRES +######################################################## +maybe_gres_arg() { + # Check if any nodes in the partition have GRES configured + # Assumes a homogeneous allocation (not a heterogeneous job) + if sinfo -p $SLURM_JOB_PARTITION -h -o "%G" | grep -q "gpu:"; then + # Do a quick assert here that gpus:8 == gpus:$GPUS_PER_NODE. It is probably a user error if someone isn't using GPUS_PER_NODE=8 on our clusters if it supports --gres=gpu:8 or gpu:a100:8 + if [[ $GPUS_PER_NODE -ne $(sinfo -p $SLURM_JOB_PARTITION -h -o "%G" | grep "gpu:" | awk -F: '{print $NF}') ]]; then + echo "Error: GPUS_PER_NODE=$GPUS_PER_NODE but GRES detected is $(sinfo -p $SLURM_JOB_PARTITION -h -o "%G" | grep "gpu:") meaning GPUS_PER_NODE is not set to fully claim the GPUs on the nodes." >&2 + exit 1 + fi + echo "--gres=gpu:${GPUS_PER_NODE}" + return + fi + + # No GRES support detected + echo "" +} + +######################################################## +# User defined variables +######################################################## +CONTAINER=$CONTAINER +MOUNTS=$MOUNTS +COMMAND=${COMMAND:-} # This is a script relative to the SLURM_SUBMIT_DIR. If left empty, it will leave the cluster idle after it's brought up. +######################################################## +# Ports for all nodes (should be odd numbers since we place head/worker[0] on the same node) so all workers get the odd ports, but the head will get +1 the ports +NODE_MANAGER_PORT=${NODE_MANAGER_PORT:-53001} +OBJECT_MANAGER_PORT=${OBJECT_MANAGER_PORT:-53003} +RUNTIME_ENV_AGENT_PORT=${RUNTIME_ENV_AGENT_PORT:-53005} +DASHBOARD_AGENT_GRPC_PORT=${DASHBOARD_AGENT_GRPC_PORT:-53007} +METRICS_EXPORT_PORT=${METRICS_EXPORT_PORT:-53009} + +# Ports for the head node +PORT=${PORT:-54514} +RAY_CLIENT_SERVER_PORT=${RAY_CLIENT_SERVER_PORT:-10001} +#REDIT_SHARD_PORTS=${REDIT_SHARD_PORTS:-"random"} ?? +DASHBOARD_PORT=${DASHBOARD_PORT:-8265} # Also used by debugger +DASHBOARD_AGENT_LISTEN_PORT=${DASHBOARD_AGENT_LISTEN_PORT:-52365} +RAY_DEBUGGER_ARGS= +if [ "${RAY_DEBUG:-}" = "legacy" ]; then + RAY_DEBUGGER_ARGS="--ray-debugger-external" +fi + +# After ray>=2.47, this feature is enabled by default which creates uv venvs for any py_executable starting with `uv run`. +# There is severe contention and performance issues with this enabled considering our dependencies are so large and occasionally +# need to be compiled, so NeMo RL has an implementation in nemo_rl/utils/venv.py that does it once per node as opposed to once per task. +export RAY_ENABLE_UV_RUN_RUNTIME_ENV=0 + +# Setting ulimit is recommended by ray best practices page +# @ https://docs.ray.io/en/latest/cluster/vms/user-guides/large-cluster-best-practices.html +# It's session based and won't affect the system outside the script +# Ensure that the soft limit isn't above the hard limit +if [[ $(ulimit -Hn) == "unlimited" ]] || [[ 65535 -lt $(ulimit -Hn) ]]; then + ulimit -Sn 65535 +elif [[ $(ulimit -Hn) != "unlimited" ]] && [[ $(ulimit -Hn) -lt 65535 ]]; then + echo "[WARNING]: Cannot increase ulimit on file descriptors to 65535 according ray recommendation: https://docs.ray.io/en/latest/cluster/vms/user-guides/large-cluster-best-practices.html. Speak to cluster admins to increase, otherwise ray may crash unexpectedly." +fi + +# On our clusters, the largest port range on an idle worker appeared between 52369-64607 +# (not including the other ports set by this script). So this range is chosen to be +# somewhere in the middle +MIN_WORKER_PORT=${MIN_WORKER_PORT:-54001} +MAX_WORKER_PORT=${MAX_WORKER_PORT:-54513} +######################################################## +# Number seconds to sync logs from /tmp/ray/session_*/logs to $LOG_DIR/ray/ +RAY_LOG_SYNC_FREQUENCY=${RAY_LOG_SYNC_FREQUENCY:-} +######################################################## + +# Unset UV_CACHE_DIR to avoid local cache directory interferring with the container cache +unset UV_CACHE_DIR + +if [[ -n "${UV_CACHE_DIR_OVERRIDE:-}" ]]; then + mkdir -p "$UV_CACHE_DIR_OVERRIDE" + if [[ -n $MOUNTS ]]; then + MOUNTS+=",$UV_CACHE_DIR_OVERRIDE:/root/.cache/uv" + else + MOUNTS="$UV_CACHE_DIR_OVERRIDE:/root/.cache/uv" + fi +fi + +# Create logs directory +BASE_LOG_DIR=${BASE_LOG_DIR:-$SLURM_SUBMIT_DIR} +LOG_DIR="$BASE_LOG_DIR/$SLURM_JOB_ID-logs" +mkdir -p $LOG_DIR + +# Number of GPUs per worker node +GPUS_PER_NODE=${GPUS_PER_NODE:-8} + +# Detect GRES support and set GRES_ARG +GRES_ARG=$(maybe_gres_arg) +if [[ -n "$GRES_ARG" ]]; then + echo "[INFO] GRES support detected. Using: $GRES_ARG" +else + echo "[INFO] No GRES support detected. Running without --gres flag." +fi + +COMMON_SRUN_ARGS="$GRES_ARG" +COMMON_SRUN_ARGS+=" --no-container-mount-home" +COMMON_SRUN_ARGS+=" --mpi=pmix" +COMMON_SRUN_ARGS+=" --container-mounts=$MOUNTS" +COMMON_SRUN_ARGS+=" --container-image=$CONTAINER" +COMMON_SRUN_ARGS+=" --container-workdir=$SLURM_SUBMIT_DIR" +# TODO: delete these (just for debugging) +COMMON_SRUN_ARGS+=" -p $SLURM_JOB_PARTITION" +COMMON_SRUN_ARGS+=" -A $SLURM_JOB_ACCOUNT" +# Number of CPUs per worker node +CPUS_PER_WORKER=${CPUS_PER_WORKER:-$((GPUS_PER_NODE * 16))} + +num_retries=3 + +# Track backgrounded srun client PIDs for head and workers +declare -A SRUN_PIDS + +# Verify all backgrounded srun client processes are still alive; exit fast if any died +check_srun_processes() { + for name in "${!SRUN_PIDS[@]}"; do + pid="${SRUN_PIDS[$name]}" + # Check if the process is still running + if ! kill -0 "$pid" 2>/dev/null; then + echo "[ERROR] Background srun '$name' died (pid=$pid). Could be a failure in startup or an issue with the node preventing the srun to start. Attempting to exit." >&2 + # Signal sidecars inside containers to terminate ASAP + touch "$LOG_DIR/ENDED" + exit 1 + fi + done +} + +# Getting the node names and IP addresses in the SLURM allocation +nodes=$(scontrol show hostnames "$SLURM_JOB_NODELIST") +nodes_array=($nodes) +ip_addresses_array=() + +for node in $nodes; do + # Try multiple methods to get IP address - ENHANCED VERSION v2.0 + echo "[DEBUG] Resolving hostname: $node using enhanced resolution methods" + ip_address="" + + # Method 1: Try host command + echo "[DEBUG] Method 1: host command" + ip_address=$(host $node 2>/dev/null | awk '/has address/ { print $4 }' | head -1 || true) + echo "[DEBUG] host result: '$ip_address'" + + # Method 2: If host fails, try getent + if [[ -z "$ip_address" ]]; then + echo "[DEBUG] Method 2: getent hosts" + ip_address=$(getent hosts $node 2>/dev/null | awk '{ print $1 }' | head -1 || true) + echo "[DEBUG] getent result: '$ip_address'" + fi + + # Method 3: If getent fails, try nslookup + if [[ -z "$ip_address" ]]; then + echo "[DEBUG] Method 3: nslookup" + ip_address=$(nslookup $node 2>/dev/null | awk '/^Address: / { print $2 }' | head -1 || true) + echo "[DEBUG] nslookup result: '$ip_address'" + fi + + # Method 4: If all DNS methods fail, try ping to extract IP + if [[ -z "$ip_address" ]]; then + echo "[DEBUG] Method 4: ping" + ip_address=$(ping -c 1 $node 2>/dev/null | grep "PING" | sed 's/.*(\([^)]*\)).*/\1/' || true) + echo "[DEBUG] ping result: '$ip_address'" + fi + + # If still no IP, use the hostname itself (might work if it's already an IP or resolvable) + if [[ -z "$ip_address" ]]; then + echo "[WARNING] Could not resolve IP for $node, using hostname as fallback" + ip_address=$node + fi + + echo "[INFO] Node: $node -> IP: $ip_address" + # Add the IP address to the array + ip_addresses_array+=("$ip_address") +done + +head_node=${nodes_array[0]} +head_node_ip=${ip_addresses_array[0]} + +ip_head=$head_node_ip:$PORT + +# First we start the head of the ray cluster on one of the physical nodes +# Give the head node actual resources to make it schedulable + +head_cmd=$(cat < /dev/null 2>&1; then + for session_dir in /tmp/ray/session_[0-9]*/; do + if [[ -d "\$session_dir/logs" ]]; then + session_name=\$(basename "\$session_dir") + mkdir -p "$LOG_DIR/ray/\$session_name" + if command -v rsync > /dev/null 2>&1; then + rsync -ahP "\$session_dir/logs/" "$LOG_DIR/ray/\$session_name/logs/" 2>/dev/null || true + else + cp -r "\$session_dir/logs" "$LOG_DIR/ray/\$session_name/" + fi + fi + done + fi + if [[ -f "$LOG_DIR/ENDED" ]]; then + echo "Log sync sidecar terminating..." + break + fi + done +} +log-sync-sidecar & + +# Patch nsight.py before starting Ray head +sed -i 's/context\.py_executable = " "\.join(self\.nsight_cmd) + " python"/context.py_executable = " ".join(self.nsight_cmd) + f" {context.py_executable}"/g' /opt/nemo_rl_venv/lib64/python*/site-packages/ray/_private/runtime_env/nsight.py + +cat < /dev/null 2>&1; then + for session_dir in /tmp/ray/session_[0-9]*/; do + if [[ -d "\$session_dir/logs" ]]; then + session_name=\$(basename "\$session_dir") + mkdir -p "$LOG_DIR/ray/$node_i/\$session_name" + if command -v rsync > /dev/null 2>&1; then + rsync -ahP "\$session_dir/logs/" $LOG_DIR/ray/$node_i/\$session_name/logs/ 2>/dev/null || true + else + cp -r "\$session_dir/logs" $LOG_DIR/ray/$node_i/\$session_name/ + fi + fi + done + fi + if [[ -f "$LOG_DIR/ENDED" ]]; then + echo "Log sync sidecar terminating..." + break + fi + done +} +log-sync-sidecar & + +# Patch nsight.py before starting Ray worker +sed -i 's/context\.py_executable = " "\.join(self\.nsight_cmd) + " python"/context.py_executable = " ".join(self.nsight_cmd) + f" {context.py_executable}"/g' /opt/nemo_rl_venv/lib64/python*/site-packages/ray/_private/runtime_env/nsight.py + +cat <$SLURM_SUBMIT_DIR/${SLURM_JOB_ID}-attach.sh +# No args launches on the head node (node 0) +# Args 1-N launch on worker nodes (nodes 1 through N-1) +# Optional: set COMMAND='...' to run non-interactively instead of opening an interactive shell +WORKER_NUM=\${1:-} +if [[ -z "\$WORKER_NUM" ]]; then + # Empty means we are on the head node + if [[ -n "\${COMMAND:-}" ]]; then + srun --no-container-mount-home $GRES_ARG -A $SLURM_JOB_ACCOUNT -p $SLURM_JOB_PARTITION --overlap --container-name=ray-head --container-workdir=$CONTAINER_CWD --nodes=1 --ntasks=1 -w "$head_node" --jobid $SLURM_JOB_ID bash -c "\$COMMAND" + else + srun --no-container-mount-home $GRES_ARG -A $SLURM_JOB_ACCOUNT -p $SLURM_JOB_PARTITION --overlap --container-name=ray-head --container-workdir=$CONTAINER_CWD --nodes=1 --ntasks=1 -w "$head_node" --jobid $SLURM_JOB_ID --pty bash + fi +else + # Worker numbers 1 through N-1 correspond to ray-worker-1 through ray-worker-(N-1) + # and use nodes_array[1] through nodes_array[N-1] + if [[ \$WORKER_NUM -lt 1 || \$WORKER_NUM -ge $SLURM_JOB_NUM_NODES ]]; then + echo "Error: WORKER_NUM must be between 1 and $((SLURM_JOB_NUM_NODES-1))" + exit 1 + fi + nodes_array=($nodes) + if [[ -n "\${COMMAND:-}" ]]; then + srun --no-container-mount-home $GRES_ARG -A $SLURM_JOB_ACCOUNT -p $SLURM_JOB_PARTITION --overlap --container-name=ray-worker-\$WORKER_NUM --container-workdir=$CONTAINER_CWD --nodes=1 --ntasks=1 -w "\${nodes_array[\$WORKER_NUM]}" --jobid $SLURM_JOB_ID bash -c "\$COMMAND" + else + srun --no-container-mount-home $GRES_ARG -A $SLURM_JOB_ACCOUNT -p $SLURM_JOB_PARTITION --overlap --container-name=ray-worker-\$WORKER_NUM --container-workdir=$CONTAINER_CWD --nodes=1 --ntasks=1 -w "\${nodes_array[\$WORKER_NUM]}" --jobid $SLURM_JOB_ID --pty bash + fi +fi +EOF + chmod +x $SLURM_SUBMIT_DIR/${SLURM_JOB_ID}-attach.sh + echo " COMMAND='echo hello' bash $SLURM_SUBMIT_DIR/${SLURM_JOB_ID}-attach.sh # run a non-interactive command on head node" + echo " bash $SLURM_SUBMIT_DIR/${SLURM_JOB_ID}-attach.sh # to attach to head node (i.e., 'worker 0')" + echo " bash $SLURM_SUBMIT_DIR/${SLURM_JOB_ID}-attach.sh 1 # to attach to worker 1" + echo " bash $SLURM_SUBMIT_DIR/${SLURM_JOB_ID}-attach.sh 2 # to attach to worker 2, etc." + sleep infinity +fi diff --git a/shim/grpo_erdos_discover_debug.yaml b/shim/grpo_erdos_discover_debug.yaml new file mode 100644 index 0000000000..e34cd2d030 --- /dev/null +++ b/shim/grpo_erdos_discover_debug.yaml @@ -0,0 +1,78 @@ +# TTT-Discover DEBUG config. +# Inherits from grpo_math_1B.yaml for all defaults. +# Overrides: entropic advantages, inline reward, small batch, 5 steps. +defaults: "grpo_math_1B.yaml" + +grpo: + num_prompts_per_step: 4 + num_generations_per_prompt: 8 + max_num_epochs: 1 + max_num_steps: 5 + max_rollout_turns: 1 + remove_constant_reward_groups: true + adv_estimator: + name: entropic_adaptive_beta + gamma: 0.6931471805599453 + +loss_fn: + kl_penalty_coef: 0.1 + ratio_clip: 0.2 + token_level_loss: false + +policy: + model_name: "Qwen/Qwen2.5-1.5B-Instruct" + max_total_sequence_length: 4096 + train_global_batch_size: 32 + train_micro_batch_size: 4 + + dtensor_cfg: + enabled: true + cpu_offload: true + activation_checkpointing: true + sequence_parallel: false + + lora_cfg: + enabled: true + rank: 16 + alpha: 1.0 + dropout: 0.0 + + generation: + backend: "vllm" + max_new_tokens: 2048 + temperature: 1.0 + top_p: 1.0 + stop_token_ids: null + stop_strings: null + vllm_cfg: + async_engine: false + tensor_parallel_size: 1 + pipeline_parallel_size: 1 + expert_parallel_size: 1 + gpu_memory_utilization: 0.6 + max_model_len: ${policy.max_total_sequence_length} + +optimizer: + name: adamw + lr: 1.0e-4 + +data: + shuffle: false + +env: + erdos_discovery: + resource_server_url: "inline" + num_initial_states: 8 + num_groups_per_step: 4 + sandbox_timeout: 60 + request_timeout: 120 + +checkpointing: + enabled: false + +logger: + log_dir: "logs/erdos-debug" + wandb_enabled: false + tensorboard_enabled: false + mlflow_enabled: false + swanlab_enabled: false diff --git a/shim/nemo_rl/__init__.py b/shim/nemo_rl/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/shim/nemo_rl/algorithms/__init__.py b/shim/nemo_rl/algorithms/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/shim/nemo_rl/algorithms/entropic_advantage_estimator.py b/shim/nemo_rl/algorithms/entropic_advantage_estimator.py new file mode 100644 index 0000000000..af78ae5637 --- /dev/null +++ b/shim/nemo_rl/algorithms/entropic_advantage_estimator.py @@ -0,0 +1,179 @@ +"""Entropic Adaptive-Beta Advantage Estimator for TTT-Discover. + +Implements the Leave-One-Out (LOO) entropic advantage from +"Learning to Discover at Test Time" (arXiv:2601.16175). + +Instead of standard group-relative advantages (Adv = R - mean(R)), +this estimator: + 1. Solves for β such that KL(softmax_β(R) || uniform) = γ (default γ = ln(2)) + 2. Computes LOO advantages: w_i = exp(β·r_i) / Z_{-i} - 1 + where Z_{-i} is the normalizer excluding the i-th sample. + +Properties: + - Shift-invariant, approximately scale-invariant + - Monotone in reward + - Approximately mean-zero + - Adaptive scaling via β solves the reward-scale sensitivity of standard GRPO +""" + +import math +from typing import Optional + +import torch + + +def _solve_beta( + rewards: torch.Tensor, + gamma: float = math.log(2), + max_iter: int = 50, + tol: float = 1e-6, +) -> float: + """Solve for β such that KL(softmax_β(R) || uniform) = γ via bisection. + + Args: + rewards: [K] tensor of rewards for one group. + gamma: Target KL divergence. Default ln(2) as in the paper. + max_iter: Maximum bisection iterations. + tol: Convergence tolerance on β. + + Returns: + Scalar β value. + """ + K = rewards.shape[0] + if K <= 1: + return 0.0 + + log_K = math.log(K) + r = rewards.double() + r_max = r.max() + + def kl_at_beta(b: float) -> float: + logits = b * (r - r_max) + log_Z = torch.logsumexp(logits, dim=0) + logq = logits - log_Z + q = logq.exp() + kl = (q * (logq + log_K)).sum().item() + return kl + + # Bisect: KL is monotonically increasing in |β| for non-constant rewards + # Find upper bound for β + lo, hi = 0.0, 1.0 + while kl_at_beta(hi) < gamma and hi < 1e8: + hi *= 2.0 + + # Edge case: all rewards identical → β = 0, KL = 0 for any β + if hi >= 1e8: + return 0.0 + + for _ in range(max_iter): + mid = (lo + hi) / 2.0 + kl = kl_at_beta(mid) + if abs(kl - gamma) < tol: + return mid + if kl < gamma: + lo = mid + else: + hi = mid + + return (lo + hi) / 2.0 + + +def compute_entropic_advantages( + rewards: torch.Tensor, + gamma: float = math.log(2), + eps: float = 1e-8, +) -> torch.Tensor: + """Compute LOO entropic advantages for a group of rewards. + + Args: + rewards: [K] tensor of rewards for one group. + gamma: Target KL for adaptive β. + eps: Small constant for numerical stability. + + Returns: + [K] tensor of advantages. + """ + K = rewards.shape[0] + if K <= 1: + return torch.zeros_like(rewards) + + beta = _solve_beta(rewards, gamma=gamma) + if beta == 0.0: + return torch.zeros_like(rewards) + + r = rewards.double() + r_max = r.max() + e = torch.exp(beta * (r - r_max)) + + if K == 1: + Z_loo = e + else: + # Leave-one-out normalizer: Z_{-i} = (sum(e) - e_i) / (K - 1) + Z_loo = (e.sum() - e) / (K - 1) + + w = e / (Z_loo + eps) + advantages = (w - 1.0).to(rewards.dtype) + return advantages + + +class EntropicAdaptiveBetaAdvantageEstimator: + """Advantage estimator using entropic adaptive-β LOO weighting. + + Follows the same interface as GRPOAdvantageEstimator: + compute_advantage(prompt_ids, rewards, mask, **kwargs) -> [B, S] tensor + + Config keys (under grpo.adv_estimator): + gamma: Target KL for β search. Default ln(2) ≈ 0.693. + eps: Numerical stability constant. Default 1e-8. + """ + + def __init__(self, estimator_config: dict, loss_config: dict): + self.gamma = estimator_config.get("gamma", math.log(2)) + self.eps = estimator_config.get("eps", 1e-8) + + def compute_advantage( + self, + prompt_ids: torch.Tensor, + rewards: torch.Tensor, + mask: torch.Tensor, + **kwargs, + ) -> torch.Tensor: + """Compute per-token advantages using entropic adaptive-β LOO. + + Args: + prompt_ids: [B] or [B, S] tensor identifying which prompt each + sample belongs to (same prompt = same group). + rewards: [B] scalar rewards per sample. + mask: [B, S] response token mask (1 = generation token). + + Returns: + [B, S] advantages tensor. Each generation token gets the + sample-level advantage; non-generation tokens get 0. + """ + batch_size, seq_len = mask.shape + advantages = torch.zeros_like(mask, dtype=rewards.dtype) + + # Group by prompt (same as GRPO's per-prompt baseline) + if prompt_ids.dim() > 1: + # prompt_ids is [B, S] — use first token as group key + group_ids = prompt_ids[:, 0] + else: + group_ids = prompt_ids + + unique_prompts = group_ids.unique() + + for pid in unique_prompts: + group_mask = group_ids == pid + group_rewards = rewards[group_mask] + + group_adv = compute_entropic_advantages( + group_rewards, gamma=self.gamma, eps=self.eps + ) + + # Expand sample-level advantages to [group_size, seq_len] + # and mask to generation tokens only + group_indices = group_mask.nonzero(as_tuple=True)[0] + for i, idx in enumerate(group_indices): + advantages[idx] = group_adv[i] * mask[idx] + + return advantages diff --git a/shim/nemo_rl/algorithms/grpo.py b/shim/nemo_rl/algorithms/grpo.py new file mode 100644 index 0000000000..58f722c653 --- /dev/null +++ b/shim/nemo_rl/algorithms/grpo.py @@ -0,0 +1,3205 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import gc +import os +import time +import warnings +from concurrent.futures import ThreadPoolExecutor +from contextlib import nullcontext +from typing import Any, NotRequired, Optional, TypedDict, TypeVar, cast + +import numpy as np +import ray +import torch +from torchdata.stateful_dataloader import StatefulDataLoader +from transformers import AutoProcessor +from transformers.tokenization_utils_base import PreTrainedTokenizerBase + +from nemo_rl.algorithms.advantage_estimator import ( + GDPOAdvantageEstimator, + GRPOAdvantageEstimator, + ReinforcePlusPlusAdvantageEstimator, +) +from nemo_rl.algorithms.loss import ( + ClippedPGLossConfig, + ClippedPGLossDataDict, + ClippedPGLossFn, +) +from nemo_rl.algorithms.loss.interfaces import LossFunction +from nemo_rl.algorithms.reward_functions import ( + RewardShapingConfig, + apply_reward_shaping, +) +from nemo_rl.algorithms.utils import ( + calculate_baseline_and_std_per_prompt, + get_gdpo_reward_component_keys, + log_generation_metrics_to_wandb, + print_performance_metrics, + set_seed, +) +from nemo_rl.data import DataConfig +from nemo_rl.data.collate_fn import rl_collate_fn +from nemo_rl.data.dataloader import MultipleDataloaderWrapper +from nemo_rl.data.datasets import AllTaskProcessedDataset +from nemo_rl.data.interfaces import DatumSpec +from nemo_rl.data.llm_message_utils import ( + batched_message_log_to_flat_message, + get_keys_from_message_log, +) +from nemo_rl.data.utils import extract_necessary_env_names +from nemo_rl.distributed.batched_data_dict import BatchedDataDict +from nemo_rl.distributed.ray_actor_environment_registry import get_actor_python_env +from nemo_rl.distributed.virtual_cluster import ClusterConfig, RayVirtualCluster +from nemo_rl.environments.interfaces import EnvironmentInterface +from nemo_rl.experience.rollouts import ( + run_async_multi_turn_rollout, + run_async_nemo_gym_rollout, + run_multi_turn_rollout, +) +from nemo_rl.models.generation.interfaces import GenerationInterface +from nemo_rl.models.generation.sglang import SGLangConfig, SGLangGeneration +from nemo_rl.models.generation.vllm import VllmConfig, VllmGeneration +from nemo_rl.models.policy import PolicyConfig +from nemo_rl.models.policy.interfaces import ColocatablePolicyInterface +from nemo_rl.models.policy.lm_policy import Policy +from nemo_rl.utils.checkpoint import CheckpointingConfig, CheckpointManager +from nemo_rl.utils.logger import ( + Logger, + LoggerConfig, + print_message_log_samples, +) +from nemo_rl.utils.memory_tracker import MemoryTracker +from nemo_rl.utils.nsys import maybe_gpu_profile_step +from nemo_rl.utils.timer import TimeoutChecker, Timer +from nemo_rl.utils.venvs import create_local_venv_on_each_node + +# =============================================================================== +# Configuration +# =============================================================================== +TokenizerType = TypeVar("TokenizerType", bound=PreTrainedTokenizerBase) + + +class RewardScalingConfig(TypedDict): + """Configure linear reward scaling with clamping. + + When `enabled` is True, each reward is clamped to the source interval + [source_min, source_max] and linearly mapped to the target interval + [target_min, target_max]. Refer to the scale_rewards function for the implementation. + + Defaults: + source_min=0.0, source_max=1.0, target_min=0.0, target_max=1.0 + """ + + enabled: bool + source_min: NotRequired[float] + source_max: NotRequired[float] + target_min: NotRequired[float] + target_max: NotRequired[float] + + +class AsyncGRPOConfig(TypedDict): + enabled: bool + # Maximum trajectory age in training steps for samples drawn from the + # async replay buffer. Trajectories older than this are excluded during + # sampling; buffer sizing also scales with this value. + max_trajectory_age_steps: int + # Does the weight synchronization as soon as the training is done + # without waiting for the pending generations to finish. + in_flight_weight_updates: NotRequired[bool] + # Recomputes the KV cache after the in-flight weight updates. + recompute_kv_cache_after_weight_updates: NotRequired[bool] + + +class AdvEstimatorConfig(TypedDict): + """Configuration for advantage estimator (GRPO, GDPO, Reinforce++, or Entropic).""" + + name: str # "grpo", "gdpo", "reinforce_plus_plus", or "entropic_adaptive_beta" + # GRPO specific + normalize_rewards: NotRequired[bool] + use_leave_one_out_baseline: NotRequired[bool] + # Reinforce++ specific + minus_baseline: NotRequired[bool] + # Entropic Adaptive-Beta specific (TTT-Discover, arXiv:2601.16175) + gamma: NotRequired[float] # Target KL for beta search; default ln(2) + eps: NotRequired[float] # Numerical stability; default 1e-8 + + +class GRPOConfig(TypedDict): + num_prompts_per_step: int + num_generations_per_prompt: int + max_num_epochs: int + max_num_steps: int + max_rollout_turns: int + normalize_rewards: bool + use_leave_one_out_baseline: bool + val_period: int + val_batch_size: int + val_at_start: bool + # Whether to run validation on the last training step. Setting this to True ensures the + # final checkpoint has validation metrics, which is required for get_best_checkpoint_path(). + val_at_end: bool + max_val_samples: int + skip_reference_policy_logprobs_calculation: NotRequired[bool] + seed: int + async_grpo: NotRequired[AsyncGRPOConfig] + overlong_filtering: NotRequired[bool] + # whether to enable dynamic sampling, i.e. + # whether to discard prompts whose rewards have zero standard deviation + use_dynamic_sampling: bool + # When using dynamic sampling, the maximum number of batches to generate + # before throwing an error + dynamic_sampling_max_gen_batches: NotRequired[int] + # When using dynamic sampling, generation prompt batch size will equal + # num_prompts_per_step * batch_multiplier + batch_multiplier: NotRequired[float] + reward_shaping: RewardShapingConfig + reward_scaling: RewardScalingConfig + # By default advantages are calculated on CPU. Setting this flag to true leverages GPU for their computation. + calculate_advantages_on_gpu: NotRequired[bool] + # Sequence-level logprob error masking for training stability. If set, mask sequences with mult_prob_error exceeding this threshold (same scale as token_mult_prob_error metric, e.g., 1.5) + # Note that this is slightly different than Masked Importance Sampling (MIS) because this uses the absolute value of the difference between the training and generation logprobs, whereas MIS just uses the difference between the training and generation logprobs. + seq_logprob_error_threshold: float | None + # Advantage estimator configuration (grpo or reinforce_plus_plus) + adv_estimator: NotRequired[AdvEstimatorConfig] + + +class GRPOSaveState(TypedDict): + consumed_samples: int + current_step: int + current_epoch: int + total_steps: int + total_valid_tokens: int # Track total number of non-padding tokens during training + val_reward: NotRequired[ + float + ] # Optional field - may not be present during training + + +def _default_grpo_save_state() -> GRPOSaveState: + return { + "consumed_samples": 0, + "current_step": 0, + "current_epoch": 0, + "total_steps": 0, + "total_valid_tokens": 0, + "val_reward": -99999999.0, + } + + +class GRPOLoggerConfig(LoggerConfig): + num_val_samples_to_print: int # number of val samples to print to stdout + + +class MasterConfig(TypedDict): + policy: PolicyConfig + loss_fn: ClippedPGLossConfig + env: dict[str, Any] + data: DataConfig + grpo: GRPOConfig + logger: GRPOLoggerConfig + cluster: ClusterConfig + checkpointing: CheckpointingConfig + + +# =============================================================================== +# Setup & Initialization +# =============================================================================== + + +def setup( + master_config: MasterConfig, + tokenizer: TokenizerType, + dataset: AllTaskProcessedDataset | dict[str, AllTaskProcessedDataset], + val_dataset: Optional[AllTaskProcessedDataset], + processor: Optional[AutoProcessor] = None, +) -> tuple[ + ColocatablePolicyInterface, + Optional[GenerationInterface], + tuple[RayVirtualCluster, RayVirtualCluster], + StatefulDataLoader | MultipleDataloaderWrapper, + Optional[StatefulDataLoader], + ClippedPGLossFn, + Logger, + CheckpointManager, + GRPOSaveState, + MasterConfig, +]: + """Main entry point for running GRPO algorithm. + + Returns: + tuple of policy, cluster, dataloader, tokenizer, loss_fn, math_env, logger, master_config, val_dataloader + """ + # Start timing the entire setup process + setup_start_time = time.perf_counter() + + # Extract individual configs for easier access + policy_config = master_config["policy"] + generation_config = master_config["policy"]["generation"] + env_configs = master_config["env"] + loss_config = master_config["loss_fn"] + grpo_config = master_config["grpo"] + data_config = master_config["data"] + logger_config = master_config["logger"] + cluster_config = master_config["cluster"] + + assert generation_config is not None, ( + "A generation config in the PolicyConfig is required for GRPO" + ) + + # Set seed for all random number generators + set_seed(grpo_config["seed"]) + + # ========================== + # Logger + # ========================== + logger = Logger(logger_config) + logger.log_hyperparams(master_config) + + # ========================== + # Checkpointing + # ========================== + checkpointer = CheckpointManager(master_config["checkpointing"]) + last_checkpoint_path = checkpointer.get_latest_checkpoint_path() + grpo_save_state: Optional[GRPOSaveState] = cast( + Optional[GRPOSaveState], checkpointer.load_training_info(last_checkpoint_path) + ) + if grpo_save_state is None: + grpo_save_state = _default_grpo_save_state() + + # ========================== + # Data + # ========================== + # num_prompts_per_step and dataloader_batch_size will be different when using multiple dataloaders + num_prompts_per_step = grpo_config["num_prompts_per_step"] + if data_config["use_multiple_dataloader"]: + dataloader_batch_size = data_config["num_prompts_per_dataloader"] + else: + dataloader_batch_size = num_prompts_per_step + + # Validate batch_multiplier + batch_multiplier = grpo_config["batch_multiplier"] + if grpo_config["use_dynamic_sampling"]: + num_prompts_per_step = int(num_prompts_per_step * batch_multiplier) + dataloader_batch_size = int(dataloader_batch_size * batch_multiplier) + else: + assert batch_multiplier == 1, ( + "batch_multiplier>1 can only be used if use_dynamic_sampling=True" + ) + + # Validate number of prompts per step + if data_config["use_multiple_dataloader"]: + assert num_prompts_per_step % dataloader_batch_size == 0, ( + "Expected num_prompts_per_step to be a multiple of num_prompts_per_dataloader, " + f"but got {num_prompts_per_step} and {dataloader_batch_size}. " + "Please check the configuration of num_prompts_per_step and num_prompts_per_dataloader. " + "If use_dynamic_sampling is enabled and batch_multiplier is used, please also check the configuration of batch_multiplier." + ) + + # Load train dataset + def init_train_dataloader(dataset, suffix: str = ""): + dataloader = StatefulDataLoader( + dataset, + batch_size=dataloader_batch_size, + shuffle=data_config["shuffle"], + collate_fn=rl_collate_fn, + drop_last=True, + num_workers=data_config["num_workers"], + ) + if last_checkpoint_path is not None: + dataloader_state_dict = torch.load( + os.path.join(last_checkpoint_path, f"train_dataloader{suffix}.pt") + ) + dataloader.load_state_dict(dataloader_state_dict) + return dataloader + + if data_config["use_multiple_dataloader"]: + # Initialize dataloaders + dataloaders = {} + for task_name, task_dataset in dataset.items(): + dataloaders[task_name] = init_train_dataloader( + task_dataset, f"_{task_name}" + ) + print( + f" ✓ Training dataloader {task_name} loaded with {len(task_dataset)} samples", + flush=True, + ) + + train_sample_count = sum( + len(task_dataloader) for task_dataloader in dataloaders.values() + ) + + # Wrap dataloader + dataloader = MultipleDataloaderWrapper( + expected_num_prompts=num_prompts_per_step, + data_config=data_config, + dataloaders=dataloaders, + ) + else: + dataloader = init_train_dataloader(dataset) + train_sample_count = len(dataloader) + print( + f" ✓ Training dataloader loaded with {train_sample_count} samples", + flush=True, + ) + + # Load validation dataset if provided + val_dataloader: Optional[StatefulDataLoader] = None + # If validation is enabled, load the validation dataloader + if ( + grpo_config["val_period"] > 0 + or grpo_config["val_at_start"] + or grpo_config["val_at_end"] + ): + assert val_dataset is not None, ( + "Validation dataset is required if validation is enabled" + ) + val_dataloader = StatefulDataLoader( + val_dataset, + batch_size=grpo_config["val_batch_size"], + shuffle=False, + collate_fn=rl_collate_fn, + num_workers=data_config["num_workers"], + ) + print( + f" ✓ Validation dataloader loaded with {len(val_dataset)} samples", + flush=True, + ) + + # ========================== + # Loss Function + # ========================== + loss_fn = ClippedPGLossFn(loss_config) + + # Validate force_on_policy_ratio + if loss_config.get("force_on_policy_ratio", False): + assert ( + grpo_config["num_prompts_per_step"] + * grpo_config["num_generations_per_prompt"] + == policy_config["train_global_batch_size"] + ), ( + "force_on_policy_ratio requires train_global_batch_size == num_prompts_per_step * num_generations_per_prompt" + ) + os.environ["NRL_IGNORE_TP_ACCURACY_CHECK"] = "1" + print(" ✓ force_on_policy_ratio enabled") + + # ========================== + # Cluster + # ========================== + print("\n▶ Setting up compute cluster...", flush=True) + colocated_inference = generation_config["colocated"]["enabled"] + + env_name_list = extract_necessary_env_names(data_config) + rm_env_enabled = "reward_model" in env_name_list + + total_nodes = cluster_config["num_nodes"] + if rm_env_enabled: + rm_resource = env_configs["reward_model"]["resources"] + rm_nodes = rm_resource["num_nodes"] + rm_gpus_per_node = rm_resource["gpus_per_node"] + else: + rm_nodes = 0 + rm_gpus_per_node = 0 + + if total_nodes == 1: + policy_nodes = total_nodes + else: + policy_nodes = total_nodes - rm_nodes + assert policy_nodes > 0, ( + "policy_nodes must be > 0, but got " + f"policy_nodes:{policy_nodes} + rm_nodes:{rm_nodes} = total_nodes:{total_nodes}" + ) + + if colocated_inference: + if total_nodes == 1: + policy_gpus_per_node = cluster_config["gpus_per_node"] - rm_gpus_per_node + assert policy_gpus_per_node > 0, ( + "policy.generation.colocated.resources.gpus_per_node must be > 0 " + "when cluster.num_nodes = 1, " + f"but got {policy_gpus_per_node}." + ) + else: + policy_gpus_per_node = cluster_config["gpus_per_node"] + + cluster = RayVirtualCluster( + name="grpo_policy_cluster", + bundle_ct_per_node_list=[policy_gpus_per_node] * policy_nodes, + use_gpus=True, + num_gpus_per_node=policy_gpus_per_node, + max_colocated_worker_groups=1 + if generation_config["backend"] == "megatron" + else 2, + ) + train_cluster = cluster + inference_cluster = cluster + print( + f" ✓ Ray cluster for policy initialized with {policy_nodes} nodes", + flush=True, + ) + + else: + assert generation_config["backend"] != "megatron", ( + "Non-colocated inference is not supported for Megatron generation backends. " + "Please use vLLM backend for generation." + ) + + # train resources will be updated through overall and inference resources below + train_gpus_per_node = cluster_config["gpus_per_node"] + train_nodes = policy_nodes + + inference_resources = generation_config["colocated"]["resources"] + inference_gpus_per_node = inference_resources["gpus_per_node"] + inference_nodes = inference_resources["num_nodes"] + + # validate and configure resources + if policy_nodes == 1: + # When policy_nodes == 1, train and inference are on the same node + assert ( + inference_gpus_per_node is not None and inference_gpus_per_node > 0 + ), ( + "policy.generation.colocated.resources.gpus_per_node must be explicitly set to a value > 0 " + "when policy_nodes = 1 and inference is non-colocated, " + f"but got {inference_gpus_per_node}." + ) + assert inference_nodes is None or inference_nodes == 1, ( + "policy.generation.colocated.resources.num_nodes must be 1 or set to null " + "when policy_nodes = 1 and inference is non-colocated, " + f"but got {inference_nodes}." + ) + + inference_nodes = 1 + # If total_nodes == 1, reward model is also on the same node; otherwise it's on a different node + reward_gpus_to_subtract = ( + rm_gpus_per_node if total_nodes == 1 and rm_env_enabled else 0 + ) + train_gpus_per_node -= inference_gpus_per_node + reward_gpus_to_subtract + assert train_gpus_per_node > 0, ( + "No enough GPUs for training, " + f"train_gpus_per_node:{train_gpus_per_node} = cluster_config['gpus_per_node']:{cluster_config['gpus_per_node']} - inference_gpus_per_node:{inference_gpus_per_node}" + + ( + f" - rm_gpus_per_node:{rm_gpus_per_node}" + if total_nodes == 1 and rm_env_enabled + else "" + ) + ) + else: + # train, inference, and reward model are all on different nodes + assert inference_nodes > 0, ( + "policy.generation.colocated.resources.num_nodes must be > 0 " + "when cluster.num_nodes > 1 and inference is non-colocated, " + f"but got {inference_nodes}." + ) + assert ( + inference_gpus_per_node is not None + and inference_gpus_per_node == cluster_config["gpus_per_node"] + ), ( + "policy.generation.colocated.resources.gpus_per_node must be explicitly set and equal to cluster.gpus_per_node " + "when cluster.num_nodes > 1 and inference is non-colocated, " + f"but got inference_gpus_per_node={inference_gpus_per_node}, cluster.gpus_per_node={cluster_config['gpus_per_node']}." + ) + train_nodes -= inference_nodes + + # initialize train cluster + train_cluster = RayVirtualCluster( + name="grpo_train_cluster", + bundle_ct_per_node_list=[train_gpus_per_node] * train_nodes, + use_gpus=True, + num_gpus_per_node=train_gpus_per_node, + max_colocated_worker_groups=1, + ) + print( + f" ✓ Ray train cluster initialized with {train_nodes} nodes with {train_gpus_per_node} GPUs per node", + flush=True, + ) + + # initialize inference cluster + inference_cluster = RayVirtualCluster( + name="grpo_inference_cluster", + bundle_ct_per_node_list=[inference_gpus_per_node] * inference_nodes, + use_gpus=True, + num_gpus_per_node=inference_gpus_per_node, + max_colocated_worker_groups=1, + ) + print( + f" ✓ Ray inference cluster initialized with {inference_nodes} nodes with {inference_gpus_per_node} GPUs per node", + flush=True, + ) + + # ========================== + # Training and Inference + # ========================== + print("\n▶ Setting up model and training...", flush=True) + + # vllm model loading prefers clean environment, initialize policy_generation before policy in colocated mode + backend = generation_config["backend"] + generation_config["model_name"] = policy_config["model_name"] # Needed for vLLM + + # Dictionary to store worker initialization timing stats for logging + worker_init_timing_metrics = {} + + weights_path, optimizer_path = checkpointer.get_resume_paths(last_checkpoint_path) + + if policy_config.get("megatron_cfg", {}).get("enabled", False): + ## NOTE: this is equal to the total number of scheduler steps + total_train_iters = min( + grpo_config["max_num_steps"], + grpo_config["max_num_epochs"] * train_sample_count, + ) + policy_config["megatron_cfg"]["train_iters"] = total_train_iters + + # Define initialization functions that will be used in all paths + def init_policy(): + """Initialize policy training workers.""" + t0 = time.perf_counter() + p = Policy( + cluster=train_cluster, + config=policy_config, + tokenizer=tokenizer, + processor=processor, + weights_path=weights_path, + optimizer_path=optimizer_path, + init_optimizer=True, + ) + return p, time.perf_counter() - t0 + + def init_vllm(): + """Initialize vLLM generation workers.""" + t0 = time.perf_counter() + pg = VllmGeneration(cluster=inference_cluster, config=generation_config) + pg.finish_generation() + return pg, time.perf_counter() - t0 + + def init_sglang(): + """Initialize SGLang generation workers.""" + t0 = time.perf_counter() + pg = SGLangGeneration(cluster=inference_cluster, config=generation_config) + pg.finish_generation() + return pg, time.perf_counter() - t0 + + def initialize_generation_with_policy( + init_generation_fn, + generation_name: str, + init_time_key: str, + colocated_inference: bool, + worker_init_timing_metrics: dict, + ): + """Generic function to initialize a generation engine (vLLM or SGLang) along with policy. + + Args: + init_generation_fn: Function that initializes the generation engine (init_vllm or init_sglang) + generation_name: Name of the generation engine ("vLLM" or "SGLang") + init_time_key: Key name for storing initialization time in metrics ("vllm_init_time_s" or "sglang_init_time_s") + colocated_inference: Whether inference is colocated with training + worker_init_timing_metrics: Dictionary to store timing metrics + + Returns: + Tuple of (policy_generation, policy) + """ + # Determine if parallel initialization is possible (non-colocated mode) + use_parallel_init = not colocated_inference + + if use_parallel_init: + # Parallel initialization: Generation engine and Policy can initialize simultaneously + print( + " ⚡ Using parallel worker initialization (non-colocated mode)", + flush=True, + ) + + # Execute both initializations in parallel + parallel_start_time = time.perf_counter() + with ThreadPoolExecutor(max_workers=2) as executor: + generation_future = executor.submit(init_generation_fn) + policy_future = executor.submit(init_policy) + policy_generation, generation_time = generation_future.result() + policy, policy_time = policy_future.result() + parallel_wall_time = time.perf_counter() - parallel_start_time + + # Store timing metrics + worker_init_timing_metrics[init_time_key] = generation_time + worker_init_timing_metrics["policy_init_time_s"] = policy_time + worker_init_timing_metrics["parallel_wall_time_s"] = parallel_wall_time + worker_init_timing_metrics["parallel_init_enabled"] = True + + else: + # Sequential initialization: colocated mode (GPU memory requires generation engine first) + print( + " ⚙️ Using sequential worker initialization (colocated mode)", + flush=True, + ) + + # Initialize generation engine first (clean GPU memory), then policy + policy_generation, generation_time = init_generation_fn() + worker_init_timing_metrics[init_time_key] = generation_time + + policy, policy_time = init_policy() + worker_init_timing_metrics["policy_init_time_s"] = policy_time + worker_init_timing_metrics["parallel_init_enabled"] = 0.0 + + return policy_generation, policy + + # Handle generation-specific setup + if backend == "megatron": + # Megatron generation: policy_generation is None, only initialize policy + policy_generation = None + print( + f" ✓ Using {backend} backend for generation with {policy_config['model_name']}", + flush=True, + ) + + policy, policy_time = init_policy() + worker_init_timing_metrics["policy_init_time_s"] = policy_time + + elif backend == "vllm": + # vLLM generation: setup config, then initialize with policy + generation_config = cast(VllmConfig, generation_config) + if generation_config["vllm_cfg"]["precision"] == "fp8": + assert loss_config["use_importance_sampling_correction"] is True, ( + "Importance sampling must be enabled for vLLM FP8 generation for good convergence!" + ) + if generation_config["vllm_cfg"]["kv_cache_dtype"].startswith("fp8"): + # FP8 KV cache requires FP8 model precision + assert generation_config["vllm_cfg"]["precision"] == "fp8", ( + f"kv_cache_dtype='{generation_config['vllm_cfg']['kv_cache_dtype']}' requires precision='fp8'. " + "FP8 KV cache can only be used together with FP8 model weights." + ) + # FP8 KV cache compatibility checks + assert policy_config["dtensor_cfg"]["enabled"] == False, ( + "DTensor backend is not supported with kv cache fp8 enabled." + ) + assert not _should_use_async_rollouts(master_config), ( + "Async rollouts is not supported with kv cache fp8 enabled." + ) + assert policy_config["megatron_cfg"]["pipeline_model_parallel_size"] == 1, ( + "Currently when using FP8 KV cache in generation, then in megatron we only support pipeline_model_parallel_size=1. We will add more support in future." + ) + + ## make vllm hf overrides match the training policy + generation_config["vllm_cfg"]["hf_overrides"] = policy_config.get( + "hf_config_overrides", {} + ) + + policy_generation, policy = initialize_generation_with_policy( + init_generation_fn=init_vllm, + generation_name="vLLM", + init_time_key="vllm_init_time_s", + colocated_inference=colocated_inference, + worker_init_timing_metrics=worker_init_timing_metrics, + ) + + print( + f" ✓ Using vLLM backend for generation with {policy_config['model_name']}", + flush=True, + ) + + elif backend == "sglang": + generation_config = cast(SGLangConfig, generation_config) + + # Set model_path if not already set + if "model_path" not in generation_config["sglang_cfg"]: + generation_config["sglang_cfg"]["model_path"] = policy_config["model_name"] + + policy_generation, policy = initialize_generation_with_policy( + init_generation_fn=init_sglang, + generation_name="SGLang", + init_time_key="sglang_init_time_s", + colocated_inference=colocated_inference, + worker_init_timing_metrics=worker_init_timing_metrics, + ) + + print( + f" ✓ Using SGLang backend for generation with {policy_config['model_name']}", + flush=True, + ) + + # Record when worker initialization completes (for calculating other setup time) + worker_init_complete_time = time.perf_counter() - setup_start_time + + # print the node IP and GPU ID of the policy workers for debugging + policy.print_node_ip_and_gpu_id() + + # if it is not colocated inference, initialize collective communication for update weights + if not colocated_inference: + t0 = time.perf_counter() + ip, port = train_cluster.get_master_address_and_port() + print(f"Using ip: {ip}, port: {port} for collective communication", flush=True) + # world includes all training workers and all inference workers + train_world_size = train_cluster.world_size() + inference_world_size = inference_nodes * inference_gpus_per_node + world_size = train_world_size + inference_world_size + # init collective + futures_train = policy.init_collective( + ip, port, world_size, train_world_size=train_world_size + ) + futures_inference = policy_generation.init_collective( + ip, port, world_size, train_world_size=train_world_size + ) # type: ignore + # wait for all futures to complete + ray.get(futures_train + futures_inference) + worker_init_timing_metrics["collective_init_time_s"] = time.perf_counter() - t0 + + # prepare refit info + state_dict_info = policy.prepare_refit_info() + if policy_generation is not None: + policy_generation.prepare_refit_info(state_dict_info) + + # Calculate total setup time + total_setup_time = time.perf_counter() - setup_start_time + worker_init_timing_metrics["total_setup_time_s"] = total_setup_time + + # Log worker initialization timing metrics to logger + if worker_init_timing_metrics: + print("\n▶ Worker Initialization Timing:") + + vllm_time = worker_init_timing_metrics.get("vllm_init_time_s", 0) + policy_time = worker_init_timing_metrics.get("policy_init_time_s", 0) + total_setup = worker_init_timing_metrics.get("total_setup_time_s", 0) + + if vllm_time: + print(f" vLLM init: {vllm_time:.1f}s") + + if policy_time: + print(f" Policy init: {policy_time:.1f}s") + + # Calculate "other" time (time after worker init completes) + other_time = total_setup - worker_init_complete_time + worker_init_timing_metrics["other_setup_time_s"] = other_time + print(f" Other setup: {other_time:.1f}s") + + print(f" Total setup: {total_setup:.1f}s") + + # Log all metrics to the logger for analysis + logger.log_metrics(worker_init_timing_metrics, step=0, prefix="timing/setup") + + print("\n" + "=" * 60) + print(" " * 18 + "SETUP COMPLETE") + print(f" Total setup time: {total_setup_time:.1f}s") + print("=" * 60 + "\n", flush=True) + + return ( + policy, + policy_generation, + (train_cluster, inference_cluster), + dataloader, + val_dataloader, + loss_fn, + logger, + checkpointer, + grpo_save_state, + master_config, + ) + + +# =============================================================================== +# Core Algorithm Functions +# =============================================================================== + + +def dynamic_sampling( + repeated_batch: BatchedDataDict[DatumSpec], + std: torch.Tensor, + baseline: torch.Tensor, + dynamic_sampling_num_gen_batches: int, + master_config: MasterConfig, + timer: Timer, + batch_cache: BatchedDataDict[DatumSpec] = None, +) -> BatchedDataDict[DatumSpec]: + """Implements the dynamic sampling algorithm to select prompts with non-zero standard deviation. + + This function filters the current batch to retain only those prompts that have a non-zero standard deviation. + If the current batch has fewer number of prompts with non-zero standard deviation than the required batch size, defined as num_prompts_per_step * num_generations_per_prompt, + we store it in the batch_cache to be used in later iterations. + If the current batch has more number of prompts with non-zero standard deviation than the required batch size, defined as num_prompts_per_step * num_generations_per_prompt, + the batch is sliced to ensure batch size is num_prompts_per_step * num_generations_per_prompt. + is_batch_complete is set to False to indicate that the current batch is not enough to meet the required batch size. This is used as a signal in the GRPO training loop + to continue sampling or proceed to training. + This approach is based on the dynamic sampling algorithm from the DAPO paper: + https://arxiv.org/pdf/2503.14476. + + Args: + repeated_batch (BatchedDataDict[DatumSpec]): The current batch of data containing prompts, responses, rewards, baselines, and std. + std (torch.Tensor): Tensor representing the standard deviation for each prompt group. + baseline (torch.Tensor): Baseline values for each prompt group. + dynamic_sampling_num_gen_batches (int): Number of generation batches processed at the current step. + master_config (MasterConfig): Configuration containing GRPO and policy settings. + batch_cache (BatchedDataDict[DatumSpec], optional): Cache storing previously selected prompts with non-zero std. + + Returns: + tuple: A tuple containing: + - repeated_batch (BatchedDataDict[DatumSpec]): Updated batch with selected prompts. + - is_batch_complete (bool): Indicates if the batch has enough samples with non-zero std for training. + - batch_cache (BatchedDataDict[DatumSpec]): Updated cache for future iterations. + """ + # is_batch_complete is used to indicate if the current batch was able to generate enough prompts with non-zero std. + is_batch_complete = True + + # Required batch size for training + train_prompts_size = ( + master_config["grpo"]["num_prompts_per_step"] + * master_config["grpo"]["num_generations_per_prompt"] + ) + # Store the baseline, std and total_reward for the current unfiltered batch. + repeated_batch["baseline"] = baseline + repeated_batch["std"] = std + total_rewards = repeated_batch["total_reward"] + dynamic_sampling_metrics = {} + + # Dynamic sampling algorithm (used in DAPO algorithm) + # This block implements dynamic sampling by selecting prompt groups with non-zero std. + # If sampled prompts (with non-zero std) are fewer than num_prompts_per_step * num_generations_per_prompt, continue sampling until dynamic_sampling_max_gen_batches is reached. + if master_config["grpo"]["use_dynamic_sampling"]: + with timer.time("dynamic_sampling"): + # Get the prompt indices with non-zero std + non_zero_std_mask = std != 0.0 + + keep_prompt_indices = torch.arange( + len(non_zero_std_mask), device=std.device + )[non_zero_std_mask].tolist() + + # Only select the inputs that have non-zero std + # total_reward is already a part of repeated_batch so we don't need to add it again + filtered_repeated_batch = repeated_batch.select_indices(keep_prompt_indices) + filtered_repeated_batch["std"] = std[keep_prompt_indices] + filtered_repeated_batch["baseline"] = baseline[keep_prompt_indices] + + # Store filtered and total rewards to track them separately + filtered_rewards = filtered_repeated_batch["total_reward"] + filtered_repeated_batch["total_reward"] = total_rewards + filtered_repeated_batch["filtered_reward"] = filtered_rewards + + # Store the total_reward for the current filtered batch. + # If none of the prompts in current batch have non-zero std, filtered_repeated_batch.size will be 0. + # In this case, the current batch will be ignored and the next batch will be processed and we generate responses for it. + if filtered_repeated_batch.size > 0: + # Concatenate the previous partially filled batch with the current batch. This serves as a cache to store and collect the prompts with non-zero std. + # This is used in the next iteration when the current batch is not enough to fill the buffer. + batch_cache = ( + filtered_repeated_batch + if batch_cache is None + else BatchedDataDict.from_batches( + [batch_cache, filtered_repeated_batch] + ) + ) + filtered_repeated_batch = batch_cache + + filtered_prompts_size = filtered_repeated_batch.size + print( + f"Detected {filtered_prompts_size} prompts with non-zero std; " + f"{train_prompts_size} are required and used for training." + ) + + # If the generation samples size is smaller than a fixed threshold (train_prompts_size), keep generating by processing the next batch + if filtered_prompts_size < train_prompts_size: + dynamic_sampling_max_gen_batches = master_config["grpo"][ + "dynamic_sampling_max_gen_batches" + ] + assert dynamic_sampling_max_gen_batches > 0, ( + "When using grpo.use_dynamic_sampling, grpo.dynamic_sampling_max_gen_batches must be > 0" + ) + if dynamic_sampling_num_gen_batches <= dynamic_sampling_max_gen_batches: + print( + f"Generation sample buffer size: {filtered_prompts_size} is smaller than train_prompts_size: {train_prompts_size}. Processed {dynamic_sampling_num_gen_batches} batches so far out of {dynamic_sampling_max_gen_batches}." + ) + is_batch_complete = False + else: + raise ValueError( + f"Dynamic sampling has reached the maximum allowed number of batches ({dynamic_sampling_max_gen_batches}). Consider evaluating the complexity of your data or adjusting the num_prompts_per_step or num_generations_per_prompt parameters to enhance the diversity of the samples." + ) + else: + num_discarded_valid_samples = filtered_prompts_size - train_prompts_size + dynamic_sampling_metrics[ + "dynamic_sampling_num_discarded_valid_samples" + ] = num_discarded_valid_samples + + # Slice the batch, rewards, baselines and std to ensure batch size is train_prompts_size + filtered_repeated_batch = filtered_repeated_batch.slice( + 0, train_prompts_size + ) + + batch_to_return = ( + filtered_repeated_batch + if master_config["grpo"]["use_dynamic_sampling"] + else repeated_batch + ) + return batch_to_return, is_batch_complete, batch_cache, dynamic_sampling_metrics + + +def scale_rewards( + repeated_batch: BatchedDataDict[DatumSpec], reward_scaling_cfg: RewardScalingConfig +) -> BatchedDataDict[DatumSpec]: + """Linearly scales rewards from a source range to a target range. + + If `reward_scaling.enabled` is True, each reward in `repeated_batch["total_reward"]` + is clamped to the configured source interval [source_min, source_max] and then + rescaled to the target interval [target_min, target_max]. + + Default configuration: + source_min = 0.0 + source_max = 1.0 + target_min = 0.0 + target_max = 1.0 + """ + if reward_scaling_cfg["enabled"]: + rewards = repeated_batch["total_reward"] + source_min = float(reward_scaling_cfg["source_min"]) + source_max = float(reward_scaling_cfg["source_max"]) + target_min = float(reward_scaling_cfg["target_min"]) + target_max = float(reward_scaling_cfg["target_max"]) + + # Detect out-of-range values + out_of_range_mask = (rewards < source_min) | (rewards > source_max) + if torch.any(out_of_range_mask): + print( + f"[reward_scaling] WARNING: {int(out_of_range_mask.sum())} rewards " + f"are outside the configured source range [{source_min}, {source_max}]. " + f"Values will be clipped before scaling." + ) + + # Clamp and scale + def _scale(reward_tensor: torch.Tensor) -> torch.Tensor: + r = torch.clamp(reward_tensor, min=source_min, max=source_max) + return target_min + (r - source_min) / (source_max - source_min) * ( + target_max - target_min + ) + + scaled_rewards = _scale(rewards) + repeated_batch["total_reward"] = scaled_rewards + for key in get_gdpo_reward_component_keys(repeated_batch): + repeated_batch[key] = _scale(repeated_batch[key]) + + return repeated_batch + + +def _should_use_async_rollouts(master_config: MasterConfig) -> bool: + """Determine if async rollouts should be used based on the configuration. + + Returns True if vLLM backend is used with async_engine enabled. + """ + generation_config = master_config["policy"]["generation"] + if generation_config is None: + return False + + backend = generation_config.get("backend", "") + if backend != "vllm": + return False + + vllm_cfg = generation_config.get("vllm_cfg", {}) + return vllm_cfg.get("async_engine", False) + + +def _should_use_nemo_gym(master_config: MasterConfig) -> bool: + """Determine if NeMo-Gym should be used for rollouts and validation based on the configuration.""" + env_config = master_config.get("env") or dict() + should_use_nemo_gym = bool(env_config.get("should_use_nemo_gym")) + if not should_use_nemo_gym: + return should_use_nemo_gym + + # Validate the setup for training with NeMo-Gym + assert _should_use_async_rollouts(master_config), ( + "❌ Error: In order to use NeMo-Gym, you must use vllm generation backend with `async_engine: true`!" + ) + + generation_config = master_config["policy"]["generation"] + + # We piggyback off of `_should_use_async_rollouts` to guarantee the existence of these configs. + should_expose_http_server = generation_config["vllm_cfg"].get("expose_http_server") + assert should_expose_http_server, ( + "In order to use NeMo-Gym, you must expose the vllm server via `expose_http_server: true`!" + ) + + return should_use_nemo_gym + + +def _should_log_nemo_gym_responses(master_config: MasterConfig) -> bool: + env_config = master_config.get("env") or dict() + should_log_nemo_gym_responses = bool( + env_config.get("should_log_nemo_gym_responses") + ) + + return should_log_nemo_gym_responses + + +def _create_advantage_estimator(master_config: MasterConfig): + """Create and return an advantage estimator based on configuration. + + Args: + master_config: The master configuration dictionary. + + Returns: + An advantage estimator instance (GRPO, GDPO, or ReinforcePlusPlus). + + Raises: + ValueError: If the advantage estimator name is not recognized. + """ + grpo_config = master_config["grpo"] + loss_config = master_config["loss_fn"] + + # Provide backward-compatible defaults when adv_estimator is not in config. + # Fall back to top-level grpo.normalize_rewards / grpo.use_leave_one_out_baseline + # which older configs still use. + adv_estimator_config = grpo_config.get( + "adv_estimator", + { + "name": "grpo", + "normalize_rewards": grpo_config.get("normalize_rewards", True), + "use_leave_one_out_baseline": grpo_config.get( + "use_leave_one_out_baseline", False + ), + "minus_baseline": True, + }, + ) + + adv_estimator_name = adv_estimator_config["name"] + if adv_estimator_name == "gdpo": + adv_estimator = GDPOAdvantageEstimator(adv_estimator_config, loss_config) + print(" ✓ Using GDPO advantage estimator (multi-reward)") + elif adv_estimator_name == "grpo": + adv_estimator = GRPOAdvantageEstimator(adv_estimator_config, loss_config) + print(" ✓ Using GRPO advantage estimator") + elif adv_estimator_name == "reinforce_plus_plus": + adv_estimator = ReinforcePlusPlusAdvantageEstimator( + adv_estimator_config, loss_config + ) + print(" ✓ Using Reinforce++ advantage estimator") + elif adv_estimator_name == "entropic_adaptive_beta": + from nemo_rl.algorithms.entropic_advantage_estimator import ( + EntropicAdaptiveBetaAdvantageEstimator, + ) + adv_estimator = EntropicAdaptiveBetaAdvantageEstimator( + adv_estimator_config, loss_config + ) + print(" ✓ Using Entropic Adaptive-Beta advantage estimator (TTT-Discover)") + else: + raise ValueError(f"Invalid adv_estimator name: {adv_estimator_name}") + + return adv_estimator + + +def _extract_prompt_only_messages(message_logs: list) -> list: + """Extract only prompt messages (user/system) from message logs. + + This is used to get prompt IDs for advantage estimation, excluding + any assistant responses. + + Args: + message_logs: List of message logs, where each log is a list of messages. + + Returns: + List of message logs containing only user and system messages. + """ + prompt_only_message_logs = [] + for message_log in message_logs: + prompt_only_log = [] + for message in message_log: + if message["role"] == "user" or message["role"] == "system": + prompt_only_log.append(message) + prompt_only_message_logs.append(prompt_only_log) + return prompt_only_message_logs + + +def refit_policy_generation( + policy: ColocatablePolicyInterface, + policy_generation: GenerationInterface, + colocated_inference: bool, + _refit_buffer_size_gb: Optional[int] = None, + timer: Optional[Timer] = None, + kv_scales: Optional[dict[str, float]] = None, +) -> None: + """Refit the policy generation interface with the latest policy weights. + + Args: + policy: The policy to provide weights to the inference engine. + policy_generation: The inference engine to refit. + _refit_buffer_size_gb: The size of the buffer to use for refitting. + If it is None, the buffer size will be computed by the remaining memory. + This parameter is primarily used for testing. + timer: Optional Timer used to time the prepare/transfer/update phase + kv_scales: Optional dictionary of KV cache scales for FP8 quantization. + """ + if colocated_inference: + policy.offload_before_refit() + policy_generation.prepare_for_generation(tags=["weights"]) + + # Create a context manager that does nothing when timer is None + timer_context = ( + timer.time("prepare_for_generation/transfer_and_update_weights") + if timer is not None + else nullcontext() + ) + with timer_context: + # update weights + update_success = False + if colocated_inference: + # get model param keys, which is grouped by size + if _refit_buffer_size_gb is not None: + buffer_size_bytes = _refit_buffer_size_gb * (1024**3) + else: + # Empirically sets ratio as 30% to maximize efficiency. + # The remaining 70% is a necessary buffer reserved for the parameter all-gathering across the expert-parallelism dimension. + memory_ratio = os.getenv("NRL_REFIT_BUFFER_MEMORY_RATIO", "0.3") + buffer_size_bytes = int( + policy.get_free_memory_bytes() * float(memory_ratio) + ) + + if isinstance(policy_generation, SGLangGeneration): + sglang_url_to_gpu_uuids = ( + policy_generation.get_sglang_url_to_gpu_uuids() + ) + # Stream weights via HTTP + flush_success = policy_generation.invalidate_kv_cache() + if not flush_success: + print("SGLang KV cache invalidation failed before weight update. ") + futures_train = policy.stream_weights_via_http( + sglang_url_to_gpu_uuids=sglang_url_to_gpu_uuids, + ) + # Wait for all workers to complete + ray.get(futures_train) + update_success = True + else: + # Original ZMQ IPC path for vLLM + futures_train = policy.stream_weights_via_ipc_zmq( + buffer_size_bytes=buffer_size_bytes + ) + futures_inference = policy_generation.update_weights_via_ipc_zmq() + # wait for all futures to complete + ray.get(futures_train) + results = ray.get(futures_inference) + update_success = all(result for result in results if result is not None) + else: + # update weights through nccl + # SGLang haven't implemented non-colocated inference mode. + if isinstance(policy_generation, SGLangGeneration): + raise NotImplementedError( + "SGLang haven't implemented non-colocated inference mode. " + ) + futures_train = policy.broadcast_weights_for_collective(kv_scales=kv_scales) + futures_inference = policy_generation.update_weights_from_collective() + # wait for all futures to complete + ray.get(futures_train) + results = ray.get(futures_inference) + update_success = all(result for result in results if result is not None) + + # check if update is successful + if not update_success: + error_tag = "cuda-ipc" if colocated_inference else "nccl" + error_message = ( + "❌ Error: Updating weights for the generation policy failed during refit.\n" + f"This often indicates an issue with {error_tag} or " + "a problem within the generation backend (e.g., vLLM worker).\n" + ) + raise RuntimeError(error_message) + + if colocated_inference: + policy.offload_after_refit() + policy_generation.prepare_for_generation(tags=["kv_cache"]) + + +def _log_mixed_rewards_and_advantages_information( + logger: Logger, + total_steps: int, + metrics: dict[str, Any], + baseline: torch.Tensor, + advantages: torch.Tensor, +) -> None: + # The histograms that are logged are logged with a prefix "train/" to the name, since that is what the remaining metrics will be logged with. + logger.log_histogram( + baseline.numpy(), total_steps + 1, "train/baseline_reward/histogram" + ) + metrics["baseline_reward/pct_0"] = 100 * (baseline == 0).float().mean().item() + metrics["baseline_reward/pct_1"] = 100 * (baseline == 1).float().mean().item() + metrics["baseline_reward/pct_mixed"] = ( + 100 - metrics["baseline_reward/pct_0"] - metrics["baseline_reward/pct_1"] + ) + + logger.log_histogram( + advantages.numpy(), total_steps + 1, "train/advantages/histogram" + ) + metrics["advantages/sum"] = advantages.float().sum().item() + metrics["advantages/mean"] = advantages.float().mean().item() + + +def compute_and_apply_seq_logprob_error_masking( + train_data: BatchedDataDict, + rewards: torch.Tensor, + seq_logprob_error_threshold: Optional[float], +) -> tuple[float, int, float]: + """Compute sequence-level logprob error metrics and optionally mask high-error sequences. + + This function computes the multiplicative probability error per sequence + (same calculation as token_mult_prob_error but aggregated per-sequence) and + optionally masks sequences that exceed the configured threshold. + + Args: + train_data: Training data dict containing token_mask, sample_mask, + prev_logprobs, and generation_logprobs. If masking is applied, + sample_mask will be updated in-place. + rewards: Reward tensor for computing statistics on masked sequences. + seq_logprob_error_threshold: If set, mask sequences with mult_prob_error + exceeding this threshold. If None, only compute metrics. + + Returns: + Tuple of (max_seq_mult_prob_error, num_masked_seqs, masked_correct_pct) + """ + # Compute sequence-level logprob error metrics (always) + token_mask = train_data["token_mask"][:, 1:] + sample_mask = train_data["sample_mask"] + prev_logprobs = train_data["prev_logprobs"][:, 1:] + generation_logprobs = train_data["generation_logprobs"][:, 1:] + lp_error = torch.abs(generation_logprobs - prev_logprobs) + + # Use combined mask exactly as in loss function + mask = token_mask * sample_mask.unsqueeze(-1) + + # Calculate sequence-level multiplicative prob error + # EXACT same calculation as token_mult_prob_error but per-sequence + seq_mult_prob_error = (torch.exp(lp_error * mask) * mask).sum(dim=-1) / mask.sum( + dim=-1 + ).clamp(min=1) + max_seq_mult_prob_error = ( + seq_mult_prob_error.max().item() if seq_mult_prob_error.numel() > 0 else 0.0 + ) + + # Apply sequence-level masking if configured + num_masked_seqs = 0 + masked_correct_pct = 0.0 + + if seq_logprob_error_threshold is not None: + print( + f"▶ Applying sequence-level logprob error masking (threshold={seq_logprob_error_threshold})...", + flush=True, + ) + + original_sample_mask = sample_mask.clone() + + # Create mask for sequences below threshold + seq_error_mask = ( + seq_mult_prob_error <= seq_logprob_error_threshold + ).float() * original_sample_mask + + diff_mask = original_sample_mask - seq_error_mask + num_masked_seqs = int(diff_mask.sum().item()) + + if num_masked_seqs > 0: + diff_mask_bool = diff_mask.bool() + masked_correct_count = (rewards.view(-1)[diff_mask_bool] == 1).sum().item() + masked_correct_pct = masked_correct_count / num_masked_seqs + + # Update sample_mask in train_data + train_data["sample_mask"] = seq_error_mask + + print( + f" Masked {num_masked_seqs} sequences with mult_prob_error > {seq_logprob_error_threshold}", + flush=True, + ) + if num_masked_seqs > 0: + print( + f" • {masked_correct_count}/{num_masked_seqs} masked sequences were correct (reward=1)" + f" → {masked_correct_pct:.2%}", + flush=True, + ) + + return max_seq_mult_prob_error, num_masked_seqs, masked_correct_pct + + +# =============================================================================== +# Training & Validation +# =============================================================================== + + +def grpo_train( + policy: ColocatablePolicyInterface, + policy_generation: Optional[GenerationInterface], + wrapped_dataloader: StatefulDataLoader | MultipleDataloaderWrapper, + val_dataloader: Optional[StatefulDataLoader], + tokenizer: TokenizerType, + loss_fn: LossFunction, + task_to_env: dict[str, EnvironmentInterface], + val_task_to_env: Optional[dict[str, EnvironmentInterface]], + logger: Logger, + checkpointer: CheckpointManager, + grpo_save_state: GRPOSaveState, + master_config: MasterConfig, +) -> None: + """Run GRPO training algorithm.""" + timer = Timer() + timeout = TimeoutChecker( + timeout=master_config["checkpointing"]["checkpoint_must_save_by"], + fit_last_save_time=True, + ) + timeout.start_iterations() + memory_tracker = MemoryTracker() + + kv_scales_cache = None # Cache reused for computed kv scales + + NEED_REFIT = True + # If policy_generation is None, use the policy as the generation interface (megatron framework backend) + if policy_generation is None: + policy_generation = policy # type: ignore + NEED_REFIT = False + POLICY_GENERATION_STALE = True # tracks if generation needs a refit before running + assert policy_generation is not None # for mypy type check + + if master_config["grpo"].get("skip_reference_policy_logprobs_calculation"): + assert master_config["loss_fn"]["reference_policy_kl_penalty"] == 0 + print( + "Reference policy logprob calculation will be skipped since `grpo.skip_reference_policy_logprobs_calculation` is set to True and `loss_fn.reference_policy_kl_penalty` is 0." + ) + + # Check if we need to sync KV cache scales + # When fallback to policy as the policy_generation, we use getattr to check. + sync_kv_scales = getattr(policy_generation, "requires_kv_scale_sync", False) + + # common config/state times + current_step = grpo_save_state["current_step"] # current step within an epoch + total_steps = grpo_save_state["total_steps"] # total steps across all epochs + max_num_steps = master_config["grpo"][ + "max_num_steps" + ] # max number of steps to train for + current_epoch = grpo_save_state["current_epoch"] # current epoch + max_num_epochs = master_config["grpo"][ + "max_num_epochs" + ] # max number of epochs to train for + consumed_samples = grpo_save_state[ + "consumed_samples" + ] # total samples consumed across all epochs + total_valid_tokens = grpo_save_state.get( + "total_valid_tokens", 0 + ) # total valid tokens processed across all epochs; default to 0 for backward compatibility with older checkpoints + val_at_start = master_config["grpo"]["val_at_start"] + val_at_end = master_config["grpo"]["val_at_end"] + val_period = master_config["grpo"]["val_period"] + colocated_inference = master_config["policy"]["generation"]["colocated"]["enabled"] + + # Initialize advantage estimator + adv_estimator = _create_advantage_estimator(master_config) + + # Run validation at the start if configured + # TODO: Add validation with kv scales if needed + if val_at_start and current_step == 0: + print("\n🔍 Running initial validation...", flush=True) + memory_tracker.snapshot_start_of_stage("Initial validation", dir()) + + if NEED_REFIT and POLICY_GENERATION_STALE: + refit_policy_generation(policy, policy_generation, colocated_inference) + POLICY_GENERATION_STALE = False + else: + policy_generation.prepare_for_generation() + val_metrics, validation_timings = validate( + policy_generation, + val_dataloader, + tokenizer, + val_task_to_env, + step=0, + master_config=master_config, + logger=logger, + ) + policy_generation.finish_generation() + logger.log_metrics(val_metrics, current_step, prefix="validation") + logger.log_metrics(validation_timings, current_step, prefix="timing/validation") + + if master_config["data"]["use_multiple_dataloader"]: + warnings.warn( + "When using multiple dataloaders, MultipleDataloaderWrapper operates as an infinite iterator. " + "As a result, grpo.max_num_epochs will be ignored, and only grpo.max_num_steps will be used. " + "See https://github.com/NVIDIA-NeMo/RL/blob/main/docs/guides/grpo.md#multiple-dataloaders for more details." + ) + + while current_epoch < max_num_epochs and total_steps < max_num_steps: + memory_tracker.snapshot_start_of_stage("Preparing batch", dir()) + print(f"\n{'=' * 25} Epoch {current_epoch + 1}/{max_num_epochs} {'=' * 25}") + # batch cache is used for DAPO. We store prompts with non-zero standard deviation in this cache. + batch_cache: BatchedDataDict[DatumSpec] = None + # This is the number of batches we processed so far at each step to generate responses whose std is non-zero. Maximum threshold is set by dynamic_sampling_max_gen_batches. Used in the case of dynamic sampling. + dynamic_sampling_num_gen_batches = 0 + + # Run grpo/dapo training loop (single-turn) + for batch in wrapped_dataloader: + # A central place to store logging data that won't be deleted until the loop ends + metrics_logging_data = dict() + metrics = dict() + + if master_config["data"]["use_multiple_dataloader"]: + print( + f"\n{'=' * 25} Step {current_step + 1}/{max_num_steps} {'=' * 25}", + flush=True, + ) + else: + print( + f"\n{'=' * 25} Step {current_step + 1}/{min(len(wrapped_dataloader), max_num_steps)} {'=' * 25}", + flush=True, + ) + + maybe_gpu_profile_step(policy, total_steps + 1) + if policy != policy_generation: + maybe_gpu_profile_step(policy_generation, total_steps + 1) + val_metrics, validation_timings = None, None + + with timer.time("total_step_time"): + # Prepare batch + print("▶ Preparing batch...", flush=True) + with timer.time("data_processing"): + # Repeat batch items + repeated_batch: BatchedDataDict[DatumSpec] = ( + batch.repeat_interleave( + master_config["grpo"]["num_generations_per_prompt"] + ) + ) + # Convert LLMMessageLogType to FlatMessagesType for generation + batched_flat, input_lengths = batched_message_log_to_flat_message( + repeated_batch["message_log"], + pad_value_dict={"token_ids": tokenizer.pad_token_id}, + ) + input_ids = batched_flat["token_ids"] + + # Generate responses - this updates the LLMMessageLogType in repeated_batch + memory_tracker.snapshot_start_of_stage("Generation", dir()) + print( + f"▶ Generating responses for batch of size {repeated_batch.size}...", + flush=True, + ) + with timer.time("prepare_for_generation/total"): + if NEED_REFIT and POLICY_GENERATION_STALE: + # Compute KV scales if needed for FP8 quantization + if sync_kv_scales and kv_scales_cache is None: + print("▶ Computing KV cache scales...", flush=True) + policy.prepare_for_lp_inference() + # Align with training data processing to ensure parallel training compatibility + calib_flat, calib_input_lengths = ( + batched_message_log_to_flat_message( + repeated_batch["message_log"], + pad_value_dict={ + "token_ids": tokenizer.pad_token_id + }, + make_sequence_length_divisible_by=master_config[ + "policy" + ]["make_sequence_length_divisible_by"], + ) + ) + # Create calibration data from flattened messages + calibration_data = BatchedDataDict[ClippedPGLossDataDict]( + { + "input_ids": calib_flat["token_ids"], + "input_lengths": calib_input_lengths, + } + ) + calibration_data.update( + calib_flat.get_multimodal_dict(as_tensors=False) + ) + calibration_data.to("cpu") + kv_scales_cache = policy.calibrate_qkv_fp8_scales( + calibration_data, include_q=True + )["layers"] + + refit_policy_generation( + policy, + policy_generation, + colocated_inference, + timer=timer, + kv_scales=kv_scales_cache if sync_kv_scales else None, + ) + POLICY_GENERATION_STALE = False + else: + if colocated_inference: + policy.offload_after_refit() # unload optimizer to make space for generation + policy_generation.prepare_for_generation() + + dynamic_sampling_num_gen_batches += 1 + if dynamic_sampling_num_gen_batches == 1 and hasattr( + policy_generation, "snapshot_step_metrics" + ): + policy_generation.snapshot_step_metrics() + with timer.time("generation"): + # Clear logger metrics for each generation step + if policy_generation is not None: + policy_generation.clear_logger_metrics() + # Use NeMo-Gym rollouts if enabled. We cascade NeMo-Gym first since NeMo-Gym requires async rollouts. + if _should_use_nemo_gym(master_config): + generation_config = master_config["policy"]["generation"] + nemo_gym_rollout_result = run_async_nemo_gym_rollout( + policy_generation=policy_generation, + input_batch=repeated_batch, + tokenizer=tokenizer, + task_to_env=task_to_env, + max_seq_len=None, + generation_config=generation_config, + max_rollout_turns=None, + greedy=False, + ) + input_ids = nemo_gym_rollout_result.input_ids + repeated_batch = nemo_gym_rollout_result.final_batch + rollout_metrics = nemo_gym_rollout_result.rollout_metrics + del nemo_gym_rollout_result + + # NeMo Gym responses can be very large and expensive to log. Here we have logic to opt-in to logging. + if not _should_log_nemo_gym_responses(master_config): + for key in list(rollout_metrics): + if "full_result" in key: + rollout_metrics.pop(key) + + # Use async rollouts if vLLM async engine is enabled + elif _should_use_async_rollouts(master_config): + ( + repeated_batch, + rollout_metrics, + ) = run_async_multi_turn_rollout( + policy_generation=policy_generation, + input_batch=repeated_batch, + tokenizer=tokenizer, + task_to_env=task_to_env, + max_seq_len=master_config["policy"][ + "max_total_sequence_length" + ], + max_rollout_turns=master_config["grpo"][ + "max_rollout_turns" + ], + greedy=False, + ) + else: + repeated_batch, rollout_metrics = run_multi_turn_rollout( + policy_generation=policy_generation, + input_batch=repeated_batch, + tokenizer=tokenizer, + task_to_env=task_to_env, + max_seq_len=master_config["policy"][ + "max_total_sequence_length" + ], + max_rollout_turns=master_config["grpo"][ + "max_rollout_turns" + ], + greedy=False, + ) + policy_generation.finish_generation() + # Collect generation logger metrics for performance reporting after each generation step + # inflight batch sizes and num pending samples are collected from each worker + if policy_generation is not None: + generation_logger_metrics = ( + policy_generation.get_logger_metrics() + ) + + metrics_logging_data["mean_gen_tokens_per_sample"] = ( + rollout_metrics["mean_gen_tokens_per_sample"] + ) + logger.log_metrics(rollout_metrics, total_steps + 1, prefix="train") + + repeated_batch = scale_rewards( + repeated_batch, master_config["grpo"]["reward_scaling"] + ) + # Process rewards with custom reward function + if master_config["grpo"]["reward_shaping"]["enabled"]: + repeated_batch = apply_reward_shaping( + repeated_batch, master_config["grpo"]["reward_shaping"] + ) + + # Calculate rewards & advantages + memory_tracker.snapshot_start_of_stage("Processing rewards", dir()) + print("▶ Processing rewards...,", flush=True) + with timer.time("reward_calculation"): + # Extract rewards from final_batch + rewards = repeated_batch["total_reward"] + + print("▶ Computing advantages...", flush=True) + if master_config["grpo"].get("calculate_advantages_on_gpu"): + print("Computing advantages on GPU!") + # Just fix the device id for now + device_id = 0 + baseline, std = calculate_baseline_and_std_per_prompt( + input_ids.cuda(device_id), + rewards.cuda(device_id), + torch.ones_like(rewards).cuda(device_id), + leave_one_out_baseline=master_config["grpo"][ + "use_leave_one_out_baseline" + ], + ) + baseline = baseline.cpu() + std = std.cpu() + else: + baseline, std = calculate_baseline_and_std_per_prompt( + input_ids, + rewards, + torch.ones_like(rewards), + leave_one_out_baseline=master_config["grpo"][ + "use_leave_one_out_baseline" + ], + ) + + # Apply dynamic sampling to filter prompts with non-zero std (DAPO algorithm) + repeated_batch, is_batch_complete, batch_cache, ds_metrics = ( + dynamic_sampling( + repeated_batch, + std, + baseline, + dynamic_sampling_num_gen_batches, + master_config, + timer, + batch_cache, + ) + ) + if ds_metrics: + ds_metrics["dynamic_sampling_num_gen_batches"] = ( + dynamic_sampling_num_gen_batches + ) + # Get the updated rewards and baselines. For DAPO, these rewards and baselines only correspond to the prompts with non-zero std. + rewards = ( + repeated_batch["total_reward"] + if not master_config["grpo"]["use_dynamic_sampling"] + else repeated_batch["filtered_reward"] + ) + baseline = repeated_batch["baseline"] + std = repeated_batch["std"] + + # If the current batch is not enough to fill the buffer during dynamic sampling, we update the cache and process the next batch. + if not is_batch_complete: + continue + + gen_step_metrics = {} + if hasattr(policy_generation, "get_step_metrics"): + gen_step_metrics = policy_generation.get_step_metrics() + + # Save baseline for logging (before deletion) + baseline_for_log = baseline.clone() + + # Extract prompt-only messages for advantage estimation + prompt_only_message_logs = _extract_prompt_only_messages( + repeated_batch["message_log"] + ) + prompt_batched_flat, _ = batched_message_log_to_flat_message( + prompt_only_message_logs, + pad_value_dict={"token_ids": tokenizer.pad_token_id}, + ) + prompt_ids_for_adv = prompt_batched_flat["token_ids"] + del prompt_only_message_logs + del prompt_batched_flat + del input_ids + del baseline + del std + + with timer.time("data_processing"): + use_overlong_filtering = master_config["grpo"]["overlong_filtering"] + if use_overlong_filtering: + loss_multiplier = repeated_batch["loss_multiplier"].clone() + truncated = repeated_batch["truncated"] + + if isinstance(truncated, list): + truncated = torch.tensor(truncated, dtype=torch.bool) + + loss_multiplier[truncated] = 0 + repeated_batch["loss_multiplier"] = loss_multiplier + # Add loss mask to each message in LLMMessageLogType + for i, message_log in enumerate(repeated_batch["message_log"]): + for j, message in enumerate(message_log): + if message["role"] == "assistant": + message["token_loss_mask"] = torch.ones_like( + message["token_ids"] + ) + else: + message["token_loss_mask"] = torch.zeros_like( + message["token_ids"] + ) + if "generation_logprobs" not in message: + message["generation_logprobs"] = torch.zeros_like( + message["token_ids"], dtype=torch.float32 + ) + + # Convert updated LLMMessageLogType to FlatMessagesType for training + flat_messages, input_lengths = batched_message_log_to_flat_message( + repeated_batch["message_log"], + pad_value_dict={"token_ids": tokenizer.pad_token_id}, + make_sequence_length_divisible_by=master_config["policy"][ + "make_sequence_length_divisible_by" + ], + ) + + # Create training data from flattened messages + # Note: advantages will be computed and added after logprobs are available + train_data = BatchedDataDict[ClippedPGLossDataDict]( + { + "input_ids": flat_messages["token_ids"], + "input_lengths": input_lengths, + "generation_logprobs": flat_messages["generation_logprobs"], + "token_mask": flat_messages["token_loss_mask"], + "sample_mask": repeated_batch["loss_multiplier"], + } + ) + # this will be mini-batched inside the policy, so maintain the packed multimodal structure + # This is also used to populate part of the downstream logprob calculation data + extra_multimodal_data = flat_messages.get_multimodal_dict( + as_tensors=False + ) + train_data.update(extra_multimodal_data) + train_data.to("cpu") + + metrics_logging_data["content"] = flat_messages["content"] + + memory_tracker.snapshot_start_of_stage("Computing logprobs", dir()) + print("▶ Preparing for logprob inference...", flush=True) + with timer.time("logprob_inference_prep"): + policy.prepare_for_lp_inference() + + print("▶ Computing logprobs...", flush=True) + with timer.time("policy_and_reference_logprobs"): + # Custom create this logprob_data so we avoid Ray comm overheads sending unused data to workers. + logprob_data = BatchedDataDict[ClippedPGLossDataDict]( + { + "input_ids": train_data["input_ids"], + "input_lengths": train_data["input_lengths"], + "token_mask": flat_messages["token_loss_mask"], + "sample_mask": repeated_batch["loss_multiplier"], + **extra_multimodal_data, + } + ) + train_data["prev_logprobs"] = policy.get_logprobs( + logprob_data, timer=timer + )["logprobs"] + + if not master_config["grpo"].get( + "skip_reference_policy_logprobs_calculation" + ): + train_data["reference_policy_logprobs"] = ( + policy.get_reference_policy_logprobs( + logprob_data, + timer=timer, + )["reference_logprobs"] + ) + + del logprob_data + del extra_multimodal_data + + ( + max_seq_mult_prob_error, + num_masked_seqs, + masked_correct_pct, + ) = compute_and_apply_seq_logprob_error_masking( + train_data=train_data, + rewards=rewards, + seq_logprob_error_threshold=master_config["grpo"][ + "seq_logprob_error_threshold" + ], + ) + + # Compute advantages with adv_estimator using correct mask and logprobs + with timer.time("advantage_calculation"): + print("▶ Computing advantages...", flush=True) + # Get token-level mask: token_mask * sample_mask + token_mask = train_data["token_mask"] + sample_mask = train_data["sample_mask"] + mask = token_mask * sample_mask.unsqueeze(-1) + + train_data["advantages"] = adv_estimator.compute_advantage( + prompt_ids=prompt_ids_for_adv, + rewards=rewards, + mask=mask, + repeated_batch=repeated_batch, + logprobs_policy=train_data["prev_logprobs"], + logprobs_reference=train_data.get("reference_policy_logprobs"), + ) + del prompt_ids_for_adv + + # Log rewards and advantages information + _log_mixed_rewards_and_advantages_information( + logger=logger, + total_steps=total_steps, + metrics=metrics, + baseline=baseline_for_log, + advantages=train_data["advantages"], + ) + del baseline_for_log + + memory_tracker.snapshot_start_of_stage("Policy train", dir()) + print("▶ Preparing for training...", flush=True) + with timer.time("training_prep"): + policy.prepare_for_training() # set model train and reload optim to GPU + POLICY_GENERATION_STALE = True + + print("▶ Training policy...", flush=True) + with timer.time("policy_training"): + train_results = policy.train( + train_data, + loss_fn, + timer=timer, + ) + + # Recompute KV scales after policy training if needed + if sync_kv_scales: + with timer.time("recompute_kv_scales"): + print( + "▶ Recomputing KV cache scales after policy update...", + flush=True, + ) + kv_scales_cache = policy.calibrate_qkv_fp8_scales( + train_data, include_q=True + )["layers"] + # Set generation as stale to force refit with new scales + POLICY_GENERATION_STALE = True + + is_last_step = total_steps + 1 >= max_num_steps + if not master_config["data"]["use_multiple_dataloader"]: + is_last_step = is_last_step or ( + (current_epoch + 1 == max_num_epochs) + and (current_step + 1 == len(wrapped_dataloader)) + ) + + # Run validation if it's a validation step or last step with val_at_end + if (val_period > 0 and (total_steps + 1) % val_period == 0) or ( + val_at_end and is_last_step + ): + memory_tracker.snapshot_start_of_stage("Validation", dir()) + if NEED_REFIT and POLICY_GENERATION_STALE: + refit_policy_generation( + policy, + policy_generation, + colocated_inference, + kv_scales=kv_scales_cache if sync_kv_scales else None, + ) + POLICY_GENERATION_STALE = False + else: + if colocated_inference: + policy.offload_after_refit() # unload optimizer to make space for generation + policy_generation.prepare_for_generation() + val_metrics, validation_timings = validate( + policy_generation, + val_dataloader, + tokenizer, + val_task_to_env, + step=total_steps + 1, + master_config=master_config, + logger=logger, + ) + policy_generation.finish_generation() + logger.log_metrics( + validation_timings, total_steps + 1, prefix="timing/validation" + ) + logger.log_metrics( + val_metrics, total_steps + 1, prefix="validation" + ) + + # Get flat advantages and token mask for masked metrics computation + flat_advantages = train_data["advantages"] + flat_token_mask = flat_messages["token_loss_mask"] + + # Filter advantages using token mask (only valid response tokens) + response_advantages = torch.masked_select( + flat_advantages, flat_token_mask.bool() + ) + + memory_tracker.snapshot_start_of_stage("Metrics", dir()) + metrics = { + **metrics, + "loss": train_results["loss"].numpy(), + "grad_norm": train_results["grad_norm"].numpy(), + "reward": rewards.numpy(), + "mean_prompt_length": repeated_batch["length"].numpy(), + "total_num_tokens": input_lengths.numpy(), + # Add masked advantages tracking metrics (only for valid response tokens) + "advantages/mean": torch.mean(response_advantages).detach().item() + if response_advantages.numel() > 0 + else 0.0, + "advantages/max": torch.max(response_advantages).detach().item() + if response_advantages.numel() > 0 + else 0.0, + "advantages/min": torch.min(response_advantages).detach().item() + if response_advantages.numel() > 0 + else 0.0, + **ds_metrics, + } + if "moe_metrics" in train_results: + metrics.update( + {f"moe/{k}": v for k, v in train_results["moe_metrics"].items()} + ) + if master_config["grpo"]["use_dynamic_sampling"]: + metrics["filtered_reward"] = rewards.numpy() + metrics["reward"] = repeated_batch["total_reward"].numpy() + + metrics.update(train_results["all_mb_metrics"]) + metrics.update(gen_step_metrics) + for k, v in metrics.items(): + if k in {"probs_ratio_min", "probs_ratio_clamped_min"}: + valid_values = [x for x in v if not np.isinf(x)] + metrics[k] = ( + np.min(valid_values).item() if valid_values else -1.0 + ) + elif k in {"probs_ratio_max", "probs_ratio_clamped_max"}: + valid_values = [x for x in v if not np.isinf(x)] + metrics[k] = ( + np.max(valid_values).item() if valid_values else -1.0 + ) + elif k in { + "lr", + "wd", + "reward", + "filtered_reward", + "global_valid_seqs", + "global_valid_toks", + "mean_prompt_length", + }: + metrics[k] = np.mean(v).item() + elif isinstance(v, (np.ndarray, list)): + metrics[k] = np.sum(v).item() + else: + print(f"Skipping aggregation for {k} ({type(v)})") + + metrics.update(rollout_metrics) + metrics["generation_logger_metrics"] = generation_logger_metrics + total_valid_tokens += metrics["global_valid_toks"] + + # Always log sequence-level error metrics (useful for deciding threshold) + metrics["max_seq_mult_prob_error"] = max_seq_mult_prob_error + metrics["num_masked_seqs_by_logprob_error"] = num_masked_seqs + metrics["masked_correct_pct"] = masked_correct_pct + + ## Checkpointing + consumed_samples += master_config["grpo"]["num_prompts_per_step"] + timeout.mark_iteration() + + should_save_by_step = ( + is_last_step + or (total_steps + 1) % master_config["checkpointing"]["save_period"] + == 0 + ) + # +1 because step is 0-indexed + # Check if timeout-based checkpointing is enabled in config. + should_save_by_timeout = timeout.check_save() + + memory_tracker.snapshot_start_of_stage("Checkpointing", dir()) + if master_config["checkpointing"]["enabled"] and ( + should_save_by_step or should_save_by_timeout + ): + policy.prepare_for_training() + + # +1 because step is 0-indexed + grpo_save_state["current_step"] = current_step + 1 + grpo_save_state["total_steps"] = total_steps + 1 + grpo_save_state["current_epoch"] = current_epoch + grpo_save_state["total_valid_tokens"] = total_valid_tokens + if val_metrics is not None: + grpo_save_state["val_reward"] = val_metrics["accuracy"] + elif "val_reward" in grpo_save_state: + del grpo_save_state["val_reward"] + grpo_save_state["consumed_samples"] = consumed_samples + + full_metric_name = master_config["checkpointing"]["metric_name"] + if full_metric_name is not None: + assert full_metric_name.startswith( + "train:" + ) or full_metric_name.startswith("val:"), ( + f"metric_name={full_metric_name} must start with 'val:' or 'train:',\n" + f'followed by the corresponding name in the "val" or "train" metrics dictionary.' + f" If you are using an old config, please updated checkpointing.metric_name to the new format, " + f" e.g. 'val_reward --> 'val:reward'" + ) + prefix, metric_name = full_metric_name.split(":", 1) + metrics_source = metrics if prefix == "train" else val_metrics + if not metrics_source: + warnings.warn( + f"You asked to save checkpoints based on {metric_name} but no {prefix} metrics were collected. " + "This checkpoint will not be saved as top-k.", + stacklevel=2, + ) + if full_metric_name in grpo_save_state: + del grpo_save_state[full_metric_name] + elif metric_name not in metrics_source: + raise ValueError( + f"Metric {metric_name} not found in {prefix} metrics" + ) + else: + grpo_save_state[full_metric_name] = metrics_source[ + metric_name + ] + + with timer.time("checkpointing"): + print( + f"Saving checkpoint for step {total_steps + 1}...", + flush=True, + ) + checkpoint_path = checkpointer.init_tmp_checkpoint( + total_steps + 1, grpo_save_state, master_config + ) + policy.save_checkpoint( + weights_path=os.path.join( + checkpoint_path, "policy", "weights" + ), + optimizer_path=os.path.join( + checkpoint_path, "policy", "optimizer" + ) + if checkpointer.save_optimizer + else None, + tokenizer_path=os.path.join( + checkpoint_path, "policy", "tokenizer" + ), + checkpointing_cfg=master_config["checkpointing"], + ) + if master_config["data"]["use_multiple_dataloader"]: + for ( + task_name, + task_dataloader, + ) in wrapped_dataloader.dataloaders.items(): + torch.save( + task_dataloader.state_dict(), + os.path.join( + checkpoint_path, + f"train_dataloader_{task_name}.pt", + ), + ) + else: + torch.save( + wrapped_dataloader.state_dict(), + os.path.join(checkpoint_path, "train_dataloader.pt"), + ) + checkpointer.finalize_checkpoint(checkpoint_path) + + # Logging + # Log training data + memory_tracker.snapshot_start_of_stage("Logging", dir()) + if not _should_log_nemo_gym_responses(master_config): + log_data = {} + if "agent_ref" in repeated_batch: + log_data["agent_ref"] = repeated_batch["agent_ref"] + log_data["content"] = flat_messages["content"] + log_data["rewards"] = rewards.tolist() + if master_config["grpo"]["use_dynamic_sampling"]: + log_data["filtered_rewards"] = rewards.tolist() + log_data["rewards"] = repeated_batch["total_reward"].tolist() + log_data["input_lengths"] = input_lengths.tolist() + log_data["token_ids"] = train_data["input_ids"].tolist() + log_data["token_loss_mask"] = train_data["token_mask"].tolist() + log_data["sample_loss_mask"] = train_data["sample_mask"].tolist() + log_data["advantages"] = train_data["advantages"].tolist() + log_data["generation_logprobs"] = train_data[ + "generation_logprobs" + ].tolist() + log_data["prev_logprobs"] = train_data["prev_logprobs"].tolist() + + logger.log_batched_dict_as_jsonl( + log_data, f"train_data_step{total_steps + 1}.jsonl" + ) + del log_data + del flat_messages + + timing_metrics: dict[str, float] = timer.get_timing_metrics( + reduction_op="sum" + ) # type: ignore + # track example with high token mult prob error above 1.05 + if metrics["token_mult_prob_error"] > 1.05: + logger.log_plot_token_mult_prob_error( + { + "prompt_lengths": repeated_batch["length"], + "full_lengths": input_lengths, + "generation_logprobs": train_data["generation_logprobs"], + "prev_logprobs": train_data["prev_logprobs"], + "token_mask": train_data["token_mask"], + "sample_mask": train_data["sample_mask"], + }, + total_steps + 1, + name="train/token_mult_prob_error_plot_sample", + ) + del train_data + if master_config["policy"]["generation"].get("vllm_cfg", {}).get( + "enable_vllm_metrics_logger", False + ) and master_config.get("logger", {}).get("wandb_enabled", False): + log_generation_metrics_to_wandb( + generation_logger_metrics, + total_steps + 1, + master_config["policy"]["generation"]["vllm_cfg"][ + "vllm_metrics_logger_interval" + ], + logger, + ) + + # Plot ISL/OSL/ISL+OSL histograms to wandb + if ( + master_config["policy"]["generation"] + .get("vllm_cfg", {}) + .get("async_engine", False) + ): + for metric_name in metrics.keys(): + if metric_name.startswith("histogram/"): + logger.log_histogram( + metrics[metric_name], + total_steps + 1, + f"generation_metrics/{metric_name}", + ) + + print("\n📊 Training Results:") + + print(f" • Loss: {metrics['loss']:.4f}") + print(f" • Generation KL Error: {metrics['gen_kl_error']:.4f}") + if master_config["grpo"]["use_dynamic_sampling"]: + print(f" • Avg Filtered Reward: {np.mean(rewards.numpy()):.4f}") + print( + f" • Avg Total Reward: {np.mean(repeated_batch['total_reward'].numpy()):.4f}" + ) + else: + print(f" • Avg Reward: {np.mean(rewards.numpy()):.4f}") + print( + f" • Mean Generation Length: {metrics_logging_data['mean_gen_tokens_per_sample']:.4f}", + flush=True, + ) + + print("\n⏱️ Timing:", flush=True) + # Display total time first, separately + total_time = timing_metrics.get("total_step_time", 0) + + number_of_samples_per_step = ( + master_config["grpo"]["num_prompts_per_step"] + * master_config["grpo"]["num_generations_per_prompt"] + ) + total_num_gpus = ( + master_config["cluster"]["num_nodes"] + * master_config["cluster"]["gpus_per_node"] + ) + + print(f" • Total step time: {total_time:.2f}s", flush=True) + + # Display all other timing metrics + for k, v in sorted( + timing_metrics.items(), key=lambda item: item[1], reverse=True + ): + if k != "total_step_time": + percent = (v / total_time * 100) if total_time > 0 else 0 + print(f" • {k}: {v:.2f}s ({percent:.1f}%)", flush=True) + + timing_metrics["valid_tokens_per_sec_per_gpu"] = ( + metrics["global_valid_toks"] / total_time / total_num_gpus + ) + performance_metrics = print_performance_metrics( + train_results, metrics, timing_metrics, master_config + ) + + logger.log_metrics(metrics, total_steps + 1, prefix="train") + logger.log_metrics( + performance_metrics, total_steps + 1, prefix="performance" + ) + # step_finished=True here since this is the final log of our current step. + logger.log_metrics( + timing_metrics, + total_steps + 1, + prefix="timing/train", + step_finished=True, + ) + + # Reset the batch and set dynamic_sampling_num_gen_batches to 0 + batch_cache = None + dynamic_sampling_num_gen_batches = 0 + + # Clear mem + memory_tracker.snapshot_start_of_stage("After CPU memory clear", dir()) + + # processing rewards + del repeated_batch + del rewards + # train_data already deleted after logging above + # logging + del metrics + if "val_metrics" in dir(): + del val_metrics + + timer.reset() + current_step += 1 + total_steps += 1 + if should_save_by_timeout: + memory_tracker.snapshot_start_of_stage("", dir()) + print("Timeout has been reached, stopping training early", flush=True) + return + if total_steps >= max_num_steps: + memory_tracker.snapshot_start_of_stage("", dir()) + print( + "Max number of steps has been reached, stopping training early", + flush=True, + ) + return + + current_epoch += 1 + current_step = 0 # Reset step counter for new epoch + + +def validate( + policy_generation: GenerationInterface, + val_dataloader: Optional[StatefulDataLoader], + tokenizer, + val_task_to_env: Optional[dict[str, EnvironmentInterface]], + step: int, + master_config: MasterConfig, + logger: Optional[Logger] = None, +) -> tuple[dict[str, Any], dict[str, Any]]: + """Run validation on the validation dataset.""" + if val_dataloader is None: + assert val_dataloader is not None or master_config["dpo"]["val_period"] == 0, ( + "val_dataloader is None, so dpo.val_period must be 0" + ) + print(" ⚠️ No validation dataloader provided, skipping validation", flush=True) + return {}, {} + + timer = Timer() + with timer.time("total_validation_time"): + print(f"▶ Starting validation at step {step}...", flush=True) + + total_rewards = [] + total_lengths = [] + all_message_logs = [] # Collect all message logs + + max_batches = ( + master_config["grpo"]["max_val_samples"] + // master_config["grpo"]["val_batch_size"] + ) + for batch_idx, val_batch in enumerate(val_dataloader): + if batch_idx >= max_batches: + break + + additional_metrics_to_report = dict() + # Generate responses (updates the LLMMessageLogType in batch_with_msg_logs) + # Use async rollouts if vLLM async engine is enabled + # We cascade NeMo-Gym first since NeMo-Gym also uses async rollouts. + if _should_use_nemo_gym(master_config): + generation_config = master_config["policy"]["generation"] + nemo_gym_rollout_result = run_async_nemo_gym_rollout( + policy_generation=policy_generation, + input_batch=val_batch, + tokenizer=tokenizer, + task_to_env=val_task_to_env, + max_seq_len=None, + generation_config=generation_config, + max_rollout_turns=None, + greedy=False, + ) + val_batch = nemo_gym_rollout_result.final_batch + gen_metrics = nemo_gym_rollout_result.rollout_metrics + additional_metrics_to_report = gen_metrics + elif _should_use_async_rollouts(master_config): + val_batch, gen_metrics = run_async_multi_turn_rollout( + policy_generation, + val_batch, + tokenizer, + val_task_to_env, + max_seq_len=master_config["policy"]["max_total_sequence_length"], + max_rollout_turns=master_config["grpo"]["max_rollout_turns"], + greedy=False, + ) + else: + val_batch, gen_metrics = run_multi_turn_rollout( + policy_generation, + val_batch, + tokenizer, + val_task_to_env, + max_seq_len=master_config["policy"]["max_total_sequence_length"], + max_rollout_turns=master_config["grpo"]["max_rollout_turns"], + greedy=False, + ) + + total_rewards.extend(val_batch["total_reward"].tolist()) + total_lengths.append(gen_metrics["mean_gen_tokens_per_sample"]) + + # Collect message logs for later display + to_env = [ + get_keys_from_message_log( + val_batch["message_log"][i], ["role", "content"] + ) + for i in range(len(val_batch["message_log"])) + ] + + all_message_logs.extend(to_env) + + # Calculate validation metrics + num_samples = len(total_rewards) + if num_samples > 0: + rewards_t = torch.tensor(total_rewards, dtype=torch.float32) + accuracy = rewards_t.mean().item() + else: + accuracy = 0.0 + + avg_length = ( + sum(total_lengths) / len(total_lengths) if len(total_lengths) > 0 else 0.0 + ) + + val_metrics = { + "accuracy": accuracy, + "avg_length": avg_length, + **additional_metrics_to_report, + } + + # Print sample conversations only once at the end of validation + try: + print_message_log_samples( + all_message_logs, + total_rewards, + num_samples=min( + master_config["logger"]["num_val_samples_to_print"], + len(all_message_logs), + ), + step=step, + ) + except Exception as e: + print(f"\n ⚠️ Error displaying message samples: {str(e)}") + print(" ⚠️ Continuing validation without displaying samples...", flush=True) + + # Get timing metrics + timing_metrics = timer.get_timing_metrics(reduction_op="sum") + validation_time = timing_metrics.get("total_validation_time", 0) + + # Print summary of validation results + print("\n📊 Validation Results:") + print(f" • Accuracy: {accuracy:.4f}") + print(f" • Average response length: {avg_length:.1f} tokens") + print(f" • Samples processed: {len(total_rewards)}", flush=True) + + # Print timing information + print("\n ⏱️ Validation Timing:") + validation_time = timing_metrics.get("total_validation_time", 0) + print(f" • Total validation time: {validation_time:.2f}s", flush=True) + + # Log validation data to JSONL file + if logger is not None: + val_log_data = { + "content": all_message_logs, + "rewards": total_rewards, + } + logger.log_batched_dict_as_jsonl(val_log_data, f"val_data_step{step}.jsonl") + + # Make sure to reset the timer after validation + timer.reset() + + # Explicit GPU memory cleanup after validation + gc.collect() + torch.cuda.empty_cache() + + return val_metrics, timing_metrics + + +def async_grpo_train( + policy: ColocatablePolicyInterface, + policy_generation: Optional[GenerationInterface], + dataloader: StatefulDataLoader, + val_dataloader: Optional[StatefulDataLoader], + tokenizer: TokenizerType, + loss_fn: LossFunction, + task_to_env: dict[str, EnvironmentInterface], + val_task_to_env: Optional[dict[str, EnvironmentInterface]], + logger: Logger, + checkpointer: CheckpointManager, + grpo_save_state: GRPOSaveState, + master_config: MasterConfig, + max_trajectory_age_steps: int = 1, +) -> None: + """Run asynchronous GRPO training with replay buffer. + + Args: + policy: Training policy + policy_generation: Generation interface + dataloader: Training data loader + val_dataloader: Validation data loader + tokenizer: Tokenizer + loss_fn: Loss function + task_to_env: Training environments + val_task_to_env: Validation environments + logger: Logger + checkpointer: Checkpoint manager + grpo_save_state: Training state + master_config: Master configuration + max_trajectory_age_steps: Maximum age (in training steps) for trajectories to be used in training + """ + # Ensure we are running with a compatible async generation backend + assert _should_use_async_rollouts(master_config), ( + "Async GRPO requires vLLM backend with vllm_cfg.async_engine=True. " + "Set policy.generation.vllm_cfg.async_engine to true in your config." + ) + assert master_config["loss_fn"]["use_importance_sampling_correction"] is True, ( + "Importance sampling correction must be enabled for async GRPO for good convergence due to off-policy samples!" + ) + + if master_config["grpo"]["async_grpo"]["max_trajectory_age_steps"] > 1: + if not master_config["grpo"]["async_grpo"].get( + "in_flight_weight_updates", False + ): + print( + "⚠️ WARNING: In-flight weight updates must be enabled for async GRPO with max_trajectory_age_steps > 1. " + "Without in-flight weight updates, having more max_trajectory_age_steps will not give any performance benefit." + ) + + # Import async utilities only when needed + from nemo_rl.algorithms.async_utils import AsyncTrajectoryCollector, ReplayBuffer + + timer = Timer() + timeout = TimeoutChecker( + timeout=master_config["checkpointing"]["checkpoint_must_save_by"], + fit_last_save_time=True, + ) + timeout.start_iterations() + NEED_REFIT = True + + # Setup generation interface + if policy_generation is None: + policy_generation = policy + NEED_REFIT = False + POLICY_GENERATION_STALE = True + assert policy_generation is not None + + # Training state + step = grpo_save_state["current_step"] + weight_version = step # Tracks refitted weight versions + consumed_samples = grpo_save_state["consumed_samples"] + total_valid_tokens = grpo_save_state.get( + "total_valid_tokens", 0 + ) # Default to 0 for backward compatibility with older checkpoints + val_period = master_config["grpo"]["val_period"] + val_at_start = master_config["grpo"]["val_at_start"] + val_at_end = master_config["grpo"]["val_at_end"] + colocated_inference = master_config["policy"]["generation"]["colocated"]["enabled"] + + # Initialize advantage estimator + adv_estimator = _create_advantage_estimator(master_config) + + assert not colocated_inference, ( + "Colocated inference is not supported for async GRPO. Please use non-colocated inference." + ) + + # Calculate minimum buffer size from training requirements + # In per-prompt buffer mode, one buffer entry is 1 prompt * num_generations_per_prompt + num_prompts_per_step = master_config["grpo"]["num_prompts_per_step"] + samples_per_prompt_group = master_config["grpo"]["num_generations_per_prompt"] + train_gbs = master_config["policy"]["train_global_batch_size"] + + # Ensure the buffer has at least one step worth of prompt-groups before training + min_trajectories_needed = num_prompts_per_step + + print("📊 Buffer requirements calculation:") + print(f" - num_prompts_per_step: {num_prompts_per_step}") + print(f" - num_generations_per_prompt: {samples_per_prompt_group}") + print(f" - samples_per_prompt_group: {samples_per_prompt_group}") + print(f" - train_global_batch_size: {train_gbs}") + print(f" - min_trajectories_needed: {min_trajectories_needed} (async mode)") + + _replay_py_exec = get_actor_python_env( + "nemo_rl.algorithms.async_utils.ReplayBuffer" + ) + if _replay_py_exec.startswith("uv"): + # Lazily build a dedicated venv across all Ray nodes on-demand. + _replay_py_exec = create_local_venv_on_each_node( + _replay_py_exec, + "nemo_rl.algorithms.async_utils.ReplayBuffer", + ) + + _replay_runtime_env = { + "py_executable": _replay_py_exec, + "env_vars": { + **os.environ, + "VIRTUAL_ENV": _replay_py_exec, + "UV_PROJECT_ENVIRONMENT": _replay_py_exec, + }, + } + + # Calculate optimal buffer size based on generation limits to prevent length bias + # Each weight version generates exactly num_prompts_per_step trajectories + # With max_age_steps, we keep trajectories from multiple weight versions + num_prompts_per_step = master_config["grpo"]["num_prompts_per_step"] + late_arrival_slack = 2 + optimal_buffer_size = ( + num_prompts_per_step * max_trajectory_age_steps * late_arrival_slack + ) + + replay_buffer = ReplayBuffer.options(runtime_env=_replay_runtime_env).remote( + max_size=optimal_buffer_size + ) + + _tc_py_exec = get_actor_python_env( + "nemo_rl.algorithms.async_utils.AsyncTrajectoryCollector" + ) + if _tc_py_exec.startswith("uv"): + _tc_py_exec = create_local_venv_on_each_node( + _tc_py_exec, + "nemo_rl.algorithms.async_utils.AsyncTrajectoryCollector", + ) + + _tc_runtime_env = { + "py_executable": _tc_py_exec, + "env_vars": { + **os.environ, + "VIRTUAL_ENV": _tc_py_exec, + "UV_PROJECT_ENVIRONMENT": _tc_py_exec, + }, + } + + # Initialize trajectory collector with synchronized collection + trajectory_collector = AsyncTrajectoryCollector.options( + runtime_env=_tc_runtime_env + ).remote( + policy_generation=policy_generation, + tokenizer=tokenizer, + task_to_env=task_to_env, + master_config=master_config, + replay_buffer=replay_buffer, + start_step=step, + ) + + # Start trajectory collection in background + collection_task = trajectory_collector.start_collection.remote(dataloader) + + # Ensure collector knows initial weight version + trajectory_collector.set_weight_version.remote(weight_version) + + print("📦 Started continuous background trajectory collection") + + print( + f"🚀 Starting async GRPO training with buffer_size={optimal_buffer_size}, max_age={max_trajectory_age_steps} steps" + ) + + print("⏳ Preparing policy generation for training...") + if NEED_REFIT and POLICY_GENERATION_STALE: + print("🔄 Refitting policy generation with actual model weights...") + try: + refit_policy_generation(policy, policy_generation, colocated_inference) + print("✅ Policy generation refit completed successfully") + POLICY_GENERATION_STALE = False + except Exception as e: + print(f"❌ Policy generation refit failed: {e}") + import traceback + + traceback.print_exc() + return + else: + print("🔄 Preparing policy generation for inference...") + try: + policy_generation.prepare_for_generation() + print("✅ Policy generation preparation completed successfully") + except Exception as e: + print(f"❌ Policy generation preparation failed: {e}") + import traceback + + traceback.print_exc() + return + + print("✅ Policy generation setup complete, proceeding to validation...") + + # Run validation at start if configured + if val_at_start and step == 0: + print("\n🔍 Running initial validation...") + # Pause trajectory collection during initial validation + trajectory_collector.pause.remote() + + try: + val_metrics, validation_timings = validate( + policy_generation, + val_dataloader, + tokenizer, + val_task_to_env, + step=0, + master_config=master_config, + logger=logger, + ) + policy_generation.finish_generation() + logger.log_metrics(val_metrics, step, prefix="validation") + logger.log_metrics(validation_timings, step, prefix="timing/validation") + print("✅ Initial validation completed successfully") + except Exception as e: + print(f"❌ Initial validation failed: {e}") + import traceback + + traceback.print_exc() + # Continue anyway since validation is optional + finally: + # Resume trajectory collection after initial validation + trajectory_collector.resume.remote() + + print("✅ All setup complete, starting buffer wait...") + # Clear logger metrics at start of training + if policy_generation is not None: + policy_generation.clear_logger_metrics() + + # Wait for initial buffer fill + print( + f"⏳ Waiting for replay buffer to have sufficient trajectories ({min_trajectories_needed} trajectories)..." + ) + wait_iterations = 0 + while True: + buffer_size_current = ray.get(replay_buffer.size.remote()) + + print( + f" Wait iteration {wait_iterations}: buffer_filled_ratio={buffer_size_current}/{min_trajectories_needed}" + ) + + if buffer_size_current >= min_trajectories_needed: + break + + time.sleep(1.0) + + print("✅ Buffer ready! Starting training loop...") + + # Main training loop + try: + while step < master_config["grpo"]["max_num_steps"]: + print( + f"\n{'=' * 25} Step {step + 1}/{master_config['grpo']['max_num_steps']} {'=' * 25}" + ) + maybe_gpu_profile_step(policy, step + 1) + if policy != policy_generation: + maybe_gpu_profile_step(policy_generation, step + 1) + + with timer.time("total_step_time"): + # Sample trajectories from replay buffer + print("📦 Sampling from replay buffer...") + with timer.time("exposed_generation"): + buffer_size_current = ray.get(replay_buffer.size.remote()) + print( + f"📊 Step coordination: training_step={step}, max_age={max_trajectory_age_steps}, buffer_size={buffer_size_current}" + ) + + # Sample the required number of per-prompt groups. + num_prompt_groups_needed = master_config["grpo"][ + "num_prompts_per_step" + ] + sample_result = ray.get( + replay_buffer.sample.remote( + num_prompt_groups=num_prompt_groups_needed, + current_weight_version=weight_version, + max_age_steps=max_trajectory_age_steps, + ) + ) + + if ( + sample_result is None + or len(sample_result["trajectories"]) + != num_prompt_groups_needed + ): + print( + "⏳ Buffer empty or not enough groups to form a full step, waiting..." + ) + + # Get buffer debug info to help diagnose the issue + buffer_debug = ray.get(replay_buffer.get_debug_info.remote()) + buffer_size = buffer_debug["total_trajectories"] + + if buffer_size > 0: + print( + f"🔍 Debug: Buffer has {buffer_size} trajectories but sampling requires exactly {num_prompt_groups_needed}." + ) + print(f" Current weight version: {weight_version}") + print(f" Max trajectory age: {max_trajectory_age_steps}") + print( + f" Trajectory versions in buffer: {buffer_debug['trajectory_versions']}" + ) + + time.sleep(0.5) + continue + + # Extract trajectories and metadata from sample result + trajectories = sample_result["trajectories"] + avg_trajectory_age = sample_result["avg_trajectory_age"] + + print( + f"✅ Sampled {len(trajectories)} trajectory groups from buffer (avg age: {avg_trajectory_age:.2f} steps)" + ) + + # Concatenate per-prompt groups into a single training batch + per_prompt_batches = [t["batch"] for t in trajectories] + repeated_batch = BatchedDataDict.from_batches(per_prompt_batches) + # Aggregate rollout metrics across groups (simple mean where applicable) + rollout_metrics = {} + for t in trajectories: + for k, v in t["rollout_metrics"].items(): + rollout_metrics.setdefault(k, []).append(v) + # TODO: this simple averaging might cause misleading information for such data as max_gen_tokens, etc. + rollout_metrics = { + k: (sum(v) / len(v) if isinstance(v[0], (int, float)) else v) + for k, v in rollout_metrics.items() + } + + # Enforce fixed training batch: num_prompts_per_step * num_generations_per_prompt + expected_batch_size = ( + master_config["grpo"]["num_prompts_per_step"] + * master_config["grpo"]["num_generations_per_prompt"] + ) + if repeated_batch.size != expected_batch_size: + print( + f"❌ Unexpected training batch size: got {repeated_batch.size}, expected {expected_batch_size}. Skipping step and waiting for correct buffer content." + ) + time.sleep(0.5) + continue + + # Optional sanity: ensure DP divisibility to avoid sharding issues + dp_size = policy.sharding_annotations.get_axis_size("data_parallel") + if expected_batch_size % dp_size != 0: + raise AssertionError( + f"Configuration error: (num_prompts_per_step * num_generations_per_prompt) = {expected_batch_size} must be divisible by data_parallel size {dp_size}." + ) + + print(f"Got trajectory batch (size: {repeated_batch.size})") + + print("▶ Processing rewards...") + with timer.time("reward_calculation"): + # Extract prompt-only messages for advantage estimation + prompt_only_message_logs = _extract_prompt_only_messages( + repeated_batch["message_log"] + ) + prompt_batched_flat, _ = batched_message_log_to_flat_message( + prompt_only_message_logs, + pad_value_dict={"token_ids": tokenizer.pad_token_id}, + ) + prompt_ids_for_adv = prompt_batched_flat["token_ids"] + del prompt_only_message_logs + del prompt_batched_flat + + rewards = repeated_batch["total_reward"] + + print( + f" 📊 Rewards stats: min={rewards.min():.4f}, max={rewards.max():.4f}, mean={rewards.mean():.4f}, std={rewards.std():.4f}" + ) + + # Prepare training data (same as sync version) + with timer.time("data_processing"): + # Add loss mask to each message + for i, message_log in enumerate(repeated_batch["message_log"]): + for j, message in enumerate(message_log): + if message["role"] == "assistant": + message["token_loss_mask"] = torch.ones_like( + message["token_ids"] + ) + else: + message["token_loss_mask"] = torch.zeros_like( + message["token_ids"] + ) + if "generation_logprobs" not in message: + message["generation_logprobs"] = torch.zeros_like( + message["token_ids"], dtype=torch.float32 + ) + + # Convert to flat format for training + flat_messages, input_lengths = batched_message_log_to_flat_message( + repeated_batch["message_log"], + pad_value_dict={"token_ids": tokenizer.pad_token_id}, + make_sequence_length_divisible_by=master_config["policy"][ + "make_sequence_length_divisible_by" + ], + ) + + # Create training data + # Note: advantages will be computed and added after logprobs are available + train_data = BatchedDataDict[ClippedPGLossDataDict]( + { + "input_ids": flat_messages["token_ids"], + "input_lengths": input_lengths, + "generation_logprobs": flat_messages["generation_logprobs"], + "token_mask": flat_messages["token_loss_mask"], + "sample_mask": repeated_batch["loss_multiplier"], + } + ) + train_data.to("cpu") + + # Training phase (same as sync version) + print("▶ Preparing for logprob inference...") + with timer.time("logprob_inference_prep"): + policy.prepare_for_lp_inference() + + print("▶ Computing logprobs...") + with timer.time("policy_and_reference_logprobs"): + fprop_logprobs = policy.get_logprobs( + train_data, + timer=timer, + )["logprobs"] + reference_logprobs = policy.get_reference_policy_logprobs( + train_data, + timer=timer, + )["reference_logprobs"] + train_data["prev_logprobs"] = fprop_logprobs + train_data["reference_policy_logprobs"] = reference_logprobs + + ( + max_seq_mult_prob_error, + num_masked_seqs, + masked_correct_pct, + ) = compute_and_apply_seq_logprob_error_masking( + train_data=train_data, + rewards=rewards, + seq_logprob_error_threshold=master_config["grpo"][ + "seq_logprob_error_threshold" + ], + ) + + # Compute advantages with adv_estimator using correct mask and logprobs + with timer.time("advantage_calculation"): + print("▶ Computing advantages...", flush=True) + # Get token-level mask: token_mask * sample_mask + token_mask = train_data["token_mask"] + sample_mask = train_data["sample_mask"] + mask = token_mask * sample_mask.unsqueeze(-1) + + train_data["advantages"] = adv_estimator.compute_advantage( + prompt_ids=prompt_ids_for_adv, + rewards=rewards, + mask=mask, + repeated_batch=repeated_batch, + logprobs_policy=train_data["prev_logprobs"], + logprobs_reference=train_data.get("reference_policy_logprobs"), + ) + del prompt_ids_for_adv + + # Log advantages stats + # Note: For GRPOAdvantageEstimator with normalize_rewards=True, these are + # already normalized advantages (equivalent to "Normalized advantages stats" + # in older versions). For ReinforcePlusPlusAdvantageEstimator, advantages + # are globally normalized across valid tokens. + advantages = train_data["advantages"] + print( + f" 📊 Advantages stats: min={advantages.min():.4f}, max={advantages.max():.4f}, mean={advantages.mean():.4f}, std={advantages.std():.4f}" + ) + + print("▶ Preparing for training...") + with timer.time("training_prep"): + policy.prepare_for_training() + POLICY_GENERATION_STALE = True + + print("▶ Training policy...") + with timer.time("policy_training"): + train_results = policy.train( + train_data, + loss_fn, + timer=timer, + ) + + print("🔄 Synchronizing policy weights to trajectory collector…") + generation_logger_metrics = None + if NEED_REFIT: + # Measure pending-generation wait as exposed_generation time + print("🔄 Coordinating with trajectory collector before refit...") + with timer.time("exposed_generation"): + ray.get(trajectory_collector.prepare_for_refit.remote()) + + # Collect generation logger metrics for performance reporting + # inflight batch sizes and num pending samples are collected from each worker + if policy_generation is not None: + generation_logger_metrics = ( + policy_generation.get_logger_metrics() + ) + + # Only the actual refit/weight transfer should be counted as weight_sync + print("🔄 Performing policy generation refit...") + with timer.time("weight_sync"): + refit_policy_generation( + policy, policy_generation, colocated_inference + ) + POLICY_GENERATION_STALE = False + + # Update weight version before resuming trajectory collection so that all trajectories are updated with the new correct weight version + weight_version += 1 + trajectory_collector.set_weight_version.remote(weight_version) + trajectory_collector.resume_after_refit.remote() + + # Clear logger metrics after each refit (weight sync), starting a new logging cycle + if policy_generation is not None: + policy_generation.clear_logger_metrics() + + # Validation + val_metrics, validation_timings = None, None + is_last_step = step + 1 == master_config["grpo"]["max_num_steps"] + + # Run validation if it's a validation step or last step with val_at_end + if (val_period > 0 and (step + 1) % val_period == 0) or ( + val_at_end and is_last_step + ): + # Pause trajectory collection during validation to reduce memory pressure + trajectory_collector.pause.remote() + + if NEED_REFIT and POLICY_GENERATION_STALE: + refit_policy_generation( + policy, policy_generation, colocated_inference + ) + POLICY_GENERATION_STALE = False + else: + policy_generation.prepare_for_generation() + val_metrics, validation_timings = validate( + policy_generation, + val_dataloader, + tokenizer, + val_task_to_env, + step=step + 1, + master_config=master_config, + logger=logger, + ) + policy_generation.finish_generation() + logger.log_metrics( + validation_timings, step + 1, prefix="timing/validation" + ) + logger.log_metrics(val_metrics, step + 1, prefix="validation") + + # Explicit GPU memory cleanup after validation in async mode + import gc + + gc.collect() + torch.cuda.empty_cache() + + # Resume trajectory collection after validation + trajectory_collector.resume.remote() + # Get flat advantages and token mask for masked metrics computation + flat_advantages = train_data["advantages"] + flat_token_mask = flat_messages["token_loss_mask"] + # Save content for logging before deleting flat_messages + flat_messages_content = flat_messages.get("content", []) + del flat_messages + + # Filter advantages using token mask (only valid response tokens) + response_advantages = torch.masked_select( + flat_advantages, flat_token_mask.bool() + ) + + metrics = { + "loss": train_results["loss"].numpy(), + "reward": rewards.numpy(), + "grad_norm": train_results["grad_norm"].numpy(), + "mean_prompt_length": repeated_batch["length"].numpy(), + "total_num_tokens": input_lengths.numpy(), + # Add masked advantages tracking metrics (only for valid response tokens) + "advantages/mean": torch.mean(response_advantages).detach().item() + if response_advantages.numel() > 0 + else 0.0, + "advantages/max": torch.max(response_advantages).detach().item() + if response_advantages.numel() > 0 + else 0.0, + "advantages/min": torch.min(response_advantages).detach().item() + if response_advantages.numel() > 0 + else 0.0, + } + if "moe_metrics" in train_results: + metrics.update( + {f"moe/{k}": v for k, v in train_results["moe_metrics"].items()} + ) + metrics.update(train_results["all_mb_metrics"]) + for k, v in metrics.items(): + if k in {"probs_ratio_min", "probs_ratio_clamped_min"}: + valid_values = [x for x in v if not np.isinf(x)] + metrics[k] = ( + np.min(valid_values).item() if valid_values else -1.0 + ) + elif k in {"probs_ratio_max", "probs_ratio_clamped_max"}: + valid_values = [x for x in v if not np.isinf(x)] + metrics[k] = ( + np.max(valid_values).item() if valid_values else -1.0 + ) + elif k in { + "lr", + "wd", + "reward", + "global_valid_seqs", + "global_valid_toks", + "mean_prompt_length", + }: + metrics[k] = np.mean(v).item() + else: + metrics[k] = np.sum(v).item() + metrics.update(rollout_metrics) + if generation_logger_metrics is not None: + metrics["generation_logger_metrics"] = generation_logger_metrics + total_valid_tokens += metrics["global_valid_toks"] + + # Always log sequence-level error metrics (useful for deciding threshold) + metrics["max_seq_mult_prob_error"] = max_seq_mult_prob_error + metrics["num_masked_seqs_by_logprob_error"] = num_masked_seqs + metrics["masked_correct_pct"] = masked_correct_pct + + # Checkpointing (same as sync version) + consumed_samples += master_config["grpo"]["num_prompts_per_step"] + timeout.mark_iteration() + + should_save_by_step = ( + is_last_step + or (step + 1) % master_config["checkpointing"]["save_period"] == 0 + ) + # +1 because step is 0-indexed + # Check if timeout-based checkpointing is enabled in config. + should_save_by_timeout = timeout.check_save() + + if master_config["checkpointing"]["enabled"] and ( + should_save_by_step or should_save_by_timeout + ): + grpo_save_state["current_step"] = step + 1 + grpo_save_state["total_valid_tokens"] = total_valid_tokens + if val_metrics is not None: + grpo_save_state["val_reward"] = val_metrics["accuracy"] + elif "val_reward" in grpo_save_state: + del grpo_save_state["val_reward"] + grpo_save_state["consumed_samples"] = consumed_samples + + full_metric_name = master_config["checkpointing"]["metric_name"] + if full_metric_name is not None: + assert full_metric_name.startswith( + "train:" + ) or full_metric_name.startswith("val:"), ( + f"metric_name={full_metric_name} must start with 'val:' or 'train:',\n" + f'followed by the corresponding name in the "val" or "train" metrics dictionary.' + f" If you are using an old config, please updated checkpointing.metric_name to the new format, " + f" e.g. 'val_reward --> 'val:accuracy'" + ) + prefix, metric_name = full_metric_name.split(":", 1) + metrics_source = metrics if prefix == "train" else val_metrics + if not metrics_source: + warnings.warn( + f"You asked to save checkpoints based on {metric_name} but no {prefix} metrics were collected. " + "This checkpoint will not be saved as top-k.", + stacklevel=2, + ) + if full_metric_name in grpo_save_state: + del grpo_save_state[full_metric_name] + elif metric_name not in metrics_source: + raise ValueError( + f"Metric {metric_name} not found in {prefix} metrics" + ) + else: + grpo_save_state[full_metric_name] = metrics_source[ + metric_name + ] + + with timer.time("checkpointing"): + print(f"Saving checkpoint for step {step + 1}...") + checkpoint_path = checkpointer.init_tmp_checkpoint( + step + 1, grpo_save_state, master_config + ) + policy.save_checkpoint( + weights_path=os.path.join( + checkpoint_path, "policy", "weights" + ), + optimizer_path=os.path.join( + checkpoint_path, "policy", "optimizer" + ) + if checkpointer.save_optimizer + else None, + tokenizer_path=os.path.join( + checkpoint_path, "policy", "tokenizer" + ), + checkpointing_cfg=master_config["checkpointing"], + ) + # Get dataloader state from trajectory collector + actual_dataloader_state = ray.get( + trajectory_collector.get_dataloader_state.remote() + ) + torch.save( + actual_dataloader_state, + os.path.join(checkpoint_path, "train_dataloader.pt"), + ) + checkpointer.finalize_checkpoint(checkpoint_path) + + # Logging + # Log training data (match sync GRPO logging payload for parity) + log_data = {} + if "agent_ref" in repeated_batch: + log_data["agent_ref"] = repeated_batch["agent_ref"] + log_data["content"] = flat_messages_content + log_data["rewards"] = rewards.tolist() + if master_config["grpo"]["use_dynamic_sampling"]: + # In dynamic sampling, `rewards` corresponds to filtered rewards + log_data["filtered_rewards"] = rewards.tolist() + log_data["rewards"] = repeated_batch["total_reward"].tolist() + log_data["input_lengths"] = input_lengths.tolist() + log_data["token_ids"] = train_data["input_ids"].tolist() + log_data["token_loss_mask"] = train_data["token_mask"].tolist() + log_data["sample_loss_mask"] = train_data["sample_mask"].tolist() + log_data["advantages"] = train_data["advantages"].tolist() + log_data["generation_logprobs"] = train_data["generation_logprobs"].tolist() + log_data["prev_logprobs"] = train_data["prev_logprobs"].tolist() + logger.log_batched_dict_as_jsonl( + log_data, f"train_data_step{step + 1}.jsonl" + ) + del train_data + del flat_messages_content + + timing_metrics: dict[str, float] = timer.get_timing_metrics( + reduction_op="sum" + ) + + # Add buffer stats + buffer_size_current = ray.get(replay_buffer.size.remote()) + metrics["buffer_size"] = buffer_size_current + metrics["avg_trajectory_age"] = avg_trajectory_age + + if master_config["policy"]["generation"].get("vllm_cfg", {}).get( + "enable_vllm_metrics_logger", False + ) and master_config.get("logger", {}).get("wandb_enabled", False): + log_generation_metrics_to_wandb( + generation_logger_metrics, + step + 1, + master_config["policy"]["generation"]["vllm_cfg"][ + "vllm_metrics_logger_interval" + ], + logger, + ) + + # Plot ISL/OSL/ISL+OSL histograms to wandb + if ( + master_config["policy"]["generation"] + .get("vllm_cfg", {}) + .get("async_engine", False) + ): + for metric_name in metrics.keys(): + if metric_name.startswith("histogram/"): + logger.log_histogram( + metrics[metric_name], + step + 1, + f"generation_metrics/{metric_name}", + ) + + print("\n📊 Training Results:") + print(f" • Loss: {metrics['loss']:.4f}") + print(f" • Generation KL Error: {metrics['gen_kl_error']:.4f}") + print(f" • Avg Reward: {np.mean(rewards.numpy()):.4f}") + print(f" • Buffer Size: {buffer_size_current}") + print(f" • Avg Trajectory Age: {avg_trajectory_age:.2f} steps") + + print("\n⏱️ Timing:") + total_time = timing_metrics.get("total_step_time", 0) + print(f" • Total step time: {total_time:.2f}s") + for k, v in sorted( + timing_metrics.items(), key=lambda item: item[1], reverse=True + ): + if k != "total_step_time": + percent = (v / total_time * 100) if total_time > 0 else 0 + print(f" • {k}: {v:.2f}s ({percent:.1f}%)") + + total_num_gpus = ( + master_config["cluster"]["num_nodes"] + * master_config["cluster"]["gpus_per_node"] + ) + timing_metrics["valid_tokens_per_sec_per_gpu"] = ( + metrics["global_valid_toks"] / total_time / total_num_gpus + ) + performance_metrics = print_performance_metrics( + train_results, metrics, timing_metrics, master_config + ) + + logger.log_metrics(performance_metrics, step + 1, prefix="performance") + logger.log_metrics(metrics, step + 1, prefix="train") + logger.log_metrics(timing_metrics, step + 1, prefix="timing/train") + + timer.reset() + step += 1 + if should_save_by_timeout: + print("Timeout has been reached, stopping training early", flush=True) + return + if step >= master_config["grpo"]["max_num_steps"]: + print( + "Max number of steps has been reached, stopping training early", + flush=True, + ) + return + + except Exception as e: + print(f"❌ Error in async loop: {e}") + import traceback + + traceback.print_exc() + + finally: + # Clean up + print("🛑 Stopping trajectory collection...") + try: + ray.kill(trajectory_collector) + except Exception as e: + print(f"Error stopping trajectory collector: {e}") + + try: + ray.kill(replay_buffer) + except Exception as e: + print(f"Error stopping replay buffer: {e}") + + print("Async GRPO training complete!") diff --git a/shim/nemo_rl/environments/__init__.py b/shim/nemo_rl/environments/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/shim/nemo_rl/environments/erdos_discovery_environment.py b/shim/nemo_rl/environments/erdos_discovery_environment.py new file mode 100644 index 0000000000..70c4710fca --- /dev/null +++ b/shim/nemo_rl/environments/erdos_discovery_environment.py @@ -0,0 +1,362 @@ +"""Erdős Discovery Environment for NeMo RL. + +Implements EnvironmentInterface for TTT-Discover with the Erdős Minimum +Overlap Problem. Calls the NeMo Gym resource server for code execution +and reward computation. + +Reference: "Learning to Discover at Test Time" (arXiv:2601.16175) + +The environment: + 1. Receives LLM-generated code from the GRPO rollout + 2. Sends it to the Erdős Gym resource server for sandboxed execution + scoring + 3. Returns reward = 1/bound (or 0 on failure) + 4. Tracks best constructions and buffer statistics via metrics +""" + +import logging +import math +from typing import Any, Optional + +import aiohttp +import ray +import torch + +from nemo_rl.environments.interfaces import EnvironmentInterface, EnvironmentReturn + +logger = logging.getLogger(__name__) + + +# ═══════════════════════════════════════════════════════════════════ +# Inline reward computation (no Gym server needed for debug/testing) +# ═══════════════════════════════════════════════════════════════════ + +def _inline_compute_reward(response_text: str, timeout: int = 60) -> dict: + """Compute reward directly in-process. No HTTP call needed.""" + import re + import signal + import builtins + import math as _math + import itertools as _itertools + import functools as _functools + import collections as _collections + + import numpy as _np + from numpy.fft import rfft, irfft + + _ALLOWED_MODULES = frozenset({ + "numpy", "np", "math", "cmath", "random", + "itertools", "functools", "collections", "fractions", "decimal", + }) + _SAFE_BUILTIN_NAMES = [ + "abs", "all", "any", "bool", "dict", "divmod", "enumerate", + "filter", "float", "format", "int", "isinstance", "issubclass", + "iter", "len", "list", "map", "max", "min", "next", "object", + "print", "range", "repr", "reversed", "round", "set", "slice", + "sorted", "str", "sum", "tuple", "type", "zip", + "Exception", "ValueError", "TypeError", "KeyError", "IndexError", + "StopIteration", "RuntimeError", "NotImplementedError", + "OverflowError", "ZeroDivisionError", "AttributeError", + ] + + # Extract code + code_re = re.compile(r"```(?:python)?\s*\n(.*?)```", re.DOTALL) + blocks = code_re.findall(response_text) + code = blocks[-1].strip() if blocks else response_text.strip() + + # Build sandbox + import random as _random + safe_builtins = {k: getattr(builtins, k) for k in _SAFE_BUILTIN_NAMES if hasattr(builtins, k)} + def _safe_import(name, globals=None, locals=None, fromlist=(), level=0): + if name.split(".")[0] not in _ALLOWED_MODULES: + raise ImportError(f"Module '{name}' not allowed") + return builtins.__import__(name, globals, locals, fromlist, level) + safe_builtins["__import__"] = _safe_import + namespace = { + "__builtins__": safe_builtins, + "np": _np, "numpy": _np, "math": _math, "random": _random, + "itertools": _itertools, "functools": _functools, "collections": _collections, + } + + try: + class _Timeout(Exception): + pass + def _handler(s, f): + raise _Timeout() + old = signal.signal(signal.SIGALRM, _handler) + signal.alarm(timeout) + try: + exec(compile(code, "", "exec"), namespace) + finally: + signal.alarm(0) + signal.signal(signal.SIGALRM, old) + + if "f" not in namespace: + return {"reward": 0.0, "bound": None, "error_msg": "no variable f"} + + f = _np.asarray(namespace["f"], dtype=float).flatten() + + # Validate + if len(f) < 1 or len(f) > 1000: + return {"reward": 0.0, "bound": None, "error_msg": "bad length"} + if _np.any(~_np.isfinite(f)) or _np.any(f < 0) or _np.any(f > 1): + return {"reward": 0.0, "bound": None, "error_msg": "bad values"} + if abs(float(_np.mean(f)) - 0.5) > 1e-3: + return {"reward": 0.0, "bound": None, "error_msg": "bad mean"} + + # Compute bound + n = len(f) + F = rfft(f, n=2*n) + autocorr = irfft(F * _np.conj(F), n=2*n) + bound = float(2 * n * _np.max(autocorr.real) / (_np.sum(f)**2)) + + if bound <= 0 or not _math.isfinite(bound): + return {"reward": 0.0, "bound": None, "error_msg": "bad bound"} + return {"reward": 1.0 / bound, "bound": bound, "error_msg": ""} + + except Exception as e: + return {"reward": 0.0, "bound": None, "error_msg": str(e)[:200]} + +# Type alias matching NeMo RL's convention +LLMMessageLogType = list[dict[str, Any]] +ErdosMetadata = dict[str, Any] + + +@ray.remote(max_restarts=-1, max_task_retries=-1) +class ErdosDiscoveryEnvironment(EnvironmentInterface[ErdosMetadata]): + """Erdős Minimum Overlap Problem environment for GRPO training. + + Communicates with the NeMo Gym Erdős resource server via HTTP for: + - /verify: code execution + reward computation + - /select_state: PUCT state selection for prompts + - /seed_session: buffer initialization + - /compute_entropic_advantages: LOO entropic advantages + - /update_buffer: add new discoveries to PUCT tree + + Config (under env.erdos_discovery): + resource_server_url: Base URL of the Erdős Gym resource server. + seed: Random seed for PUCT buffer initialization. + num_initial_states: States to seed the buffer with. + sandbox_timeout: Code execution timeout in seconds. + """ + + def __init__(self, config: dict): + self.config = config + self.resource_server_url = config.get( + "resource_server_url", "http://localhost:8080" + ) + self.seed = config.get("seed", None) + self.num_initial_states = config.get("num_initial_states", 16) + self.sandbox_timeout = config.get("sandbox_timeout", 600) + self.request_timeout = config.get("request_timeout", 660) + + self.best_reward = 0.0 + self.best_bound = float("inf") + self.total_verified = 0 + self.total_valid = 0 + self._session_initialized = False + self._inline_mode = (self.resource_server_url == "inline") + if self._inline_mode: + logger.info("ErdosDiscovery: running in INLINE mode (no Gym server)") + self._session_initialized = True # No server to init + + async def _ensure_session(self): + """Initialize the PUCT buffer on the resource server if not done.""" + if self._session_initialized: + return + try: + timeout = aiohttp.ClientTimeout(total=30) + async with aiohttp.ClientSession(timeout=timeout) as session: + async with session.post( + f"{self.resource_server_url}/seed_session", + json={ + "num_initial_states": self.num_initial_states, + "seed": self.seed, + }, + ) as resp: + data = await resp.json() + self.best_reward = data.get("best_initial_reward", 0.0) + self.best_bound = data.get( + "best_initial_bound", float("inf") + ) + logger.info( + "ErdosDiscovery: seeded buffer with %d states, " + "best_reward=%.4f, best_bound=%.6f", + data.get("num_states", 0), + self.best_reward, + self.best_bound, + ) + self._session_initialized = True + except Exception as e: + logger.error("ErdosDiscovery: seed_session failed: %s", e) + + async def _verify_single( + self, + session: Optional[aiohttp.ClientSession], + response_text: str, + parent_state: Optional[list[float]] = None, + ) -> dict: + """Call /verify on the resource server, or compute inline.""" + if self._inline_mode: + return _inline_compute_reward( + response_text, timeout=self.sandbox_timeout + ) + # Build a minimal NeMoGymResponse-like payload + # The resource server extracts output_text from response.output_text + body = { + "responses_create_params": { + "input": [{"role": "user", "content": ""}], + }, + "response": { + "id": "verify", + "output": [ + { + "type": "message", + "role": "assistant", + "content": [{"type": "output_text", "text": response_text}], + } + ], + "output_text": response_text, + }, + "parent_state": parent_state, + } + try: + timeout = aiohttp.ClientTimeout(total=self.request_timeout) + async with session.post( + f"{self.resource_server_url}/verify", + json=body, + timeout=timeout, + ) as resp: + return await resp.json() + except Exception as e: + logger.warning("ErdosDiscovery: verify failed: %s", e) + return {"reward": 0.0, "bound": None, "error_msg": str(e)} + + def step( + self, + message_log_batch: list[LLMMessageLogType], + metadata: list[ErdosMetadata], + ) -> EnvironmentReturn[ErdosMetadata]: + """Evaluate a batch of LLM responses. + + Extracts the assistant's last message from each conversation, + sends it to the resource server for code execution + scoring, + returns rewards. + """ + import asyncio + + return asyncio.get_event_loop().run_until_complete( + self._async_step(message_log_batch, metadata) + ) + + async def _async_step( + self, + message_log_batch: list[LLMMessageLogType], + metadata: list[ErdosMetadata], + ) -> EnvironmentReturn[ErdosMetadata]: + await self._ensure_session() + + batch_size = len(message_log_batch) + rewards = torch.zeros(batch_size) + terminateds = torch.ones(batch_size) # Always single-turn + observations = [{}] * batch_size + answers = [None] * batch_size + updated_metadata = list(metadata) + + if self._inline_mode: + session = None + else: + session = aiohttp.ClientSession( + timeout=aiohttp.ClientTimeout(total=self.request_timeout) + ) + + try: + tasks = [] + for i, message_log in enumerate(message_log_batch): + # Extract the last assistant message + response_text = "" + for msg in reversed(message_log): + if msg.get("role") == "assistant": + response_text = msg.get("content", "") + break + + # Get parent_state from metadata if available + parent_state = None + if metadata and i < len(metadata): + parent_state = metadata[i].get("parent_state", None) + + tasks.append( + self._verify_single(session, response_text, parent_state) + ) + + results = await asyncio.gather(*tasks, return_exceptions=True) + finally: + if session is not None: + await session.close() + + for i, result in enumerate(results): + if isinstance(result, Exception): + logger.warning( + "ErdosDiscovery: verify exception for sample %d: %s", + i, + result, + ) + continue + + reward = result.get("reward", 0.0) + rewards[i] = reward + self.total_verified += 1 + + if reward > 0: + self.total_valid += 1 + bound = result.get("bound", None) + if reward > self.best_reward: + self.best_reward = reward + self.best_bound = bound or ( + 1.0 / reward if reward > 0 else float("inf") + ) + + answers[i] = ( + f"bound={bound:.6f}" if bound else f"reward={reward:.4f}" + ) + + # Update metadata with verification results + if i < len(updated_metadata): + updated_metadata[i] = { + **updated_metadata[i], + "reward": reward, + "bound": result.get("bound"), + "error_msg": result.get("error_msg", ""), + "best_reward_ever": result.get( + "best_reward_ever", self.best_reward + ), + } + + return EnvironmentReturn( + observations=observations, + metadata=updated_metadata, + next_stop_strings=[None] * batch_size, + rewards=rewards, + terminateds=terminateds, + answers=answers, + ) + + def global_post_process_and_metrics( + self, batch: dict + ) -> tuple[dict, dict]: + """Compute and return environment-level metrics.""" + valid_rate = ( + self.total_valid / max(self.total_verified, 1) + ) + metrics = { + "env/best_reward": self.best_reward, + "env/best_bound": self.best_bound + if self.best_bound < float("inf") + else 0.0, + "env/total_verified": self.total_verified, + "env/valid_rate": valid_rate, + } + return batch, metrics + + def shutdown(self): + """Cleanup.""" + pass diff --git a/shim/nemo_rl/environments/utils.py b/shim/nemo_rl/environments/utils.py new file mode 100644 index 0000000000..a1bc6ace3f --- /dev/null +++ b/shim/nemo_rl/environments/utils.py @@ -0,0 +1,139 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +from typing import Any, Dict, NotRequired, TypedDict + +from hydra.utils import get_object + +from nemo_rl.distributed.ray_actor_environment_registry import get_actor_python_env +from nemo_rl.environments.interfaces import EnvironmentInterface +from nemo_rl.utils.venvs import create_local_venv_on_each_node + + +# Environment registry entry schema. +class EnvRegistryEntry(TypedDict, total=False): + actor_class_fqn: str + default_processor: NotRequired[str] + + +# Environment registry. Key is the env name, value is a dictionary with the actor class FQN and optional default processor. +ENV_REGISTRY: Dict[str, EnvRegistryEntry] = { + "math_default": { + "actor_class_fqn": "nemo_rl.environments.math_environment.MathEnvironment", + }, + "math": { + "actor_class_fqn": "nemo_rl.environments.math_environment.MathEnvironment", + }, + "math_multi_reward": { + "actor_class_fqn": "nemo_rl.environments.math_environment.MathMultiRewardEnvironment", + }, + "code": { + "actor_class_fqn": "nemo_rl.environments.code_environment.CodeEnvironment", + }, + "reward_model": { + "actor_class_fqn": "nemo_rl.environments.reward_model_environment.RewardModelEnvironment", + }, + "code_jaccard": { + "actor_class_fqn": "nemo_rl.environments.code_jaccard_environment.CodeJaccardEnvironment", + }, + "vlm": { + "actor_class_fqn": "nemo_rl.environments.vlm_environment.VLMEnvironment", + }, + "erdos_discovery": { + "actor_class_fqn": "nemo_rl.environments.erdos_discovery_environment.ErdosDiscoveryEnvironment", + }, + "nemo_gym": { + "actor_class_fqn": "nemo_rl.environments.nemo_gym.NemoGym", + }, +} + + +def chunk_list_to_workers(to_chunk: list[Any], num_workers: int) -> list[list[Any]]: + """Chunk a list into a list of lists, where each sublist is assigned to a worker. Keeps ordering of elements. + + If the list is not divisible by the number of workers, the last worker may have fewer elements. + If there are more workers than elements, the first len(list) workers will have a single element each, + and the remaining workers will have empty lists. + + Args: + list: The list to be chunked. + num_workers: The number of workers to distribute the list to. + + Returns: + A list of lists, where each sublist contains elements assigned to a worker. + + Examples: + ```{doctest} + >>> from nemo_rl.environments.utils import chunk_list_to_workers + >>> chunk_list_to_workers([1, 2, 3, 4, 5], 3) + [[1, 2], [3, 4], [5]] + ``` + """ + if not to_chunk: + return [[] for _ in range(num_workers)] + + # Handle case where we have more workers than elements + if len(to_chunk) <= num_workers: + result = [[item] for item in to_chunk] + # Add empty lists for remaining workers + result.extend([[] for _ in range(num_workers - len(to_chunk))]) + return result + + # Calculate chunk size (ceiling division to ensure all elements are covered) + chunk_size = (len(to_chunk) + num_workers - 1) // num_workers + + # Create chunks + chunks = [] + for i in range(0, len(to_chunk), chunk_size): + chunks.append(to_chunk[i : i + chunk_size]) + + # If we somehow ended up with more chunks than workers (shouldn't happen with ceiling division) + # merge the last chunks + if len(chunks) > num_workers: + chunks[num_workers - 1 :] = [sum(chunks[num_workers - 1 :], [])] + + return chunks + + +def create_env(env_name: str, env_config: dict) -> EnvironmentInterface: + assert env_name in ENV_REGISTRY, ( + f"Env name {env_name} is not registered in ENV_REGISTRY. Please call register_env() to register the environment." + ) + actor_class_fqn = ENV_REGISTRY[env_name]["actor_class_fqn"] + actor_class = get_object(actor_class_fqn) + actor_py_exec = get_actor_python_env(actor_class_fqn) + extra_env_vars = {} + if actor_py_exec.startswith("uv"): + actor_py_exec = create_local_venv_on_each_node( + actor_py_exec, + actor_class_fqn, + ) + extra_env_vars = { + "VIRTUAL_ENV": actor_py_exec, + "UV_PROJECT_ENVIRONMENT": actor_py_exec, + } + env = actor_class.options( # type: ignore # it's wrapped with ray.remote + runtime_env={ + "py_executable": actor_py_exec, + "env_vars": {**dict(os.environ), **extra_env_vars}, + } + ).remote(env_config) + return env + + +def register_env(env_name: str, actor_class_fqn: str) -> None: + if env_name in ENV_REGISTRY: + raise ValueError(f"Env name {env_name} already registered") + + ENV_REGISTRY[env_name] = {"actor_class_fqn": actor_class_fqn} diff --git a/shim/nemo_rl/utils/__init__.py b/shim/nemo_rl/utils/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/shim/nemo_rl/utils/puct_buffer.py b/shim/nemo_rl/utils/puct_buffer.py new file mode 100644 index 0000000000..53f808d783 --- /dev/null +++ b/shim/nemo_rl/utils/puct_buffer.py @@ -0,0 +1,561 @@ +""" +PUCT buffer for TTT-Discover state reuse. + +Reference: "Learning to Discover at Test Time" (arXiv:2601.04116) + +The buffer maintains a tree of (state, reward) nodes. At each training step, +PUCT scoring selects which states to warm-start rollouts from, balancing: + - Exploitation: states whose children have achieved high rewards (high Q) + - Exploration: states that haven't been visited much yet (low n) + +Pure data structure — no ML framework dependencies. +""" + +import math +import dataclasses +from typing import Any, Optional + +import numpy as np + + +# --------------------------------------------------------------------------- +# Internal node +# --------------------------------------------------------------------------- + +@dataclasses.dataclass +class _Node: + state: Any + reward: float # reward of THIS state (from its own evaluation) + parent_key: Any # key of parent node, or None for roots + children_keys: list # keys of direct children + n: int # visit count (number of times selected for expansion) + Q: float # max reward among all descendants (or own reward if leaf) + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +def _make_key(state: Any) -> Any: + """Convert state to a hashable key. + + Supports: str, int, float, tuple, list, np.ndarray, and arbitrary objects + (fallback: id-based, so two different objects with equal content are + treated as distinct — acceptable for LLM response strings). + """ + if isinstance(state, (str, int, float, bool)): + return state + if isinstance(state, np.ndarray): + return (state.dtype, state.shape, state.tobytes()) + if isinstance(state, (list, tuple)): + return tuple(_make_key(x) for x in state) + # Fallback: identity-based key — wrap id so it doesn't collide with ints + return ("__id__", id(state)) + + +# --------------------------------------------------------------------------- +# PUCTBuffer +# --------------------------------------------------------------------------- + +class PUCTBuffer: + """ + Tree-structured buffer with PUCT selection. + + PUCT score for node s: + score(s) = Q(s) + c · P(s) · sqrt(1 + T) / (1 + n(s)) + + Where: + Q(s) = max reward among all descendants of s (own reward if leaf) + P(s) = rank-based prior: rank states by reward, normalize by total rank + n(s) = visit count of s + T = total visit count across all nodes + c = exploration constant (default 1.0) + """ + + def __init__(self, c: float = 1.0) -> None: + self.c = c + self._nodes: dict[Any, _Node] = {} # key → _Node + self._T: int = 0 # total expansions so far + + # ------------------------------------------------------------------ + # Public API + # ------------------------------------------------------------------ + + def add(self, state: Any, reward: float, parent_state: Any = None) -> None: + """Insert a new node into the buffer. + + If the state is already present, this is a no-op (deduplication). + If parent_state is given and present in the buffer, the new node is + linked as a child and Q values are propagated upward. + + Args: + state: The state to insert (any type with a consistent identity). + reward: Scalar reward associated with this state. + parent_state: Parent state, or None for a root node. + """ + key = _make_key(state) + if key in self._nodes: + return # already present — deduplicate + + parent_key = _make_key(parent_state) if parent_state is not None else None + node = _Node( + state=state, + reward=float(reward), + parent_key=parent_key, + children_keys=[], + n=0, + Q=float(reward), # leaf: Q = own reward + ) + self._nodes[key] = node + + if parent_key is not None and parent_key in self._nodes: + self._nodes[parent_key].children_keys.append(key) + self._propagate_Q(parent_key) + + def select( + self, batch_size: int, num_groups: int = 8 + ) -> list[tuple[Any, list]]: + """Select states to warm-start rollouts from. + + Scores each node with PUCT, picks the top `num_groups` distinct states, + and returns `batch_size` (state, context) pairs grouped so that each + group of `batch_size // num_groups` entries shares the same state. + + Context is the ancestry path from root to the selected node: + [(ancestor_state, ancestor_reward), ..., (selected_state, selected_reward)] + The env uses this to build the prompt (previous attempts / warm start). + + Visit counts are incremented for the selected nodes, and T is updated. + + Args: + batch_size: Total number of (state, context) pairs to return. + Must be divisible by num_groups. + num_groups: Number of distinct initial states to select. + + Returns: + List of (state, context) tuples, length == batch_size. + """ + if not self._nodes: + raise ValueError("Buffer is empty — call add() before select()") + if batch_size % num_groups != 0: + raise ValueError( + f"batch_size ({batch_size}) must be divisible by num_groups ({num_groups})" + ) + rollouts_per_group = batch_size // num_groups + + priors = self._rank_priors() + scores = { + key: self._puct_score(node, priors[key]) + for key, node in self._nodes.items() + } + + # Top num_groups keys by PUCT score (at most len(nodes) if buffer is small) + k = min(num_groups, len(self._nodes)) + top_keys = sorted(scores, key=lambda x: scores[x], reverse=True)[:k] + + result: list[tuple[Any, list]] = [] + for key in top_keys: + node = self._nodes[key] + context = self._ancestry(key) + pair = (node.state, context) + result.extend([pair] * rollouts_per_group) + # Increment visit count for this selection + node.n += 1 + self._T += 1 + + return result + + def update( + self, parent_state: Any, child_state: Any, reward: float + ) -> None: + """Add a child node and update Q values up the tree. + + Convenience wrapper around add() that makes the parent/child + relationship explicit. + + Args: + parent_state: The state that was selected and rolled out from. + child_state: The resulting new state produced by the rollout. + reward: Reward of the new child state. + """ + self.add(child_state, reward, parent_state=parent_state) + + def best(self) -> tuple[Any, float]: + """Return the (state, reward) with the highest reward ever seen. + + Returns: + (state, reward) tuple. + """ + if not self._nodes: + raise ValueError("Buffer is empty") + best_key = max(self._nodes, key=lambda k: self._nodes[k].reward) + node = self._nodes[best_key] + return node.state, node.reward + + def __len__(self) -> int: + return len(self._nodes) + + # ------------------------------------------------------------------ + # Private helpers + # ------------------------------------------------------------------ + + def _puct_score(self, node: _Node, prior: float) -> float: + return node.Q + self.c * prior * math.sqrt(1 + self._T) / (1 + node.n) + + def _rank_priors(self) -> dict[Any, float]: + """Rank-based prior: rank by node reward, normalize by sum of ranks. + + Rank 1 = lowest reward, rank N = highest. Ties get the same rank + (average of tied ranks), consistent with scipy.stats.rankdata. + """ + keys = list(self._nodes.keys()) + rewards = np.array([self._nodes[k].reward for k in keys], dtype=float) + + # argsort twice gives rank (0-indexed); add 1 to make 1-indexed + order = np.argsort(rewards) + ranks = np.empty_like(order, dtype=float) + ranks[order] = np.arange(1, len(rewards) + 1, dtype=float) + + # Handle ties: assign average rank to tied rewards. + # Use ranks[tied].mean() — not tied.mean()+1, which would use array + # indices instead of the already-assigned rank values. + # (simple O(N²) loop is fine for buffer sizes we care about) + for i, r in enumerate(rewards): + tied = np.where(rewards == r)[0] + if len(tied) > 1: + ranks[tied] = ranks[tied].mean() + + total = ranks.sum() + return {k: float(ranks[i] / total) for i, k in enumerate(keys)} + + def _propagate_Q(self, key: Any) -> None: + """Propagate max-Q upward from `key` to the root.""" + node = self._nodes[key] + if node.children_keys: + child_rewards = [ + self._nodes[ck].Q + for ck in node.children_keys + if ck in self._nodes + ] + new_Q = max(node.reward, max(child_rewards)) if child_rewards else node.reward + else: + new_Q = node.reward + + if new_Q == node.Q: + return # no change — stop propagation + + node.Q = new_Q + if node.parent_key is not None and node.parent_key in self._nodes: + self._propagate_Q(node.parent_key) + + def _ancestry(self, key: Any) -> list[tuple[Any, float]]: + """Return the path from root to `key` as [(state, reward), ...].""" + path = [] + cur = key + while cur is not None: + node = self._nodes[cur] + path.append((node.state, node.reward)) + cur = node.parent_key + path.reverse() + return path + + +# --------------------------------------------------------------------------- +# Unit tests +# --------------------------------------------------------------------------- + +def _run_tests() -> None: + import sys + + failures: list[str] = [] + + def check(name: str, cond: bool, msg: str = "") -> None: + if not cond: + failures.append(f"FAIL [{name}]: {msg}") + else: + print(f" PASS [{name}]") + + print("=== puct_buffer unit tests ===\n") + + # ------------------------------------------------------------------ + # Basic add / best + # ------------------------------------------------------------------ + print("-- add / best --") + + buf = PUCTBuffer(c=1.0) + buf.add("s0", 0.5) + buf.add("s1", 0.8) + buf.add("s2", 0.3) + + state, reward = buf.best() + check("best_returns_max_reward_state", reward == 0.8, f"reward={reward}") + check("best_returns_correct_state", state == "s1", f"state={state!r}") + check("len_after_adds", len(buf) == 3, f"len={len(buf)}") + + # Duplicate add is a no-op + buf.add("s0", 99.0) + check("duplicate_add_noop", len(buf) == 3, "duplicate changed buffer size") + check("duplicate_reward_unchanged", buf._nodes[_make_key("s0")].reward == 0.5) + + # ------------------------------------------------------------------ + # Q uses MAX not mean + # ------------------------------------------------------------------ + print("\n-- Q = MAX not mean --") + + buf2 = PUCTBuffer() + buf2.add("root", 0.0) + buf2.add("child_low", 0.1, parent_state="root") + buf2.add("child_high", 0.9, parent_state="root") + + root_node = buf2._nodes[_make_key("root")] + check( + "Q_is_max_not_mean", + root_node.Q == 0.9, + f"root.Q={root_node.Q}, expected 0.9 (max), mean would be 0.5", + ) + + # Add another child with even higher reward — Q should update + buf2.add("child_best", 0.95, parent_state="root") + check( + "Q_updates_when_better_child_added", + root_node.Q == 0.95, + f"root.Q={root_node.Q}, expected 0.95", + ) + + # ------------------------------------------------------------------ + # Q propagates through grandchildren (MAX of all descendants) + # ------------------------------------------------------------------ + print("\n-- Q propagation --") + + buf3 = PUCTBuffer() + buf3.add("r", 0.0) + buf3.add("c1", 0.3, parent_state="r") + buf3.add("gc", 0.99, parent_state="c1") # grandchild + + r_node = buf3._nodes[_make_key("r")] + c1_node = buf3._nodes[_make_key("c1")] + check("grandchild_Q_propagates_to_child", c1_node.Q == 0.99, f"c1.Q={c1_node.Q}") + check("grandchild_Q_propagates_to_root", r_node.Q == 0.99, f"r.Q={r_node.Q}") + + # Parent with high own reward should NOT lose Q when children underperform + buf3b = PUCTBuffer() + buf3b.add("great_parent", 0.9) + buf3b.add("weak_child", 0.2, parent_state="great_parent") + gp_node = buf3b._nodes[_make_key("great_parent")] + check( + "parent_Q_not_lowered_by_weak_child", + gp_node.Q == 0.9, + f"great_parent.Q={gp_node.Q}, expected 0.9 (own reward dominates)", + ) + + # ------------------------------------------------------------------ + # Rank priors: ties get correct average rank (not index-based) + # ------------------------------------------------------------------ + print("\n-- rank prior tie handling --") + + buf_ties = PUCTBuffer() + # rewards: s0=0.1 (rank 1), s1=0.5 (tied), s2=0.3 (rank 2), s3=0.5 (tied) + # After tie-averaging: s0→1, s2→2, s1&s3→(3+4)/2=3.5 + buf_ties.add("s0", 0.1) + buf_ties.add("s1", 0.5) + buf_ties.add("s2", 0.3) + buf_ties.add("s3", 0.5) + priors_ties = buf_ties._rank_priors() + p1 = priors_ties[_make_key("s1")] + p3 = priors_ties[_make_key("s3")] + p2 = priors_ties[_make_key("s2")] + check("tied_states_equal_prior", abs(p1 - p3) < 1e-9, f"p1={p1:.6f} p3={p3:.6f}") + check("tied_states_outrank_lower", p1 > p2, f"tied={p1:.4f} vs s2={p2:.4f}") + + # ------------------------------------------------------------------ + # update() convenience wrapper + # ------------------------------------------------------------------ + print("\n-- update() --") + + buf4 = PUCTBuffer() + buf4.add("p", 0.5) + buf4.update("p", "child_via_update", 0.7) + check("update_adds_child", len(buf4) == 2, f"len={len(buf4)}") + check("update_links_child", "child_via_update" in [ + buf4._nodes[ck].state for ck in buf4._nodes[_make_key("p")].children_keys + ]) + + # ------------------------------------------------------------------ + # Exploration: unvisited high-reward states get selected + # ------------------------------------------------------------------ + print("\n-- exploration: unvisited high-reward states --") + + buf5 = PUCTBuffer(c=1.0) + # Old state, visited many times + buf5.add("visited", 0.6) + buf5._nodes[_make_key("visited")].n = 100 + # New high-reward state, never visited + buf5.add("fresh_high", 0.9) + + selected = buf5.select(batch_size=2, num_groups=2) + selected_states = [s for s, _ in selected] + check( + "unvisited_high_reward_selected", + "fresh_high" in selected_states, + f"selected states: {selected_states}", + ) + + # ------------------------------------------------------------------ + # Exploitation: Q(parent) rises after adding a high-reward child, making + # the parent score higher than a sibling with no children. + # We verify PUCT scores directly — not via select() — because select() + # would correctly pick the child itself (even better warm-start). + # ------------------------------------------------------------------ + print("\n-- exploitation: high-Q parent outscores peer --") + + buf6 = PUCTBuffer(c=0.01) # low exploration → scores dominated by Q + buf6.add("peer_no_children", 0.5) + buf6.add("parent_explored", 0.5) + # Give parent_explored a great child: Q should propagate to 0.99 + buf6.add("great_child_2", 0.99, parent_state="parent_explored") + + priors6 = buf6._rank_priors() + pk_peer = _make_key("peer_no_children") + pk_parent = _make_key("parent_explored") + score_peer = buf6._puct_score(buf6._nodes[pk_peer], priors6[pk_peer]) + score_parent = buf6._puct_score(buf6._nodes[pk_parent], priors6[pk_parent]) + + check( + "parent_Q_raised_by_great_child", + buf6._nodes[pk_parent].Q == 0.99, + f"parent.Q={buf6._nodes[pk_parent].Q}", + ) + check( + "high_Q_parent_outscores_peer", + score_parent > score_peer, + f"score_parent={score_parent:.4f}, score_peer={score_peer:.4f}", + ) + + # ------------------------------------------------------------------ + # select() group structure + # ------------------------------------------------------------------ + print("\n-- select() group structure --") + + buf7 = PUCTBuffer() + for i in range(10): + buf7.add(f"s{i}", float(i) / 10) + + result = buf7.select(batch_size=16, num_groups=4) + check("select_total_length", len(result) == 16, f"len={len(result)}") + + # Each group of 4 should share the same state + groups_of_4 = [result[i*4:(i+1)*4] for i in range(4)] + for gi, group in enumerate(groups_of_4): + states_in_group = [s for s, _ in group] + check( + f"group_{gi}_same_state", + len(set(states_in_group)) == 1, + f"group {gi} has mixed states: {states_in_group}", + ) + + # Each group should have a DIFFERENT initial state from the others + group_states = [group[0][0] for group in groups_of_4] + check( + "groups_have_distinct_states", + len(set(group_states)) == 4, + f"group states: {group_states}", + ) + + # ------------------------------------------------------------------ + # select() raises on batch_size not divisible by num_groups + # ------------------------------------------------------------------ + print("\n-- select() error handling --") + + buf8 = PUCTBuffer() + buf8.add("x", 1.0) + try: + buf8.select(batch_size=7, num_groups=3) + check("indivisible_batch_raises", False, "should have raised ValueError") + except ValueError: + check("indivisible_batch_raises", True) + + # select() on empty buffer raises + buf_empty = PUCTBuffer() + try: + buf_empty.select(batch_size=4, num_groups=2) + check("empty_buffer_select_raises", False, "should have raised ValueError") + except ValueError: + check("empty_buffer_select_raises", True) + + # ------------------------------------------------------------------ + # Context (ancestry path) + # ------------------------------------------------------------------ + print("\n-- context / ancestry path --") + + buf9 = PUCTBuffer() + buf9.add("root", 0.1) + buf9.add("child", 0.5, parent_state="root") + buf9.add("grand", 0.9, parent_state="child") + + # Force select to pick "grand" by making it best by far + buf9._nodes[_make_key("grand")].reward = 10.0 + buf9._propagate_Q(_make_key("child")) + buf9._propagate_Q(_make_key("root")) + + result9 = buf9.select(batch_size=1, num_groups=1) + state9, context9 = result9[0] + check("context_is_list", isinstance(context9, list)) + check( + "context_starts_at_root", + context9[0][0] == "root", + f"context[0]={context9[0]}", + ) + check( + "context_ends_at_selected", + context9[-1][0] == state9, + f"context[-1]={context9[-1]}, state={state9!r}", + ) + check( + "context_length_equals_depth", + len(context9) == 3, + f"len={len(context9)}, expected 3", + ) + + # ------------------------------------------------------------------ + # Visit count increments on select + # ------------------------------------------------------------------ + print("\n-- visit count tracking --") + + buf10 = PUCTBuffer() + buf10.add("a", 0.5) + buf10.add("b", 0.6) + n_before_a = buf10._nodes[_make_key("a")].n + buf10.select(batch_size=4, num_groups=2) + T_after = buf10._T + check("T_incremented_by_num_groups", T_after == 2, f"T={T_after}") + total_n = sum(n.n for n in buf10._nodes.values()) + check("total_n_equals_T", total_n == T_after, f"sum(n)={total_n}, T={T_after}") + + # ------------------------------------------------------------------ + # numpy array states + # ------------------------------------------------------------------ + print("\n-- numpy array states --") + + buf11 = PUCTBuffer() + arr_a = np.array([0.1, 0.5, 0.4]) + arr_b = np.array([0.3, 0.3, 0.4]) + buf11.add(arr_a, 0.7) + buf11.add(arr_b, 0.9) + check("numpy_states_len", len(buf11) == 2, f"len={len(buf11)}") + best_s, best_r = buf11.best() + check("numpy_best_reward", best_r == 0.9, f"best_r={best_r}") + check("numpy_best_state", np.array_equal(best_s, arr_b), f"best_s={best_s}") + + # ------------------------------------------------------------------ + print() + if failures: + for f in failures: + print(f) + print(f"\n{len(failures)} test(s) FAILED") + import sys; sys.exit(1) + else: + print("All tests passed.") + + +if __name__ == "__main__": + _run_tests() diff --git a/shim/run_discover.py b/shim/run_discover.py new file mode 100644 index 0000000000..6b2d7bc80a --- /dev/null +++ b/shim/run_discover.py @@ -0,0 +1,349 @@ +"""Run script for TTT-Discover GRPO training on the Erdős Minimum Overlap Problem. + +This follows the sliding_puzzle pattern: custom IterableDataset that generates +prompts dynamically from a PUCT buffer, wired into the standard GRPO loop. + +Usage: + # Start the Gym resource server first (separate process/node): + cd ~/Gym && ng_run "+config_paths=[resources_servers/erdos_discovery/configs/erdos_discovery.yaml]" + + # Then run training: + cd ~/RL && uv run python examples/run_discover.py [--config examples/configs/grpo_erdos_discover.yaml] + +Reference: "Learning to Discover at Test Time" (arXiv:2601.16175) +""" + +import itertools +import argparse +import itertools +import logging +import os +import sys +from typing import Optional + +import aiohttp +import asyncio +import numpy as np +import ray +import torch +from torch.utils.data import IterableDataset + +from nemo_rl.algorithms.grpo import MasterConfig, grpo_train, setup +from nemo_rl.algorithms.utils import get_tokenizer, set_seed +from nemo_rl.data.interfaces import DatumSpec, LLMMessageLogType +from nemo_rl.distributed.virtual_cluster import init_ray +from nemo_rl.environments.erdos_discovery_environment import ( + ErdosDiscoveryEnvironment, +) +from nemo_rl.models.generation import configure_generation_config +from nemo_rl.utils.config import load_config, register_omegaconf_resolvers + +logger = logging.getLogger(__name__) + +# ═══════════════════════════════════════════════════════════════════ +# Problem description (same as in the Gym resource server) +# ═══════════════════════════════════════════════════════════════════ + +PROBLEM_DESCRIPTION = """\ +Erdos Minimum Overlap Problem +============================== + +Goal: Find a step function f (Python list or NumPy array) giving the +tightest possible upper bound on the Erdos minimum overlap constant c. + +Background: + For integer n, partition {1,...,2n} into equal sets A, B. + M_k = #{(a,b) : a in A, b in B, a-b=k}. + c = lim_{n->inf} min_{A,B} max_k M_k / n. + +Known bounds: 0.379005 < c < 0.380927 (Haugland 2016) +Current best upper bound: 0.380876 (2026) + +Upper Bound via Step Functions: + f : [0,1] -> [0,1] with mean(f) = 0.5 gives: + bound = 2*n*max(autocorr(f)) / sum(f)^2 + where autocorr is computed via FFT. + Smaller bound -> higher reward (reward = 1/bound). + +Constraints: 1 <= len(f) <= 1000, 0 <= f[i] <= 1, mean(f) ~ 0.5 (tol 1e-3). + +Output: Python code defining variable `f` in a ```python block. +Allowed: numpy, math, random, itertools, functools, collections. +Execution limit: 600 seconds. Target: bound < 0.380876.\ +""" + + +# ═══════════════════════════════════════════════════════════════════ +# Datum generation +# ═══════════════════════════════════════════════════════════════════ + + +def generate_discover_datum( + tokenizer, + state_info: dict, + idx: int, + task_name: str = "erdos_discovery", +) -> DatumSpec: + """Create a DatumSpec from a PUCT-selected state. + + Args: + tokenizer: HuggingFace tokenizer. + state_info: Dict from /select_state with keys: + state, context, reward, system_prompt, user_prompt. + idx: Datum index. + task_name: Task name for env routing. + + Returns: + DatumSpec ready for the GRPO training loop. + """ + system_prompt = state_info.get("system_prompt", PROBLEM_DESCRIPTION) + user_prompt = state_info["user_prompt"] + + messages: LLMMessageLogType = [ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": user_prompt}, + ] + + # Tokenize the prompt + prompt_text = tokenizer.apply_chat_template( + messages, tokenize=False, add_generation_prompt=True + ) + prompt_ids = tokenizer.encode(prompt_text, add_special_tokens=False) + prompt_tensor = torch.tensor(prompt_ids, dtype=torch.long) + + # Attach token_ids to messages for NeMo RL's message_log format + for msg in messages: + msg_text = tokenizer.apply_chat_template( + [msg], tokenize=False, add_generation_prompt=False + ) + msg_ids = tokenizer.encode(msg_text, add_special_tokens=False) + msg["token_ids"] = torch.tensor(msg_ids, dtype=torch.long) + + return DatumSpec( + message_log=messages, + length=len(prompt_ids), + extra_env_info={ + "parent_state": state_info.get("state"), + "context": state_info.get("context"), + "reward": state_info.get("reward", 0.0), + }, + loss_multiplier=1.0, + idx=idx, + task_name=task_name, + ) + + +# ═══════════════════════════════════════════════════════════════════ +# Dynamic dataset backed by PUCT buffer +# ═══════════════════════════════════════════════════════════════════ + + +class DiscoverDataset(IterableDataset): + """Iterable dataset that fetches prompts from the PUCT buffer each step. + + Each iteration fetches `num_groups_per_step` states from the Gym resource + server's /select_state endpoint and yields them as DatumSpecs. + + The dataset loops indefinitely — the training loop controls termination + via max_num_steps in the GRPO config. + """ + + def __init__( + self, + tokenizer, + resource_server_url: str, + num_groups_per_step: int = 8, + task_name: str = "erdos_discovery", + length: int = 1000, # Nominal length for dataloader + ): + self.tokenizer = tokenizer + self.resource_server_url = resource_server_url + self.num_groups_per_step = num_groups_per_step + self.task_name = task_name + self.length = length + self._idx_counter = itertools.count() + + def _fetch_states_sync(self) -> list[dict]: + """Synchronously fetch states from the PUCT buffer.""" + import requests + + try: + resp = requests.post( + f"{self.resource_server_url}/select_state", + json={ + "batch_size": self.num_groups_per_step, + "num_groups": self.num_groups_per_step, + }, + timeout=30, + ) + resp.raise_for_status() + data = resp.json() + return data.get("states", []) + except Exception as e: + logger.error("Failed to fetch states from PUCT buffer: %s", e) + # Return fallback: single default prompt + return [ + { + "state": [0.5] * 50, + "context": [], + "reward": 0.5, + "system_prompt": PROBLEM_DESCRIPTION, + "user_prompt": ( + "Starting construction (bound=2.000000, 50 pieces):\n" + "[0.5000, 0.5000, ..., 0.5000]\n\n" + "Improve on this construction. Write Python code that " + "defines a better step function `f`. Think carefully." + ), + } + ] + + def __iter__(self): + for _ in itertools.count(): + states = self._fetch_states_sync() + for state_info in states: + idx = next(self._idx_counter) + yield generate_discover_datum( + self.tokenizer, + state_info, + idx=idx, + task_name=self.task_name, + ) + + def __len__(self): + return self.length + + +# ═══════════════════════════════════════════════════════════════════ +# Setup +# ═══════════════════════════════════════════════════════════════════ + + +def setup_discover_data(config: MasterConfig, tokenizer): + """Create dataset, environment, and wire them together. + + Returns: + (train_dataset, val_dataset, task_to_env, val_task_to_env) + """ + env_config = config.get("env", {}).get("erdos_discovery", {}) + resource_server_url = env_config.get( + "resource_server_url", "http://localhost:8080" + ) + num_groups_per_step = env_config.get("num_groups_per_step", 8) + task_name = "erdos_discovery" + + # Create the dynamic dataset + train_dataset = DiscoverDataset( + tokenizer=tokenizer, + resource_server_url=resource_server_url, + num_groups_per_step=num_groups_per_step, + task_name=task_name, + length=config["grpo"]["max_num_steps"] * num_groups_per_step, + ) + + # Validation dataset: same thing (could be a fixed set, but for discovery + # we just re-sample from the buffer) + val_dataset = DiscoverDataset( + tokenizer=tokenizer, + resource_server_url=resource_server_url, + num_groups_per_step=num_groups_per_step, + task_name=task_name, + length=num_groups_per_step, + ) + + # Create the environment as a Ray actor + env = ErdosDiscoveryEnvironment.options( + num_gpus=0, + max_restarts=-1, + max_task_retries=-1, + ).remote(config=env_config) + + task_to_env = {task_name: env} + val_task_to_env = {task_name: env} + + return train_dataset, val_dataset, task_to_env, val_task_to_env + + +# ═══════════════════════════════════════════════════════════════════ +# Main +# ═══════════════════════════════════════════════════════════════════ + + +def main(): + import os + from omegaconf import OmegaConf + from nemo_rl.utils.config import load_config, register_omegaconf_resolvers + + register_omegaconf_resolvers() + + # Parse --config argument + config_path = None + for i, arg in enumerate(sys.argv[1:], 1): + if arg.startswith("--config="): + config_path = arg.split("=", 1)[1] + elif arg == "--config" and i < len(sys.argv) - 1: + config_path = sys.argv[i + 1] + elif not arg.startswith("--") and config_path is None: + config_path = arg + + if config_path is None: + config_path = os.path.join( + os.path.dirname(__file__), "configs", "grpo_erdos_discover_debug.yaml" + ) + + print(f"Loading config from: {config_path}") + config = load_config(config_path) + + # Resolve OmegaConf interpolations (e.g. ${policy.model_name}) + oc = OmegaConf.create(config) + config = OmegaConf.to_container(oc, resolve=True) + + # Initialize Ray + init_ray() + set_seed(config.get("seed", 42)) + + # Tokenizer + tokenizer = get_tokenizer(config["policy"]["tokenizer"]) + + # Generation config + config["policy"]["generation"] = configure_generation_config( + config["policy"]["generation"], tokenizer + ) + + # Setup data + environment + train_dataset, val_dataset, task_to_env, val_task_to_env = ( + setup_discover_data(config, tokenizer) + ) + + # Setup policy, generation, cluster, dataloader, etc. + ( + policy, + policy_generation, + clusters, + dataloader, + val_dataloader, + loss_fn, + nemo_logger, + checkpointer, + grpo_state, + master_config, + ) = setup(config, tokenizer, train_dataset, val_dataset) + + # Run GRPO training + grpo_train( + policy, + policy_generation, + dataloader, + val_dataloader, + tokenizer, + loss_fn, + task_to_env, + val_task_to_env, + nemo_logger, + checkpointer, + grpo_state, + master_config, + ) + + +if __name__ == "__main__": + main() From 6c58bd69c5ceff0d3c41c10ff4eb70d28d9c0b5d Mon Sep 17 00:00:00 2001 From: Morgane Moss Date: Wed, 1 Apr 2026 06:22:07 +0000 Subject: [PATCH 08/48] fix: register mul/div OmegaConf resolvers for v0.5.0 compat --- examples/configs/grpo_erdos_discover_debug.yaml | 3 +++ examples/run_discover.py | 7 +++++++ 2 files changed, 10 insertions(+) diff --git a/examples/configs/grpo_erdos_discover_debug.yaml b/examples/configs/grpo_erdos_discover_debug.yaml index e34cd2d030..3401b9b06e 100644 --- a/examples/configs/grpo_erdos_discover_debug.yaml +++ b/examples/configs/grpo_erdos_discover_debug.yaml @@ -37,6 +37,9 @@ policy: alpha: 1.0 dropout: 0.0 + dynamic_batching: + enabled: false + generation: backend: "vllm" max_new_tokens: 2048 diff --git a/examples/run_discover.py b/examples/run_discover.py index e391a8e8c0..12e0d1dfe6 100644 --- a/examples/run_discover.py +++ b/examples/run_discover.py @@ -272,6 +272,13 @@ def main(): import os from omegaconf import OmegaConf from nemo_rl.utils.config import load_config + + # Register custom resolvers needed by the base config + if not OmegaConf.has_resolver("mul"): + OmegaConf.register_new_resolver("mul", lambda a, b: a * b) + if not OmegaConf.has_resolver("div"): + OmegaConf.register_new_resolver("div", lambda a, b: a // b) + try: from nemo_rl.utils.config import register_omegaconf_resolvers register_omegaconf_resolvers() From 760fa9fddd84d5f40b8148d01227481736132eea Mon Sep 17 00:00:00 2001 From: Morgane Moss Date: Wed, 1 Apr 2026 08:13:24 +0000 Subject: [PATCH 09/48] fix: sync step method for Ray actor event loop compatibility --- .../erdos_discovery_environment.py | 69 ++++++++++++++++++- 1 file changed, 67 insertions(+), 2 deletions(-) diff --git a/nemo_rl/environments/erdos_discovery_environment.py b/nemo_rl/environments/erdos_discovery_environment.py index 70c4710fca..5348bfa9d3 100644 --- a/nemo_rl/environments/erdos_discovery_environment.py +++ b/nemo_rl/environments/erdos_discovery_environment.py @@ -244,8 +244,73 @@ def step( """ import asyncio - return asyncio.get_event_loop().run_until_complete( - self._async_step(message_log_batch, metadata) + try: + loop = asyncio.get_running_loop() + except RuntimeError: + loop = None + + if loop and loop.is_running(): + # Ray actors run inside an event loop — use nest_asyncio or run sync + return self._sync_step(message_log_batch, metadata) + else: + return asyncio.run( + self._async_step(message_log_batch, metadata) + ) + + def _sync_step( + self, + message_log_batch: list[LLMMessageLogType], + metadata: list[ErdosMetadata], + ) -> EnvironmentReturn[ErdosMetadata]: + """Synchronous step for use inside running event loops (Ray actors).""" + batch_size = len(message_log_batch) + rewards = torch.zeros(batch_size) + terminateds = torch.ones(batch_size) + observations = [{}] * batch_size + answers = [None] * batch_size + updated_metadata = list(metadata) + + for i, message_log in enumerate(message_log_batch): + response_text = "" + for msg in reversed(message_log): + if msg.get("role") == "assistant": + response_text = msg.get("content", "") + break + + if self._inline_mode: + result = _inline_compute_reward( + response_text, timeout=self.sandbox_timeout + ) + else: + result = {"reward": 0.0, "bound": None, "error_msg": "sync mode requires inline"} + + reward = result.get("reward", 0.0) + rewards[i] = reward + self.total_verified += 1 + + if reward > 0: + self.total_valid += 1 + bound = result.get("bound") + if reward > self.best_reward: + self.best_reward = reward + self.best_bound = bound or (1.0 / reward if reward > 0 else float("inf")) + answers[i] = f"bound={bound:.6f}" if bound else f"reward={reward:.4f}" + + if i < len(updated_metadata): + updated_metadata[i] = { + **updated_metadata[i], + "reward": reward, + "bound": result.get("bound"), + "error_msg": result.get("error_msg", ""), + } + + return EnvironmentReturn( + observations=observations, + metadata=updated_metadata, + next_stop_strings=[None] * batch_size, + rewards=rewards, + terminateds=terminateds, + answers=answers, ) async def _async_step( From 7e0d70faf908b47218b712e04ecb12d15a933aeb Mon Sep 17 00:00:00 2001 From: Morgane Moss Date: Wed, 1 Apr 2026 08:39:37 +0000 Subject: [PATCH 10/48] fix: observations must include content key for rollout engine --- nemo_rl/environments/erdos_discovery_environment.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/nemo_rl/environments/erdos_discovery_environment.py b/nemo_rl/environments/erdos_discovery_environment.py index 5348bfa9d3..e192c3619b 100644 --- a/nemo_rl/environments/erdos_discovery_environment.py +++ b/nemo_rl/environments/erdos_discovery_environment.py @@ -266,7 +266,7 @@ def _sync_step( batch_size = len(message_log_batch) rewards = torch.zeros(batch_size) terminateds = torch.ones(batch_size) - observations = [{}] * batch_size + observations = [{"role": "user", "content": ""} for _ in range(batch_size)] answers = [None] * batch_size updated_metadata = list(metadata) @@ -323,7 +323,7 @@ async def _async_step( batch_size = len(message_log_batch) rewards = torch.zeros(batch_size) terminateds = torch.ones(batch_size) # Always single-turn - observations = [{}] * batch_size + observations = [{"role": "user", "content": ""} for _ in range(batch_size)] answers = [None] * batch_size updated_metadata = list(metadata) From c37caa695587b62b8d6916287dd8273ba2b63aef Mon Sep 17 00:00:00 2001 From: Morgane Moss Date: Wed, 1 Apr 2026 14:12:28 +0000 Subject: [PATCH 11/48] fix: disable CPU offload for debug (1.5B fits on GPU) --- examples/configs/grpo_erdos_discover_debug.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/configs/grpo_erdos_discover_debug.yaml b/examples/configs/grpo_erdos_discover_debug.yaml index 3401b9b06e..528b5b9640 100644 --- a/examples/configs/grpo_erdos_discover_debug.yaml +++ b/examples/configs/grpo_erdos_discover_debug.yaml @@ -27,7 +27,7 @@ policy: dtensor_cfg: enabled: true - cpu_offload: true + cpu_offload: false activation_checkpointing: true sequence_parallel: false From fb2dab5a3b6d3ec3e1596d6c69a8d0b9a534b6eb Mon Sep 17 00:00:00 2001 From: Morgane Moss Date: Wed, 1 Apr 2026 14:49:52 +0000 Subject: [PATCH 12/48] =?UTF-8?q?docs:=20add=20LESSONS=5FLEARNED.md=20for?= =?UTF-8?q?=20Erd=C5=91s=20TTT-Discover=20on=20NeMo=20RL?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Covers everything we learned getting this running on the d2dfac12 B200 cluster: ray.sub patches, container compat, async fixes, config gotchas, and the full working launch pattern. --- nemo_rl/environments/LESSONS_LEARNED.md | 182 ++++++++++++++++++++++++ 1 file changed, 182 insertions(+) create mode 100644 nemo_rl/environments/LESSONS_LEARNED.md diff --git a/nemo_rl/environments/LESSONS_LEARNED.md b/nemo_rl/environments/LESSONS_LEARNED.md new file mode 100644 index 0000000000..503e518dd1 --- /dev/null +++ b/nemo_rl/environments/LESSONS_LEARNED.md @@ -0,0 +1,182 @@ +# TTT-Discover on NeMo RL — Lessons Learned + +## What This Is + +TTT-Discover (arXiv:2601.16175) running the Erdős Minimum Overlap Problem +on NeMo RL's GRPO framework. The LLM writes Python code defining a step +function, which is executed in a sandbox and scored via FFT autocorrelation. + +## What Worked + +### Final Working Setup +- **Container**: `nvcr.io#nvidia/nemo-rl:v0.5.0` +- **Cluster**: 2 nodes, 8x B200 per node (16 GPUs total) +- **Model**: Qwen/Qwen2.5-1.5B-Instruct with LoRA (r=16) for debug +- **Ray orchestration**: Dakota's patched `ray.sub` (see below) +- **~27s per training step** (5 debug steps completed successfully) + +### Launch Pattern +```bash +cd ~/RL + +CONTAINER="nvcr.io#nvidia/nemo-rl:v0.5.0" +MOUNTS="$PWD:$PWD,/home/shared/models:/home/shared/models" +COMMAND=" +export HF_HUB_ENABLE_HF_TRANSFER=0 +export TORCH_CUDA_ARCH_LIST='9.0 10.0' +export PYTHONPATH=/path/to/RL:\${PYTHONPATH:-} +cd /opt/nemo-rl +python examples/run_discover.py --config examples/configs/grpo_erdos_discover_debug.yaml +" + +COMMAND="$COMMAND" CONTAINER="$CONTAINER" MOUNTS="$MOUNTS" GPUS_PER_NODE=8 \ +sbatch --nodes=2 --partition=batch --exclusive --time=01:00:00 ray.sub +``` + +### Key Config (grpo_erdos_discover_debug.yaml) +```yaml +defaults: "grpo_math_1B.yaml" # Inherit all base settings + +grpo: + num_prompts_per_step: 4 # 4 PUCT-selected states + num_generations_per_prompt: 8 # 8 rollouts per state + max_num_steps: 5 # Debug: 5 steps only + max_rollout_turns: 1 # Single-turn code generation + adv_estimator: + name: entropic_adaptive_beta # NOT standard grpo + gamma: 0.6931471805599453 # ln(2) + +policy: + model_name: "Qwen/Qwen2.5-1.5B-Instruct" + max_total_sequence_length: 4096 + dtensor_cfg: + cpu_offload: false # IMPORTANT: must be false or logprobs crash + lora_cfg: + enabled: true + rank: 16 + +env: + erdos_discovery: + resource_server_url: "inline" # No Gym server needed for debug +``` + +## Cluster-Specific Fixes (d2dfac12 / Together AI B200) + +### ray.sub Patches +Use Dakota's patched `ray.sub` from `~/dakota-ref/ray.sub`. Key changes from upstream: +1. **MPI**: `--mpi=pmi2` (not `pmix`) +2. **Container writable**: `--container-writable` instead of `--no-container-mount-home` +3. **NCCL IB config** in ray.sub itself (not just COMMAND): + ```bash + export NCCL_IB_HCA=mlx5_4:1,mlx5_7:1,mlx5_8:1,mlx5_9:1,mlx5_10:1,mlx5_13:1,mlx5_14:1,mlx5_15:1 + export NCCL_SOCKET_IFNAME=bond0 + export UCX_NET_DEVICES=bond0 + ``` +4. **No `--account`** needed on this cluster +5. `srun --overlap` works via `enroot exec` (Dakota's ray.sub handles this) + +### NGC Container Auth +```bash +mkdir -p ~/.config/enroot +cat > ~/.config/enroot/.credentials << 'EOF' +machine nvcr.io login $oauthtoken password nvapi-YOUR_KEY_HERE +EOF +``` + +### Container Pull Time +First pull takes ~10-12 minutes per node. Cached after that, but cache is +per-node (different nodes need their own pull). Plan for this in your first run. + +## What Broke and How We Fixed It + +### 1. TransformerEngine Won't Build (bare metal) +**Problem**: `uv sync --extra automodel` fails because TransformerEngine needs +`cudnn.h` which doesn't exist on the head node. +**Fix**: Use the NeMo RL container (has TE pre-built). Don't try bare metal +with LoRA — LoRA requires DTensorPolicyWorkerV2 which needs the `automodel` +extra which needs TransformerEngine. + +### 2. Ray Version Mismatch +**Problem**: Container has Ray 2.49.2, but `uv run` picks up Ray 2.54.0 from +the mounted `.venv`. +**Fix**: Use `python` directly (container's Python), not `uv run`. Set your +code path via `PYTHONPATH` instead. + +### 3. Code Compatibility with v0.5.0 Container +**Problem**: Our branch's `nemo_rl` code is newer than v0.5.0 (has `decord` +imports, `register_omegaconf_resolvers`, etc. that the container doesn't have). +**Fix**: Don't overwrite the container's code at `/opt/nemo-rl`. Instead: + - Copy only NEW files (custom estimator, environment, run script) + - Monkey-patch `grpo.py` and `utils.py` at runtime to register new components + - Add `mul`/`div` OmegaConf resolvers manually if `register_omegaconf_resolvers` isn't available + +### 4. Ray Actor Event Loop +**Problem**: `asyncio.get_event_loop().run_until_complete()` in environment +`step()` crashes with "This event loop is already running" inside Ray actors. +**Fix**: Detect running loop and use synchronous path instead: +```python +try: + loop = asyncio.get_running_loop() +except RuntimeError: + loop = None +if loop and loop.is_running(): + return self._sync_step(...) # No async +else: + return asyncio.run(self._async_step(...)) +``` + +### 5. Empty Observations +**Problem**: `KeyError: 'content'` in rollouts.py — the rollout engine expects +observations from `env.step()` to have a `content` key. +**Fix**: Return `[{"role": "user", "content": ""} for _ in range(batch_size)]` +not `[{}] * batch_size`. + +### 6. CPU Offload + LogProbs +**Problem**: `RuntimeError: Expected all tensors to be on the same device` — +model weights on CPU (from `cpu_offload: true`) but input_ids on cuda. +**Fix**: Set `dtensor_cfg.cpu_offload: false` for small debug models. +For large models, double-check the offload/reload flow works with your config. + +### 7. `${mul:...}` OmegaConf Resolver +**Problem**: Base config uses `${mul:a,b}` interpolation but v0.5.0 container +doesn't register this resolver. +**Fix**: Register it manually in your run script: +```python +from omegaconf import OmegaConf +if not OmegaConf.has_resolver("mul"): + OmegaConf.register_new_resolver("mul", lambda a, b: a * b) +``` + +## Architecture + +``` +NeMo RL Container (v0.5.0) +├── /opt/nemo-rl/ ← Container's code (base) +│ ├── nemo_rl/algorithms/ +│ │ ├── grpo.py ← monkey-patched at runtime +│ │ └── entropic_advantage_estimator.py ← our NEW file +│ ├── nemo_rl/environments/ +│ │ ├── utils.py ← monkey-patched at runtime +│ │ └── erdos_discovery_environment.py ← our NEW file +│ └── nemo_rl/utils/ +│ └── puct_buffer.py ← our NEW file +│ +├── /home/mormio/RL/ ← Mounted source (for cp at startup) +└── /home/shared/models/ ← Mounted model weights +``` + +## Scaling to 120B + +For gpt-oss-120b-bf16 (MoE, 128 experts): +- Use 8 nodes (2 training + 6 inference) or similar +- LoRA r=32, alpha=1.0 +- EP=8 for expert parallelism (see Dakota's 120B config) +- May need `cpu_offload: true` + careful memory management +- Set `max_total_sequence_length: 32768` for full context +- Consider `generation.colocated: false` with separate inference nodes + +## References +- Paper: "Learning to Discover at Test Time" (arXiv:2601.16175) +- Reference impl: https://github.com/test-time-training/discover +- Dakota's working configs: `~/dakota-ref/` +- NeMo RL docs: https://docs.nvidia.com/nemo/rl/latest/index.html From 7daed16aa2041b820e34ad0194464f90276a7910 Mon Sep 17 00:00:00 2001 From: Morgane Moss Date: Wed, 1 Apr 2026 15:12:19 +0000 Subject: [PATCH 13/48] =?UTF-8?q?feat:=20add=20120B=20Nemotron=20Super=20l?= =?UTF-8?q?aunch=20config=20for=20Erd=C5=91s=20TTT-Discover?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 8 nodes: 2 inference (vLLM TP=8) + 6 training (Megatron TP=4, EP=8) Uses Dakota custom super-v3 container with NemotronH vLLM support. No LoRA initially (Megatron backend), full fine-tune with optimizer CPU offload + activation checkpointing. --- examples/configs/grpo_erdos_discover.yaml | 115 +++++++++++++--------- launch_erdos_120b.sh | 101 +++++++++++++++++++ 2 files changed, 171 insertions(+), 45 deletions(-) create mode 100755 launch_erdos_120b.sh diff --git a/examples/configs/grpo_erdos_discover.yaml b/examples/configs/grpo_erdos_discover.yaml index 59be71b1b5..f005822e06 100644 --- a/examples/configs/grpo_erdos_discover.yaml +++ b/examples/configs/grpo_erdos_discover.yaml @@ -1,52 +1,81 @@ -# TTT-Discover GRPO config for Erdős Minimum Overlap Problem. -# Reference: arXiv:2601.16175 -# -# Usage: -# uv run python examples/run_discover.py --config examples/configs/grpo_erdos_discover.yaml -# -# Requires the Gym Erdős resource server running separately: -# cd ~/Gym && ng_run "+config_paths=[resources_servers/erdos_discovery/configs/erdos_discovery.yaml]" - -seed: 42 +# TTT-Discover Erdős — Nemotron-3-Super-120B-A12B +# Based on Dakota's working grpo_superv3.yaml adapted for Erdős TTT-Discover. +# 8 nodes: 2 inference (vLLM TP=8) + 6 training (Megatron TP=4 EP=8) +defaults: "grpo_math_1B.yaml" grpo: - num_prompts_per_step: 8 # PUCT selects 8 states per step - num_generations_per_prompt: 64 # 64 rollouts per state = 512 total + num_prompts_per_step: 8 # 8 PUCT-selected states per step + num_generations_per_prompt: 64 # 64 rollouts per state (paper default) max_num_epochs: 1 max_num_steps: 50 # 50 training steps (paper default) - max_rollout_turns: 1 # Single-turn: generate code, get reward + max_rollout_turns: 1 # Single-turn code generation remove_constant_reward_groups: true adv_estimator: name: entropic_adaptive_beta gamma: 0.6931471805599453 # ln(2) loss_fn: - kl_penalty_coef: 0.1 # KL penalty (paper uses 0.1) + kl_penalty_coef: 0.1 ratio_clip: 0.2 - token_level_loss: false # Sequence-level policy ratio - importance_sampling: true + token_level_loss: false policy: - model_name: "gpt-oss-120b-bf16" - tokenizer: "gpt-oss-120b-bf16" - max_total_sequence_length: 32768 - train_global_batch_size: 512 # 8 groups × 64 rollouts - train_micro_batch_size: 4 + model_name: "/home/shared/models/NVIDIA-Nemotron-3-Super-120B-A12B-Base-BF16" + tokenizer: + name: "/home/shared/models/NVIDIA-Nemotron-3-Super-120B-A12B-Base-BF16" + chat_template_kwargs: null + max_total_sequence_length: 16384 + train_global_batch_size: 512 # 8 groups × 64 rollouts + train_micro_batch_size: 1 + logprob_batch_size: 1 - dtensor_cfg: + # Megatron backend for 120B MoE + megatron_cfg: enabled: true - tensor_parallel_size: 1 - sequence_parallel: false + tensor_model_parallel_size: 4 + pipeline_model_parallel_size: 1 + context_parallel_size: 1 + expert_model_parallel_size: 8 # 128 experts / 8 EP = 16 experts per rank + sequence_parallel: true + activation_checkpointing: true + empty_unused_memory_level: 2 + optimizer_cpu_offload: true + optimizer: + optimizer_cpu_offload: true + optimizer_offload_fraction: 1.0 + + dtensor_cfg: + enabled: false lora_cfg: - enabled: true - rank: 32 - alpha: 1.0 - dropout: 0.0 + enabled: false # No LoRA for now (Megatron backend) + + generation: + backend: "vllm" + colocated: + enabled: false + resources: + num_nodes: 2 # 2 nodes for vLLM inference + gpus_per_node: 8 + max_new_tokens: 16384 + temperature: 1.0 + top_p: 1.0 + stop_token_ids: null + stop_strings: null + vllm_cfg: + async_engine: false + tensor_parallel_size: 8 # Full node TP for inference + pipeline_parallel_size: 1 + expert_parallel_size: 1 + gpu_memory_utilization: 0.85 + max_model_len: 16384 + + dynamic_batching: + enabled: false optimizer: name: adamw - lr: 4.0e-5 + lr: 4.0e-5 # Paper default betas: [0.9, 0.999] weight_decay: 0.01 scheduler: @@ -55,33 +84,29 @@ optimizer: min_lr_ratio: 0.1 data: - shuffle: false # PUCT handles selection order - # No static dataset — DiscoverDataset generates prompts dynamically + shuffle: false env: erdos_discovery: - resource_server_url: "http://localhost:8080" + resource_server_url: "inline" num_initial_states: 16 num_groups_per_step: 8 sandbox_timeout: 600 request_timeout: 660 + should_use_nemo_gym: false cluster: gpus_per_node: 8 - num_nodes: 2 # 2 training nodes + num_nodes: 8 -generation: - backend: vllm - colocated: false # Inference on separate nodes - temperature: 1.0 - top_p: 1.0 - max_new_tokens: 16384 # Long context for code generation +logger: + log_dir: "results/erdos-120b" + wandb_enabled: false + tensorboard_enabled: false + mlflow_enabled: false + swanlab_enabled: false checkpointing: enabled: true - checkpoint_dir: "results/ttt-discover-erdos" - save_frequency: 5 # Every 5 steps - -wandb: - enabled: true - project: "ttt-discover-erdos" + checkpoint_dir: "results/erdos-120b" + save_period: 5 diff --git a/launch_erdos_120b.sh b/launch_erdos_120b.sh new file mode 100755 index 0000000000..ded2091f78 --- /dev/null +++ b/launch_erdos_120b.sh @@ -0,0 +1,101 @@ +#!/bin/bash +# TTT-Discover Erdős — Nemotron-3-Super-120B on 8 nodes +# 2 nodes inference (vLLM TP=8), 6 nodes training (Megatron TP=4 EP=8) +# Based on Dakota's working run_super_grpo.sh +set -euo pipefail +cd /home/mormio/RL + +CONTAINER="/home/shared/containers/nemo-rl-super-v3.sqsh" +MODEL_PATH="/home/shared/models/NVIDIA-Nemotron-3-Super-120B-A12B-Base-BF16" +EXP="results/erdos-120b-$(date +%Y%m%d_%H%M)" +mkdir -p "$EXP" + +MOUNTS="$PWD:$PWD,$MODEL_PATH:$MODEL_PATH,$HOME/.cache:$HOME/.cache" + +COMMAND=" +cd /opt/nemo-rl && \ +export NCCL_BUFFSIZE=33554432 && \ +export CUDA_DEVICE_ORDER=PCI_BUS_ID && \ +export NCCL_IB_AR_THRESHOLD=0 && \ +export NCCL_IB_PCI_RELAXED_ORDERING=1 && \ +export NCCL_IB_QPS_PER_CONNECTION=2 && \ +export NCCL_IB_SPLIT_DATA_ON_QPS=0 && \ +export NCCL_IGNORE_CPU_AFFINITY=1 && \ +export NCCL_IB_HCA=mlx5_4:1,mlx5_7:1,mlx5_8:1,mlx5_9:1,mlx5_10:1,mlx5_13:1,mlx5_14:1,mlx5_15:1 && \ +export NCCL_SOCKET_IFNAME=bond0 && \ +export UCX_NET_DEVICES=bond0 && \ +export HF_HUB_ENABLE_HF_TRANSFER=0 && \ +export TORCH_CUDA_ARCH_LIST='9.0 10.0' && \ +export NRL_IGNORE_VERSION_MISMATCH=1 && \ + +# Copy our custom files into the container's /opt/nemo-rl +SRC=/home/mormio/RL +cp \$SRC/nemo_rl/algorithms/entropic_advantage_estimator.py /opt/nemo-rl/nemo_rl/algorithms/ +cp \$SRC/nemo_rl/environments/erdos_discovery_environment.py /opt/nemo-rl/nemo_rl/environments/ +cp \$SRC/nemo_rl/utils/puct_buffer.py /opt/nemo-rl/nemo_rl/utils/ +cp \$SRC/examples/run_discover.py /opt/nemo-rl/examples/ +cp \$SRC/examples/configs/grpo_erdos_discover.yaml /opt/nemo-rl/examples/configs/ + +# Patch grpo.py to register entropic estimator +python -c \" +path = '/opt/nemo-rl/nemo_rl/algorithms/grpo.py' +with open(path) as f: + content = f.read() +if 'entropic_adaptive_beta' not in content: + old = ' else:\\n raise ValueError(f\\\"Invalid adv_estimator name: {adv_estimator_name}\\\")\\n\\n return adv_estimator' + new = ''' elif adv_estimator_name == \\\"entropic_adaptive_beta\\\": + from nemo_rl.algorithms.entropic_advantage_estimator import ( + EntropicAdaptiveBetaAdvantageEstimator, + ) + adv_estimator = EntropicAdaptiveBetaAdvantageEstimator( + adv_estimator_config, loss_config + ) + print(\\\" Using Entropic Adaptive-Beta advantage estimator (TTT-Discover)\\\") + else: + raise ValueError(f\\\"Invalid adv_estimator name: {adv_estimator_name}\\\") + + return adv_estimator''' + content = content.replace(old, new) + with open(path, 'w') as f: + f.write(content) + print('Patched grpo.py') +\" && \ + +# Patch utils.py to register erdos_discovery env +python -c \" +path = '/opt/nemo-rl/nemo_rl/environments/utils.py' +with open(path) as f: + content = f.read() +if 'erdos_discovery' not in content: + content = content.replace( + '\\\"nemo_gym\\\": {', + '\\\"erdos_discovery\\\": {\\n \\\"actor_class_fqn\\\": \\\"nemo_rl.environments.erdos_discovery_environment.ErdosDiscoveryEnvironment\\\",\\n },\\n \\\"nemo_gym\\\": {' + ) + with open(path, 'w') as f: + f.write(content) + print('Patched utils.py') +\" && \ + +python examples/run_discover.py \ + --config examples/configs/grpo_erdos_discover.yaml +" + +echo "Submitting Erdős TTT-Discover 120B..." +echo " Container: $CONTAINER" +echo " Model: $MODEL_PATH" +echo " Nodes: 8 (2 inference + 6 training)" +echo " Exp: $EXP" + +COMMAND="$COMMAND" \ +CONTAINER="$CONTAINER" \ +MOUNTS="$MOUNTS" \ +GPUS_PER_NODE=8 \ +sbatch \ + --nodes=8 --partition=batch --exclusive \ + --job-name=erdos-120b --time=12:00:00 \ + --output="$EXP/slurm-%j.out" \ + --error="$EXP/slurm-%j.err" \ + --exclude=d2dfac12-001,d2dfac12-002,d2dfac12-004,d2dfac12-007,d2dfac12-008,d2dfac12-019,d2dfac12-027,d2dfac12-028,d2dfac12-029 \ + ray.sub + +echo "Logs: $EXP/" From 2dd29f26af5d0b8d335db9b552a0ed2aede16c25 Mon Sep 17 00:00:00 2001 From: Morgane Moss Date: Wed, 1 Apr 2026 15:30:01 +0000 Subject: [PATCH 14/48] fix: inherit from grpo_superv3.yaml for correct Megatron+NemotronH config --- examples/configs/grpo_erdos_discover.yaml | 69 ++- examples/configs/grpo_superv3.yaml | 489 ++++++++++++++++++++++ 2 files changed, 513 insertions(+), 45 deletions(-) create mode 100644 examples/configs/grpo_superv3.yaml diff --git a/examples/configs/grpo_erdos_discover.yaml b/examples/configs/grpo_erdos_discover.yaml index f005822e06..1846e598f4 100644 --- a/examples/configs/grpo_erdos_discover.yaml +++ b/examples/configs/grpo_erdos_discover.yaml @@ -1,18 +1,16 @@ # TTT-Discover Erdős — Nemotron-3-Super-120B-A12B -# Based on Dakota's working grpo_superv3.yaml adapted for Erdős TTT-Discover. -# 8 nodes: 2 inference (vLLM TP=8) + 6 training (Megatron TP=4 EP=8) -defaults: "grpo_math_1B.yaml" +# Inherits from grpo_superv3.yaml (the working Nemotron Super config) +defaults: "grpo_superv3.yaml" grpo: - num_prompts_per_step: 8 # 8 PUCT-selected states per step - num_generations_per_prompt: 64 # 64 rollouts per state (paper default) - max_num_epochs: 1 - max_num_steps: 50 # 50 training steps (paper default) - max_rollout_turns: 1 # Single-turn code generation + num_prompts_per_step: 8 + num_generations_per_prompt: 64 + max_num_steps: 50 + max_rollout_turns: 1 remove_constant_reward_groups: true adv_estimator: name: entropic_adaptive_beta - gamma: 0.6931471805599453 # ln(2) + gamma: 0.6931471805599453 loss_fn: kl_penalty_coef: 0.1 @@ -25,17 +23,27 @@ policy: name: "/home/shared/models/NVIDIA-Nemotron-3-Super-120B-A12B-Base-BF16" chat_template_kwargs: null max_total_sequence_length: 16384 - train_global_batch_size: 512 # 8 groups × 64 rollouts + train_global_batch_size: 512 train_micro_batch_size: 1 logprob_batch_size: 1 - # Megatron backend for 120B MoE + generation: + colocated: + enabled: false + resources: + num_nodes: 2 + gpus_per_node: 8 + max_new_tokens: 16384 + vllm_cfg: + tensor_parallel_size: 8 + gpu_memory_utilization: 0.85 + max_model_len: 16384 + megatron_cfg: - enabled: true tensor_model_parallel_size: 4 pipeline_model_parallel_size: 1 context_parallel_size: 1 - expert_model_parallel_size: 8 # 128 experts / 8 EP = 16 experts per rank + expert_model_parallel_size: 8 sequence_parallel: true activation_checkpointing: true empty_unused_memory_level: 2 @@ -44,47 +52,18 @@ policy: optimizer_cpu_offload: true optimizer_offload_fraction: 1.0 - dtensor_cfg: + dynamic_batching: enabled: false lora_cfg: - enabled: false # No LoRA for now (Megatron backend) - - generation: - backend: "vllm" - colocated: - enabled: false - resources: - num_nodes: 2 # 2 nodes for vLLM inference - gpus_per_node: 8 - max_new_tokens: 16384 - temperature: 1.0 - top_p: 1.0 - stop_token_ids: null - stop_strings: null - vllm_cfg: - async_engine: false - tensor_parallel_size: 8 # Full node TP for inference - pipeline_parallel_size: 1 - expert_parallel_size: 1 - gpu_memory_utilization: 0.85 - max_model_len: 16384 - - dynamic_batching: enabled: false optimizer: - name: adamw - lr: 4.0e-5 # Paper default - betas: [0.9, 0.999] - weight_decay: 0.01 - scheduler: - name: cosine - warmup_steps: 2 - min_lr_ratio: 0.1 + lr: 4.0e-5 data: shuffle: false + max_input_seq_length: 16384 env: erdos_discovery: diff --git a/examples/configs/grpo_superv3.yaml b/examples/configs/grpo_superv3.yaml new file mode 100644 index 0000000000..5cc06be7dd --- /dev/null +++ b/examples/configs/grpo_superv3.yaml @@ -0,0 +1,489 @@ +checkpointing: + enabled: true + checkpoint_dir: "results/grpo" + metric_name: "val:total_reward/mean" + higher_is_better: true + keep_top_k: 1000000 + save_period: 10 + checkpoint_must_save_by: "00:03:30:00" + model_save_format: "safetensors" + save_consolidated: false + +grpo: + num_prompts_per_step: 128 + num_generations_per_prompt: 16 + num_val_generations_per_prompt: 2 + max_rollout_turns: 1 + max_num_epochs: 1 + max_num_steps: 1000000 + normalize_rewards: true + use_leave_one_out_baseline: true + # Clipping bounds for normalized advantages to prevent extreme values from small std + # Set to null to disable clipping (default), or e.g. -100/100 to clip + advantage_clip_low: null + advantage_clip_high: null + val_period: 5 + val_at_start: false + val_at_end: false + overlong_filtering: false + max_val_samples: null + val_batch_size: 256 + seed: 42 + + use_dynamic_sampling: false + dynamic_sampling_max_gen_batches: 10 + batch_multiplier: 1 + + penalize_invalid_tool_call: true + invalid_tool_call_advantage: -5.0 + penalize_malformed_thinking: true + malformed_thinking_advantage: -5.0 + + reward_shaping: + enabled: false + overlong_buffer_length: 128 + overlong_buffer_penalty: 1 + max_response_length: ${policy.max_total_sequence_length} + stop_properly_penalty_coef: null + reward_scaling: + enabled: false + source_min: 0.0 + source_max: 1.0 + target_min: 0.0 + target_max: 1.0 + + async_grpo: + enabled: false + max_trajectory_age_steps: 1 + in_flight_weight_updates: false + recompute_kv_cache_after_weight_updates: false + + seq_logprob_error_threshold: 2 + +loss_fn: + reference_policy_kl_penalty: 0.0 + reference_policy_kl_type: "k3" + kl_input_clamp_value: null + kl_output_clamp_value: null + + ratio_clip_min: 0.2 + ratio_clip_max: 0.28 + ratio_clip_c: null + use_on_policy_kl_approximation: true + use_importance_sampling_correction: true + truncated_importance_sampling_ratio: null + truncated_importance_sampling_ratio_min: null + truncated_importance_sampling_type: tis + sequence_level_importance_ratios: false + token_level_loss: true + force_on_policy_ratio: false + use_kl_in_reward: false + +policy: + model_name: "/lustre/fsw/portfolios/llmservice/projects/llmservice_nemotron_nano/users/pjin/checkpoints/nano-v3-sft-64gbs-nickel-capybara-5e-5-constant-wd-0-load-bal-1e-4-lcx3-pretool-base-temp1-iter-0013600-hf" + tokenizer: + name: ${policy.model_name} + chat_template_kwargs: null + hf_config_overrides: {} + train_global_batch_size: 2048 + train_micro_batch_size: 1 + generation_batch_size: 64 + logprob_batch_size: 1 + max_total_sequence_length: 16384 + precision: "bfloat16" + logprob_chunk_size: 2048 + offload_optimizer_for_logprob: false + + dtensor_cfg: + _v2: true + enabled: false + cpu_offload: False + sequence_parallel: false + activation_checkpointing: false + tensor_parallel_size: 1 + context_parallel_size: 1 + custom_parallel_plan: null + + megatron_cfg: + enabled: true + empty_unused_memory_level: 1 + activation_checkpointing: true + tensor_model_parallel_size: 2 + expert_tensor_parallel_size: 1 + expert_model_parallel_size: 8 + pipeline_model_parallel_size: 2 + num_layers_in_first_pipeline_stage: null + num_layers_in_last_pipeline_stage: null + context_parallel_size: 4 + pipeline_dtype: ${policy.precision} + sequence_parallel: true + freeze_moe_router: true + moe_router_dtype: "fp32" + moe_router_load_balancing_type: "none" + moe_router_bias_update_rate: 1.0e-3 + moe_permute_fusion: true + moe_enable_deepep: false + moe_token_dispatcher_type: "alltoall" + moe_aux_loss_coeff: 0.0 + moe_router_enable_expert_bias: true + moe_shared_expert_overlap: false + apply_rope_fusion: True + bias_activation_fusion: False + defer_fp32_logits: True + moe_per_layer_logging: True + + mtp_loss_scaling_factor: 0.0 + mtp_use_repeated_layer: true + mtp_num_layers: 0 + mtp_detach_heads: true + + optimizer: + optimizer: "adam" + lr: 3.0e-6 + min_lr: 3.0e-6 + weight_decay: 0.0 + bf16: true + fp16: false + params_dtype: "float32" + + adam_beta1: 0.9 + adam_beta2: 0.999 + adam_eps: 1e-8 + + sgd_momentum: 0.9 + + use_distributed_optimizer: true + use_precision_aware_optimizer: true + + clip_grad: ${policy.max_grad_norm} + + optimizer_cpu_offload: false + optimizer_offload_fraction: 0.0 + + scheduler: + start_weight_decay: ${policy.megatron_cfg.optimizer.weight_decay} + end_weight_decay: ${policy.megatron_cfg.optimizer.weight_decay} + weight_decay_incr_style: "constant" + lr_decay_style: "constant" + lr_decay_iters: null + lr_warmup_iters: 10 + lr_warmup_init: 3e-7 + + distributed_data_parallel_config: + grad_reduce_in_fp32: false + overlap_grad_reduce: true + overlap_param_gather: true + use_custom_fsdp: false + data_parallel_sharding_strategy: "optim_grads_params" + + fp8_cfg: + enabled: false + fp8: "e4m3" + fp8_recipe: "blockwise" + fp8_param: false + + env_vars: null + + dynamic_batching: + enabled: False + train_mb_tokens: ${mul:${policy.max_total_sequence_length}, ${policy.train_micro_batch_size}} + logprob_mb_tokens: ${mul:${policy.max_total_sequence_length}, ${policy.logprob_batch_size}} + sequence_length_round: 64 + + sequence_packing: + enabled: True + train_mb_tokens: ${mul:${policy.max_total_sequence_length}, ${policy.train_micro_batch_size}} + logprob_mb_tokens: ${mul:${policy.max_total_sequence_length}, ${policy.logprob_batch_size}} + algorithm: "modified_first_fit_decreasing" + sequence_length_round: 64 + + make_sequence_length_divisible_by: ${policy.megatron_cfg.tensor_model_parallel_size} + max_grad_norm: 1.0 + + optimizer: null + scheduler: null + + generation: + port_range_low: 11001 + port_range_high: 15000 + backend: "vllm" + max_new_tokens: ${policy.max_total_sequence_length} + temperature: 1.0 + top_p: 1.0 + top_k: null + stop_token_ids: null + stop_strings: null + vllm_cfg: + async_engine: true + precision: ${policy.precision} + kv_cache_dtype: "auto" + tensor_parallel_size: 4 + pipeline_parallel_size: 1 + expert_parallel_size: 1 + gpu_memory_utilization: 0.5 + max_model_len: ${policy.max_total_sequence_length} + enforce_eager: False + use_deep_gemm: False + num_last_layers_in_bf16: 0 + num_first_layers_in_bf16: 0 + enable_vllm_metrics_logger: true + vllm_metrics_logger_interval: 0.5 + expose_http_server: true + http_server_serving_chat_kwargs: + enable_auto_tools: true + tool_parser: qwen3_coder + reasoning_parser: nano_v3 + reasoning_parser_plugin: nemo_rl/utils/nano_v3_reasoning_parser.py + + + vllm_kwargs: + mamba_ssm_cache_dtype: "float32" +# compilation_config: +# mode: 0 + colocated: + enabled: true + resources: + gpus_per_node: null + num_nodes: null + +data: + max_input_seq_length: null + shuffle: false + num_workers: 1 + train: + data_path: "/lustre/fsw/portfolios/llmservice/projects/llmservice_nemotron_nano/users/pjin/data/nano-v3-posttraining-data/curriculum_v7_acrid-teal_main_rename.train.jsonl" + validation: + data_path: "/lustre/fsw/portfolios/llmservice/projects/llmservice_nemotron_nano/users/pjin/data/nano-v3-posttraining-data/curriculum_v7_acrid-teal_main_rename.val.jsonl" + default: + dataset_name: NemoGymDataset + env_name: "nemo_gym" + prompt_file: null + system_prompt_file: null + processor: "nemo_gym_data_processor" + +env: + should_use_nemo_gym: true + use_genrm_compare: true + genrm_agent_names: + - "genrm_simple_agent" + - "genrm_simple_agent_reasoning_off" + genrm_compare_server_name: "genrm_compare" + nemo_gym: + num_gpu_nodes: 4 + port_range_low: 15001 + port_range_high: 20000 + invalid_tool_call_patterns: + - "" + - "" + - "" + - "" + thinking_tags: + - "" + - "" + config_paths: + - responses_api_models/vllm_model/configs/vllm_model_for_training.yaml + - resources_servers/math_with_judge/configs/math_with_judge.yaml + - resources_servers/code_gen/configs/code_gen.yaml + - resources_servers/workplace_assistant/configs/workplace_assistant.yaml + - resources_servers/mcqa/configs/mcqa.yaml + - resources_servers/instruction_following/configs/instruction_following.yaml + - resources_servers/structured_outputs/configs/structured_outputs_json.yaml + - resources_servers/equivalence_llm_judge/configs/lc_judge.yaml + - resources_servers/calendar/configs/calendar.yaml + - resources_servers/genrm_compare/configs/genrm_compare.yaml + - resources_servers/equivalence_llm_judge/configs/nl2bash-equivalency.yaml + - resources_servers/equivalence_llm_judge/configs/equivalence_llm_judge.yaml + - resources_servers/single_step_tool_use_with_argument_comparison/configs/single_step_tool_use_with_argument_comparison.yaml + - resources_servers/reasoning_gym/configs/reasoning_gym.yaml + - resources_servers/terminal_pivot/configs/terminal_pivot.yaml + - resources_servers/ns_tools/configs/ns_tools.yaml + - resources_servers/math_formal_lean/configs/math_formal_lean_multi_turn.yaml + - resources_servers/swerl_gen/configs/swerl_gen.yaml + - resources_servers/jailbreak_detection/configs/jailbreak_detection_nemotron_combined_reward_tp8.yaml + - resources_servers/over_refusal_detection/configs/over_refusal_detection_nemotron_tp8.yaml + - resources_servers/multichallenge/configs/multichallenge.yaml + - resources_servers/inverse_if/configs/inverse_if.yaml + - resources_servers/single_step_tool_use_with_argument_comparison/configs/search_pivot_single_step_tool_use_with_argument_comparison.yaml + - resources_servers/single_step_tool_use_with_argument_comparison/configs/toolcall_schema_single_step_tool_use_with_argument_comparison.yaml + + jailbreak_detection: + resources_servers: + jailbreak_detection: + judge_model_server: + type: responses_api_models + name: safety_judge_model + + safety_judge_model: + responses_api_models: + vllm_model: + entrypoint: app.py + base_url: http://127.0.0.1:8001/v1 + api_key: dummy_key + model: /scratch/fsw/portfolios/llmservice/users/makeshn/super_v3/model_checkpoints/Nemotron-Content-Safety-Reasoning-4B + return_token_id_information: false + uses_reasoning_parser: false + spinup_server: true + router_dp_size: 8 + server_args: + tensor_parallel_size: 1 + gpu_memory_utilization: 0.85 + max_model_len: 96000 + model_loader_extra_config: + enable_multithread_load: true + num_threads: 2 + server_env: + VLLM_ATTENTION_BACKEND: TRITON_ATTN + + terminal_pivot_simple_agent: + responses_api_agents: + simple_agent: + model_server: + name: policy_model + + nl2bash_judge_model: + responses_api_models: + vllm_model: + entrypoint: app.py + base_url: http://127.0.0.1:10000/v1 + api_key: dummy_key + model: "/scratch/fsw/portfolios/llmservice/users/jiaqiz/models/Qwen3-235B-A22B-Instruct-2507-FP8" + return_token_id_information: False + uses_reasoning_parser: False + spinup_server: True + router_dp_size: 2 + server_args: + tensor_parallel_size: 8 + data_parallel_size: 1 + enable_expert_parallel: True + enable_auto_tool_choice: true + tool_call_parser: hermes + gpu_memory_utilization: 0.85 + max_model_len: 131072 + model_loader_extra_config: + enable_multithread_load: true + num_threads: 112 + + inverse_if: + resources_servers: + inverse_if: + judge_model_server: + type: responses_api_models + name: nl2bash_judge_model + + multichallenge: + resources_servers: + multichallenge: + judge_model_server: + type: responses_api_models + name: nl2bash_judge_model + judge_responses_create_params: + max_output_tokens: 8192 + + equivalence_llm_judge: + resources_servers: + equivalence_llm_judge: + judge_model_server: + name: nl2bash_judge_model + judge_responses_create_params: + max_output_tokens: 8192 + + genrm_compare: + resources_servers: + genrm_compare: + # Points to the GenRM model server defined above + genrm_model_server: + type: responses_api_models + name: genrm_model + # GenRM request parameters + genrm_responses_create_params: + max_output_tokens: 16384 + temperature: 0.6 + top_p: 0.95 + # Comparison settings + comparison_strategy: "circular" + num_judges_per_comparison: 1 + use_principle: true + default_principle: "Please act as an impartial judge and evaluate the quality of the responses provided by two AI assistants to the user prompt. Begin your evaluation by generating your own answer to the prompt. You must provide your answer before judging any answers. When evaluating the assistants' answers, compare both assistants' answers with your answer. You must identify and correct any mistakes or inaccurate information. Then consider if the assistant's answers are helpful, relevant, and concise. Helpful means the answer correctly responds to the prompt or follows the instructions. Note when user prompt has any ambiguity or more than one interpretation, it is more helpful and appropriate to ask for clarifications or more information from the user than providing an answer based on assumptions. Relevant means all parts of the response closely connect or are appropriate to what is being asked. Concise means the response is clear and not verbose or excessive. Then consider the creativity and novelty of the assistant's answers when needed. Finally, identify any missing important information in the assistants' answers that would be beneficial to include when responding to the user prompt." + aggregator_method: "simple_tiebreaker" + reasoning_bonus: 0.5 + answer_bonus: 0.5 + top_percentile: 0.2 + group_reasoning_length_penalty_coeff: 0.1 + group_answer_length_penalty_coeff: 0.1 + group_style_penalty_coeff: 0.1 + default_score: 3.0 + default_ranking: 3.5 + + genrm_model: + responses_api_models: + vllm_model: + entrypoint: app.py + base_url: http://127.0.0.1:8000/v1 + api_key: dummy_key + model: "/lustre/fsw/portfolios/llmservice/users/jiaqiz/models/qwen235b_principle_comparison_genrm_step1230" + uses_reasoning_parser: True + return_token_id_information: False + spinup_server: True + router_dp_size: 4 + server_args: + tensor_parallel_size: 8 + reasoning_parser: deepseek_r1 + gpu_memory_utilization: 0.85 + max_model_len: 60000 + model_loader_extra_config: + enable_multithread_load: true + num_threads: 112 + + + lc_judge: + resources_servers: + equivalence_llm_judge: + judge_model_server: + name: nl2bash_judge_model + judge_responses_create_params: + max_output_tokens: 8192 + + math_with_judge: + resources_servers: + math_with_judge: + judge_model_server: + name: nl2bash_judge_model + judge_responses_create_params: + max_output_tokens: 8192 + should_use_judge: true + code_gen: + resources_servers: + code_gen: + num_processes: 1024 + unit_test_timeout_secs: 10 + debug: false + +logger: + log_dir: "logs" + num_val_samples_to_print: 0 + wandb_enabled: false + tensorboard_enabled: false + mlflow_enabled: false + monitor_gpus: true + swanlab_enabled: false + wandb: + project: "grpo-dev" + name: "grpo-dev-logger" + tensorboard: {} + mlflow: + experiment_name: "grpo-dev" + run_name: "grpo-dev-logger" + gpu_monitoring: + collection_interval: 10 + flush_interval: 10 + +cluster: + gpus_per_node: 8 + num_nodes: 1 + +# uncomment to enable effort level training +effort_levels: + low_string: "{reasoning effort: low}" + low_weight: 0.1 + low_penalty: 1 + low_ub: 3000 From 2cbb9bec844383ad35a9ae64f8eebd468b6f21da Mon Sep 17 00:00:00 2001 From: Morgane Moss Date: Wed, 1 Apr 2026 15:43:22 +0000 Subject: [PATCH 15/48] fix: set async_engine false for 120B to avoid engine core crash --- examples/configs/grpo_erdos_discover.yaml | 1 + 1 file changed, 1 insertion(+) diff --git a/examples/configs/grpo_erdos_discover.yaml b/examples/configs/grpo_erdos_discover.yaml index 1846e598f4..96ee2e8c1e 100644 --- a/examples/configs/grpo_erdos_discover.yaml +++ b/examples/configs/grpo_erdos_discover.yaml @@ -35,6 +35,7 @@ policy: gpus_per_node: 8 max_new_tokens: 16384 vllm_cfg: + async_engine: false tensor_parallel_size: 8 gpu_memory_utilization: 0.85 max_model_len: 16384 From d265e721323652c6812fe74888a9f918836b01cd Mon Sep 17 00:00:00 2001 From: Morgane Moss Date: Wed, 1 Apr 2026 15:56:23 +0000 Subject: [PATCH 16/48] =?UTF-8?q?fix:=20batch=20size=20504=20(8=C3=9763)?= =?UTF-8?q?=20divisible=20by=20DP=3D12?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- examples/configs/grpo_erdos_discover.yaml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/configs/grpo_erdos_discover.yaml b/examples/configs/grpo_erdos_discover.yaml index 96ee2e8c1e..308294f5c2 100644 --- a/examples/configs/grpo_erdos_discover.yaml +++ b/examples/configs/grpo_erdos_discover.yaml @@ -4,7 +4,7 @@ defaults: "grpo_superv3.yaml" grpo: num_prompts_per_step: 8 - num_generations_per_prompt: 64 + num_generations_per_prompt: 63 max_num_steps: 50 max_rollout_turns: 1 remove_constant_reward_groups: true @@ -23,7 +23,7 @@ policy: name: "/home/shared/models/NVIDIA-Nemotron-3-Super-120B-A12B-Base-BF16" chat_template_kwargs: null max_total_sequence_length: 16384 - train_global_batch_size: 512 + train_global_batch_size: 504 train_micro_batch_size: 1 logprob_batch_size: 1 From dcb07dc80f9c302dd131f791bcb8515df1dc9df5 Mon Sep 17 00:00:00 2001 From: Morgane Moss Date: Wed, 1 Apr 2026 16:09:04 +0000 Subject: [PATCH 17/48] fix: version-agnostic setup() unpacking for super-v3 container compat --- examples/run_discover.py | 79 +++++++++++++++++++++++++--------------- 1 file changed, 50 insertions(+), 29 deletions(-) diff --git a/examples/run_discover.py b/examples/run_discover.py index 12e0d1dfe6..91006dbde0 100644 --- a/examples/run_discover.py +++ b/examples/run_discover.py @@ -324,35 +324,56 @@ def main(): setup_discover_data(config, tokenizer) ) - # Setup policy, generation, cluster, dataloader, etc. - ( - policy, - policy_generation, - clusters, - dataloader, - val_dataloader, - loss_fn, - nemo_logger, - checkpointer, - grpo_state, - master_config, - ) = setup(config, tokenizer, train_dataset, val_dataset) - - # Run GRPO training - grpo_train( - policy, - policy_generation, - dataloader, - val_dataloader, - tokenizer, - loss_fn, - task_to_env, - val_task_to_env, - nemo_logger, - checkpointer, - grpo_state, - master_config, - ) + # Setup returns vary across container versions — unpack dynamically + setup_result = setup(config, tokenizer, train_dataset, val_dataset) + + # Inspect the grpo_train signature to know what to pass + import inspect + train_sig = inspect.signature(grpo_train) + train_params = list(train_sig.parameters.keys()) + print(f" setup() returned {len(setup_result)} values") + print(f" grpo_train() expects {len(train_params)} params: {train_params[:5]}...") + + # The standard pattern: setup returns everything grpo_train needs, + # except task_to_env and val_task_to_env which we provide. + # Detect where to inject them based on parameter names. + setup_list = list(setup_result) + + # Build kwargs for grpo_train by matching setup outputs + our envs + # Common signatures: + # v0.5.0: setup returns (policy, gen, dl, val_dl, tokenizer, loss, env, val_env, logger, ckpt, state, config) + # super-v3: may return more + # Strategy: pass setup outputs positionally, but swap in our envs + if len(setup_list) == 12: + # v0.5.0 style: already includes env placeholders at positions 6,7 + setup_list[6] = task_to_env + setup_list[7] = val_task_to_env + grpo_train(*setup_list) + elif len(setup_list) == 10: + # Older style without envs + policy, policy_generation, dataloader, val_dataloader, tokenizer_out, loss_fn, nemo_logger, checkpointer, grpo_state, master_config = setup_list + grpo_train( + policy, policy_generation, dataloader, val_dataloader, + tokenizer_out, loss_fn, task_to_env, val_task_to_env, + nemo_logger, checkpointer, grpo_state, master_config, + ) + else: + # Unknown format — try passing everything with envs injected + # Find the positions of env-like params in grpo_train signature + env_idx = None + for i, p in enumerate(train_params): + if 'task_to_env' in p and 'val' not in p: + env_idx = i + break + if env_idx is not None: + # Insert our envs at the right position + args = list(setup_list) + args.insert(env_idx, task_to_env) + args.insert(env_idx + 1, val_task_to_env) + grpo_train(*args[:len(train_params)]) + else: + print(f"WARNING: Could not determine grpo_train signature, trying positional") + grpo_train(*setup_list) if __name__ == "__main__": From 081d5b5a4d4c53ab82f36cd9a07048c52f19483b Mon Sep 17 00:00:00 2001 From: Morgane Moss Date: Wed, 1 Apr 2026 16:23:41 +0000 Subject: [PATCH 18/48] =?UTF-8?q?fix:=20explicit=20setup=E2=86=92grpo=5Ftr?= =?UTF-8?q?ain=20wiring=20for=20super-v3=20container=20(11=20return=20valu?= =?UTF-8?q?es)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- examples/run_discover.py | 81 ++++++++++++++++++---------------------- 1 file changed, 36 insertions(+), 45 deletions(-) diff --git a/examples/run_discover.py b/examples/run_discover.py index 91006dbde0..797b8863a7 100644 --- a/examples/run_discover.py +++ b/examples/run_discover.py @@ -324,56 +324,47 @@ def main(): setup_discover_data(config, tokenizer) ) - # Setup returns vary across container versions — unpack dynamically + # Setup and grpo_train have different signatures across container versions. + # super-v3 setup() returns 11 values: + # policy, policy_gen, nemo_gym, clusters, dataloader, val_dataloader, + # loss_fn, logger, checkpointer, grpo_state, master_config + # grpo_train() expects 12 params: + # policy, policy_gen, dataloader, val_dataloader, tokenizer, + # loss_fn, task_to_env, val_task_to_env, logger, checkpointer, + # grpo_state, master_config setup_result = setup(config, tokenizer, train_dataset, val_dataset) - - # Inspect the grpo_train signature to know what to pass - import inspect - train_sig = inspect.signature(grpo_train) - train_params = list(train_sig.parameters.keys()) - print(f" setup() returned {len(setup_result)} values") - print(f" grpo_train() expects {len(train_params)} params: {train_params[:5]}...") - - # The standard pattern: setup returns everything grpo_train needs, - # except task_to_env and val_task_to_env which we provide. - # Detect where to inject them based on parameter names. setup_list = list(setup_result) - - # Build kwargs for grpo_train by matching setup outputs + our envs - # Common signatures: - # v0.5.0: setup returns (policy, gen, dl, val_dl, tokenizer, loss, env, val_env, logger, ckpt, state, config) - # super-v3: may return more - # Strategy: pass setup outputs positionally, but swap in our envs - if len(setup_list) == 12: - # v0.5.0 style: already includes env placeholders at positions 6,7 - setup_list[6] = task_to_env - setup_list[7] = val_task_to_env - grpo_train(*setup_list) - elif len(setup_list) == 10: - # Older style without envs - policy, policy_generation, dataloader, val_dataloader, tokenizer_out, loss_fn, nemo_logger, checkpointer, grpo_state, master_config = setup_list + n = len(setup_list) + print(f" setup() returned {n} values") + + if n == 11: + # super-v3 container + (policy, policy_generation, _nemo_gym, _clusters, + dataloader, val_dataloader, loss_fn, + nemo_logger, checkpointer, grpo_state, master_config) = setup_list + grpo_train( + policy, policy_generation, + dataloader, val_dataloader, + tokenizer, loss_fn, + task_to_env, val_task_to_env, + nemo_logger, checkpointer, + grpo_state, master_config, + ) + elif n == 10: + # v0.5.0 container (no nemo_gym) + (policy, policy_generation, dataloader, val_dataloader, + loss_fn, nemo_logger, checkpointer, grpo_state, + master_config, _extra) = setup_list grpo_train( - policy, policy_generation, dataloader, val_dataloader, - tokenizer_out, loss_fn, task_to_env, val_task_to_env, - nemo_logger, checkpointer, grpo_state, master_config, + policy, policy_generation, + dataloader, val_dataloader, + tokenizer, loss_fn, + task_to_env, val_task_to_env, + nemo_logger, checkpointer, + grpo_state, master_config, ) else: - # Unknown format — try passing everything with envs injected - # Find the positions of env-like params in grpo_train signature - env_idx = None - for i, p in enumerate(train_params): - if 'task_to_env' in p and 'val' not in p: - env_idx = i - break - if env_idx is not None: - # Insert our envs at the right position - args = list(setup_list) - args.insert(env_idx, task_to_env) - args.insert(env_idx + 1, val_task_to_env) - grpo_train(*args[:len(train_params)]) - else: - print(f"WARNING: Could not determine grpo_train signature, trying positional") - grpo_train(*setup_list) + raise RuntimeError(f"Unexpected setup() return count: {n}. Check container version.") if __name__ == "__main__": From def908951b03cb90acb3fb43ae6e60cb54e68bfe Mon Sep 17 00:00:00 2001 From: Morgane Moss Date: Wed, 1 Apr 2026 16:47:04 +0000 Subject: [PATCH 19/48] fix: reduce seq length to 4096 to avoid OOM during training --- examples/configs/grpo_erdos_discover.yaml | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/examples/configs/grpo_erdos_discover.yaml b/examples/configs/grpo_erdos_discover.yaml index 308294f5c2..90f1a83c03 100644 --- a/examples/configs/grpo_erdos_discover.yaml +++ b/examples/configs/grpo_erdos_discover.yaml @@ -22,7 +22,7 @@ policy: tokenizer: name: "/home/shared/models/NVIDIA-Nemotron-3-Super-120B-A12B-Base-BF16" chat_template_kwargs: null - max_total_sequence_length: 16384 + max_total_sequence_length: 4096 train_global_batch_size: 504 train_micro_batch_size: 1 logprob_batch_size: 1 @@ -33,12 +33,12 @@ policy: resources: num_nodes: 2 gpus_per_node: 8 - max_new_tokens: 16384 + max_new_tokens: 3072 vllm_cfg: async_engine: false tensor_parallel_size: 8 gpu_memory_utilization: 0.85 - max_model_len: 16384 + max_model_len: 4096 megatron_cfg: tensor_model_parallel_size: 4 @@ -64,7 +64,7 @@ optimizer: data: shuffle: false - max_input_seq_length: 16384 + max_input_seq_length: 4096 env: erdos_discovery: From c3b0971372004d817c6d1a5f10525ed3b84ee569 Mon Sep 17 00:00:00 2001 From: Morgane Moss Date: Wed, 1 Apr 2026 18:33:09 +0000 Subject: [PATCH 20/48] feat: port exact reference TTT-Discover env + prompts Rewrites erdos_discovery_environment.py and run_discover.py to match the reference implementation at github.com/test-time-training/discover: - C5 = max(np.correlate(h, 1-h, mode="full") * dx) formulation - Code must define run(seed, budget_s) returning (h_values, c5_bound, n_points) - scipy/cvxpy allowed in sandbox - State context shows parent code + improvement tracking (State.to_prompt) - Full ErdosMinOverlapEnv.get_question() prompt with problem description - reward = 1 / (1e-8 + c5_bound) - Initial states: random n in [40,100], perturbed h=0.5 - Removes inline/HTTP mode split (always computes directly) - DiscoverDataset generates diverse initial states each step --- examples/configs/grpo_erdos_discover.yaml | 5 +- .../configs/grpo_erdos_discover_debug.yaml | 4 +- examples/run_discover.py | 196 ++--- .../erdos_discovery_environment.py | 719 ++++++++++-------- test_gptoss_vllm.sh | 42 + 5 files changed, 509 insertions(+), 457 deletions(-) create mode 100755 test_gptoss_vllm.sh diff --git a/examples/configs/grpo_erdos_discover.yaml b/examples/configs/grpo_erdos_discover.yaml index 90f1a83c03..fb58646a6b 100644 --- a/examples/configs/grpo_erdos_discover.yaml +++ b/examples/configs/grpo_erdos_discover.yaml @@ -68,11 +68,10 @@ data: env: erdos_discovery: - resource_server_url: "inline" num_initial_states: 16 num_groups_per_step: 8 - sandbox_timeout: 600 - request_timeout: 660 + sandbox_timeout: 1000 + should_use_nemo_gym: false cluster: diff --git a/examples/configs/grpo_erdos_discover_debug.yaml b/examples/configs/grpo_erdos_discover_debug.yaml index 528b5b9640..e6282123b2 100644 --- a/examples/configs/grpo_erdos_discover_debug.yaml +++ b/examples/configs/grpo_erdos_discover_debug.yaml @@ -64,11 +64,9 @@ data: env: erdos_discovery: - resource_server_url: "inline" num_initial_states: 8 num_groups_per_step: 4 - sandbox_timeout: 60 - request_timeout: 120 + sandbox_timeout: 120 checkpointing: enabled: false diff --git a/examples/run_discover.py b/examples/run_discover.py index 797b8863a7..256eda4c1b 100644 --- a/examples/run_discover.py +++ b/examples/run_discover.py @@ -1,28 +1,18 @@ """Run script for TTT-Discover GRPO training on the Erdős Minimum Overlap Problem. -This follows the sliding_puzzle pattern: custom IterableDataset that generates -prompts dynamically from a PUCT buffer, wired into the standard GRPO loop. +Matches the reference implementation at: + https://github.com/test-time-training/discover/blob/main/examples/erdos_min_overlap/env.py -Usage: - # Start the Gym resource server first (separate process/node): - cd ~/Gym && ng_run "+config_paths=[resources_servers/erdos_discovery/configs/erdos_discovery.yaml]" - - # Then run training: - cd ~/RL && uv run python examples/run_discover.py [--config examples/configs/grpo_erdos_discover.yaml] - -Reference: "Learning to Discover at Test Time" (arXiv:2601.16175) +Usage (inside NeMo RL container): + python examples/run_discover.py --config examples/configs/grpo_erdos_discover.yaml """ -import itertools -import argparse import itertools import logging import os import sys from typing import Optional -import aiohttp -import asyncio import numpy as np import ray import torch @@ -34,44 +24,13 @@ from nemo_rl.distributed.virtual_cluster import init_ray from nemo_rl.environments.erdos_discovery_environment import ( ErdosDiscoveryEnvironment, + build_erdos_question, + create_initial_state, ) from nemo_rl.models.generation import configure_generation_config -from nemo_rl.utils.config import load_config logger = logging.getLogger(__name__) -# ═══════════════════════════════════════════════════════════════════ -# Problem description (same as in the Gym resource server) -# ═══════════════════════════════════════════════════════════════════ - -PROBLEM_DESCRIPTION = """\ -Erdos Minimum Overlap Problem -============================== - -Goal: Find a step function f (Python list or NumPy array) giving the -tightest possible upper bound on the Erdos minimum overlap constant c. - -Background: - For integer n, partition {1,...,2n} into equal sets A, B. - M_k = #{(a,b) : a in A, b in B, a-b=k}. - c = lim_{n->inf} min_{A,B} max_k M_k / n. - -Known bounds: 0.379005 < c < 0.380927 (Haugland 2016) -Current best upper bound: 0.380876 (2026) - -Upper Bound via Step Functions: - f : [0,1] -> [0,1] with mean(f) = 0.5 gives: - bound = 2*n*max(autocorr(f)) / sum(f)^2 - where autocorr is computed via FFT. - Smaller bound -> higher reward (reward = 1/bound). - -Constraints: 1 <= len(f) <= 1000, 0 <= f[i] <= 1, mean(f) ~ 0.5 (tol 1e-3). - -Output: Python code defining variable `f` in a ```python block. -Allowed: numpy, math, random, itertools, functools, collections. -Execution limit: 600 seconds. Target: bound < 0.380876.\ -""" - # ═══════════════════════════════════════════════════════════════════ # Datum generation @@ -80,38 +39,25 @@ def generate_discover_datum( tokenizer, - state_info: dict, + state: dict, idx: int, task_name: str = "erdos_discovery", ) -> DatumSpec: - """Create a DatumSpec from a PUCT-selected state. - - Args: - tokenizer: HuggingFace tokenizer. - state_info: Dict from /select_state with keys: - state, context, reward, system_prompt, user_prompt. - idx: Datum index. - task_name: Task name for env routing. + """Create a DatumSpec from a state dict. - Returns: - DatumSpec ready for the GRPO training loop. + The prompt is built using the reference TTT-Discover get_question() format. """ - system_prompt = state_info.get("system_prompt", PROBLEM_DESCRIPTION) - user_prompt = state_info["user_prompt"] + user_prompt = build_erdos_question(state) messages: LLMMessageLogType = [ - {"role": "system", "content": system_prompt}, {"role": "user", "content": user_prompt}, ] - # Tokenize the prompt prompt_text = tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) prompt_ids = tokenizer.encode(prompt_text, add_special_tokens=False) - prompt_tensor = torch.tensor(prompt_ids, dtype=torch.long) - # Attach token_ids to messages for NeMo RL's message_log format for msg in messages: msg_text = tokenizer.apply_chat_template( [msg], tokenize=False, add_generation_prompt=False @@ -123,9 +69,12 @@ def generate_discover_datum( message_log=messages, length=len(prompt_ids), extra_env_info={ - "parent_state": state_info.get("state"), - "context": state_info.get("context"), - "reward": state_info.get("reward", 0.0), + "construction": state.get("construction"), + "c5_bound": state.get("c5_bound"), + "n_points": state.get("n_points"), + "code": state.get("code", ""), + "parent_c5": state.get("parent_c5"), + "observation": state.get("observation", ""), }, loss_multiplier=1.0, idx=idx, @@ -134,77 +83,45 @@ def generate_discover_datum( # ═══════════════════════════════════════════════════════════════════ -# Dynamic dataset backed by PUCT buffer +# Dataset backed by initial states (PUCT selection comes later) # ═══════════════════════════════════════════════════════════════════ class DiscoverDataset(IterableDataset): - """Iterable dataset that fetches prompts from the PUCT buffer each step. + """Dataset that generates prompts from Erdős initial states. - Each iteration fetches `num_groups_per_step` states from the Gym resource - server's /select_state endpoint and yields them as DatumSpecs. + Each iteration generates diverse initial states and yields them as + DatumSpecs with the reference TTT-Discover prompt format. - The dataset loops indefinitely — the training loop controls termination - via max_num_steps in the GRPO config. + For now, initial states are random perturbations of h=0.5 (matching + the reference). Future: PUCT buffer selects states based on prior + discoveries. """ def __init__( self, tokenizer, - resource_server_url: str, - num_groups_per_step: int = 8, + num_states_per_step: int = 8, task_name: str = "erdos_discovery", - length: int = 1000, # Nominal length for dataloader + length: int = 1000, + seed: int = 42, ): self.tokenizer = tokenizer - self.resource_server_url = resource_server_url - self.num_groups_per_step = num_groups_per_step + self.num_states_per_step = num_states_per_step self.task_name = task_name self.length = length self._idx_counter = itertools.count() - - def _fetch_states_sync(self) -> list[dict]: - """Synchronously fetch states from the PUCT buffer.""" - import requests - - try: - resp = requests.post( - f"{self.resource_server_url}/select_state", - json={ - "batch_size": self.num_groups_per_step, - "num_groups": self.num_groups_per_step, - }, - timeout=30, - ) - resp.raise_for_status() - data = resp.json() - return data.get("states", []) - except Exception as e: - logger.error("Failed to fetch states from PUCT buffer: %s", e) - # Return fallback: single default prompt - return [ - { - "state": [0.5] * 50, - "context": [], - "reward": 0.5, - "system_prompt": PROBLEM_DESCRIPTION, - "user_prompt": ( - "Starting construction (bound=2.000000, 50 pieces):\n" - "[0.5000, 0.5000, ..., 0.5000]\n\n" - "Improve on this construction. Write Python code that " - "defines a better step function `f`. Think carefully." - ), - } - ] + self._rng = np.random.default_rng(seed) def __iter__(self): for _ in itertools.count(): - states = self._fetch_states_sync() - for state_info in states: + # Generate fresh initial states each step + for _ in range(self.num_states_per_step): + state = create_initial_state(self._rng) idx = next(self._idx_counter) yield generate_discover_datum( self.tokenizer, - state_info, + state, idx=idx, task_name=self.task_name, ) @@ -219,38 +136,27 @@ def __len__(self): def setup_discover_data(config: MasterConfig, tokenizer): - """Create dataset, environment, and wire them together. - - Returns: - (train_dataset, val_dataset, task_to_env, val_task_to_env) - """ + """Create dataset, environment, and wire them together.""" env_config = config.get("env", {}).get("erdos_discovery", {}) - resource_server_url = env_config.get( - "resource_server_url", "http://localhost:8080" - ) - num_groups_per_step = env_config.get("num_groups_per_step", 8) + num_states = config.get("grpo", {}).get("num_prompts_per_step", 8) task_name = "erdos_discovery" - # Create the dynamic dataset train_dataset = DiscoverDataset( tokenizer=tokenizer, - resource_server_url=resource_server_url, - num_groups_per_step=num_groups_per_step, + num_states_per_step=num_states, task_name=task_name, - length=config["grpo"]["max_num_steps"] * num_groups_per_step, + length=config.get("grpo", {}).get("max_num_steps", 50) * num_states, + seed=config.get("seed", 42), ) - # Validation dataset: same thing (could be a fixed set, but for discovery - # we just re-sample from the buffer) val_dataset = DiscoverDataset( tokenizer=tokenizer, - resource_server_url=resource_server_url, - num_groups_per_step=num_groups_per_step, + num_states_per_step=num_states, task_name=task_name, - length=num_groups_per_step, + length=num_states, + seed=config.get("seed", 42) + 1, ) - # Create the environment as a Ray actor env = ErdosDiscoveryEnvironment.options( num_gpus=0, max_restarts=-1, @@ -269,11 +175,10 @@ def setup_discover_data(config: MasterConfig, tokenizer): def main(): - import os from omegaconf import OmegaConf from nemo_rl.utils.config import load_config - # Register custom resolvers needed by the base config + # Register custom resolvers if not OmegaConf.has_resolver("mul"): OmegaConf.register_new_resolver("mul", lambda a, b: a * b) if not OmegaConf.has_resolver("div"): @@ -283,7 +188,7 @@ def main(): from nemo_rl.utils.config import register_omegaconf_resolvers register_omegaconf_resolvers() except ImportError: - pass # v0.5.0 container doesn't have this + pass # Parse --config argument config_path = None @@ -297,13 +202,13 @@ def main(): if config_path is None: config_path = os.path.join( - os.path.dirname(__file__), "configs", "grpo_erdos_discover_debug.yaml" + os.path.dirname(__file__), "configs", "grpo_erdos_discover.yaml" ) print(f"Loading config from: {config_path}") config = load_config(config_path) - # Resolve OmegaConf interpolations (e.g. ${policy.model_name}) + # Resolve OmegaConf interpolations oc = OmegaConf.create(config) config = OmegaConf.to_container(oc, resolve=True) @@ -324,14 +229,7 @@ def main(): setup_discover_data(config, tokenizer) ) - # Setup and grpo_train have different signatures across container versions. - # super-v3 setup() returns 11 values: - # policy, policy_gen, nemo_gym, clusters, dataloader, val_dataloader, - # loss_fn, logger, checkpointer, grpo_state, master_config - # grpo_train() expects 12 params: - # policy, policy_gen, dataloader, val_dataloader, tokenizer, - # loss_fn, task_to_env, val_task_to_env, logger, checkpointer, - # grpo_state, master_config + # Setup returns vary across container versions setup_result = setup(config, tokenizer, train_dataset, val_dataset) setup_list = list(setup_result) n = len(setup_list) @@ -351,7 +249,7 @@ def main(): grpo_state, master_config, ) elif n == 10: - # v0.5.0 container (no nemo_gym) + # v0.5.0 container (policy, policy_generation, dataloader, val_dataloader, loss_fn, nemo_logger, checkpointer, grpo_state, master_config, _extra) = setup_list @@ -364,7 +262,7 @@ def main(): grpo_state, master_config, ) else: - raise RuntimeError(f"Unexpected setup() return count: {n}. Check container version.") + raise RuntimeError(f"Unexpected setup() return count: {n}") if __name__ == "__main__": diff --git a/nemo_rl/environments/erdos_discovery_environment.py b/nemo_rl/environments/erdos_discovery_environment.py index e192c3619b..be8c14b9d8 100644 --- a/nemo_rl/environments/erdos_discovery_environment.py +++ b/nemo_rl/environments/erdos_discovery_environment.py @@ -1,268 +1,454 @@ -"""Erdős Discovery Environment for NeMo RL. +"""Erdős Minimum Overlap Discovery Environment — matches reference TTT-Discover implementation. -Implements EnvironmentInterface for TTT-Discover with the Erdős Minimum -Overlap Problem. Calls the NeMo Gym resource server for code execution -and reward computation. +Reference: https://github.com/test-time-training/discover/blob/main/examples/erdos_min_overlap/env.py +Paper: "Learning to Discover at Test Time" (arXiv:2601.16175) -Reference: "Learning to Discover at Test Time" (arXiv:2601.16175) - -The environment: - 1. Receives LLM-generated code from the GRPO rollout - 2. Sends it to the Erdős Gym resource server for sandboxed execution + scoring - 3. Returns reward = 1/bound (or 0 on failure) - 4. Tracks best constructions and buffer statistics via metrics +Key differences from our v1: +- Uses C5 = max(np.correlate(h, 1-h, mode="full") * dx) formulation (h over [0,2]) +- Code must define run(seed=42, budget_s=1000) returning (h_values, c5_bound, n_points) +- Allows scipy, cvxpy in addition to numpy/math +- State context shows parent code + improvement direction +- reward = 1 / (1e-8 + c5_bound) """ +import asyncio import logging import math +import re +import signal +import time from typing import Any, Optional -import aiohttp +import numpy as np import ray import torch -from nemo_rl.environments.interfaces import EnvironmentInterface, EnvironmentReturn +from nemo_rl.environments.interfaces import ( + EnvironmentInterface, + EnvironmentReturn, +) +from nemo_rl.data.interfaces import LLMMessageLogType logger = logging.getLogger(__name__) +# ═══════════════════════════════════════════════════════════════════ +# Reference C5 verification (from TTT-Discover env.py) +# ═══════════════════════════════════════════════════════════════════ + +def verify_c5_solution(h_values, c5_achieved, n_points): + """Verify a C5 solution — exact copy from reference implementation.""" + if not isinstance(h_values, np.ndarray): + try: + h_values = np.array(h_values, dtype=np.float64) + except (ValueError, TypeError) as e: + raise ValueError(f"Cannot convert h_values to numpy array: {e}") + + if len(h_values.shape) != 1: + raise ValueError(f"h_values must be 1D array, got shape {h_values.shape}") + + if h_values.shape[0] != n_points: + raise ValueError(f"Expected h shape ({n_points},), got {h_values.shape}") + + if not np.all(np.isfinite(h_values)): + raise ValueError("h_values contain NaN or inf values") + + if np.any(h_values < 0) or np.any(h_values > 1): + raise ValueError(f"h(x) is not in [0, 1]. Range: [{h_values.min()}, {h_values.max()}]") + + n = n_points + target_sum = n / 2.0 + current_sum = np.sum(h_values) + + if current_sum != target_sum: + h_values = h_values * (target_sum / current_sum) + if np.any(h_values < 0) or np.any(h_values > 1): + raise ValueError( + f"After normalization, h(x) is not in [0, 1]. " + f"Range: [{h_values.min()}, {h_values.max()}]" + ) + + dx = 2.0 / n_points + j_values = 1.0 - h_values + correlation = np.correlate(h_values, j_values, mode="full") * dx + computed_c5 = np.max(correlation) + + if not np.isfinite(computed_c5): + raise ValueError(f"Computed C5 is not finite: {computed_c5}") + + if not np.isclose(computed_c5, c5_achieved, atol=1e-4): + raise ValueError(f"C5 mismatch: reported {c5_achieved:.6f}, computed {computed_c5:.6f}") + + return computed_c5 + # ═══════════════════════════════════════════════════════════════════ -# Inline reward computation (no Gym server needed for debug/testing) +# Sandbox execution # ═══════════════════════════════════════════════════════════════════ -def _inline_compute_reward(response_text: str, timeout: int = 60) -> dict: - """Compute reward directly in-process. No HTTP call needed.""" - import re - import signal +_ALLOWED_MODULES = frozenset({ + "numpy", "np", "math", "cmath", "random", "scipy", "cvxpy", + "itertools", "functools", "collections", "fractions", "decimal", + "copy", "operator", "time", +}) + + +def _execute_run_function(code: str, timeout: int = 1000, n_cpus: int = 2) -> dict: + """Execute code that defines run(), call it, verify the result. + + Matches the reference SandboxRewardEvaluator flow. + """ import builtins - import math as _math - import itertools as _itertools - import functools as _functools - import collections as _collections - - import numpy as _np - from numpy.fft import rfft, irfft - - _ALLOWED_MODULES = frozenset({ - "numpy", "np", "math", "cmath", "random", - "itertools", "functools", "collections", "fractions", "decimal", - }) + _SAFE_BUILTIN_NAMES = [ "abs", "all", "any", "bool", "dict", "divmod", "enumerate", "filter", "float", "format", "int", "isinstance", "issubclass", "iter", "len", "list", "map", "max", "min", "next", "object", "print", "range", "repr", "reversed", "round", "set", "slice", - "sorted", "str", "sum", "tuple", "type", "zip", + "sorted", "str", "sum", "tuple", "type", "zip", "True", "False", + "None", "complex", "frozenset", "bytes", "bytearray", "memoryview", + "property", "staticmethod", "classmethod", "super", "hash", "id", + "input", "ord", "chr", "hex", "oct", "bin", "pow", "Exception", "ValueError", "TypeError", "KeyError", "IndexError", "StopIteration", "RuntimeError", "NotImplementedError", "OverflowError", "ZeroDivisionError", "AttributeError", + "ImportError", "FileNotFoundError", "OSError", "ArithmeticError", ] - # Extract code - code_re = re.compile(r"```(?:python)?\s*\n(.*?)```", re.DOTALL) - blocks = code_re.findall(response_text) - code = blocks[-1].strip() if blocks else response_text.strip() + safe_builtins = {k: getattr(builtins, k) for k in _SAFE_BUILTIN_NAMES + if hasattr(builtins, k)} - # Build sandbox - import random as _random - safe_builtins = {k: getattr(builtins, k) for k in _SAFE_BUILTIN_NAMES if hasattr(builtins, k)} def _safe_import(name, globals=None, locals=None, fromlist=(), level=0): - if name.split(".")[0] not in _ALLOWED_MODULES: + base = name.split(".")[0] + if base not in _ALLOWED_MODULES: raise ImportError(f"Module '{name}' not allowed") return builtins.__import__(name, globals, locals, fromlist, level) + safe_builtins["__import__"] = _safe_import + + import random as _random namespace = { "__builtins__": safe_builtins, - "np": _np, "numpy": _np, "math": _math, "random": _random, - "itertools": _itertools, "functools": _functools, "collections": _collections, + "np": np, + "numpy": np, + "math": math, + "random": _random, } + # Add evaluate_erdos_solution to namespace (reference injects this) + def _evaluate_erdos_solution(h_values, c5_bound, n_points): + verify_c5_solution(h_values, c5_bound, n_points) + return float(c5_bound) + + namespace["evaluate_erdos_solution"] = _evaluate_erdos_solution + + stdout_capture = [] + original_print = builtins.print + def capturing_print(*args, **kwargs): + import io + buf = io.StringIO() + kwargs["file"] = buf + original_print(*args, **kwargs) + stdout_capture.append(buf.getvalue()) + namespace["__builtins__"]["print"] = capturing_print + + class _Timeout(Exception): + pass + + def _handler(s, f): + raise _Timeout(f"Execution timed out after {timeout}s") + try: - class _Timeout(Exception): - pass - def _handler(s, f): - raise _Timeout() - old = signal.signal(signal.SIGALRM, _handler) + old_handler = signal.signal(signal.SIGALRM, _handler) signal.alarm(timeout) try: exec(compile(code, "", "exec"), namespace) finally: signal.alarm(0) - signal.signal(signal.SIGALRM, old) + signal.signal(signal.SIGALRM, old_handler) + + if "run" not in namespace: + return { + "reward": 0.0, "raw_score": None, + "error_msg": "No 'run' function defined", + "stdout": "".join(stdout_capture), + } + + # Call run() + signal.signal(signal.SIGALRM, _handler) + signal.alarm(timeout) + try: + result = namespace["run"](seed=42, budget_s=timeout) + finally: + signal.alarm(0) + signal.signal(signal.SIGALRM, old_handler) + + if not isinstance(result, tuple) or len(result) != 3: + return { + "reward": 0.0, "raw_score": None, + "error_msg": f"run() must return (h_values, c5_bound, n_points), got {type(result)}", + "stdout": "".join(stdout_capture), + } + + h_values, c5_bound, n_points = result + h_values = np.asarray(h_values, dtype=np.float64) + + # Verify + computed_c5 = verify_c5_solution(h_values, c5_bound, n_points) + + if computed_c5 <= 0 or not np.isfinite(computed_c5): + return { + "reward": 0.0, "raw_score": None, + "error_msg": f"Invalid C5: {computed_c5}", + "stdout": "".join(stdout_capture), + } + + return { + "reward": float(1.0 / (1e-8 + computed_c5)), + "raw_score": float(computed_c5), + "error_msg": "", + "result_construction": h_values.tolist(), + "n_points": int(n_points), + "stdout": "".join(stdout_capture), + } - if "f" not in namespace: - return {"reward": 0.0, "bound": None, "error_msg": "no variable f"} + except _Timeout as e: + return { + "reward": 0.0, "raw_score": None, + "error_msg": str(e), + "stdout": "".join(stdout_capture), + } + except Exception as e: + return { + "reward": 0.0, "raw_score": None, + "error_msg": f"{type(e).__name__}: {str(e)[:300]}", + "stdout": "".join(stdout_capture), + } - f = _np.asarray(namespace["f"], dtype=float).flatten() - # Validate - if len(f) < 1 or len(f) > 1000: - return {"reward": 0.0, "bound": None, "error_msg": "bad length"} - if _np.any(~_np.isfinite(f)) or _np.any(f < 0) or _np.any(f > 1): - return {"reward": 0.0, "bound": None, "error_msg": "bad values"} - if abs(float(_np.mean(f)) - 0.5) > 1e-3: - return {"reward": 0.0, "bound": None, "error_msg": "bad mean"} +def _extract_code(response: str) -> str: + """Extract Python code from LLM response.""" + code_re = re.compile(r"```(?:python)?\s*\n(.*?)```", re.DOTALL) + blocks = code_re.findall(response) + if blocks: + return blocks[-1].strip() + # If no code block, try the whole response + return response.strip() - # Compute bound - n = len(f) - F = rfft(f, n=2*n) - autocorr = irfft(F * _np.conj(F), n=2*n) - bound = float(2 * n * _np.max(autocorr.real) / (_np.sum(f)**2)) - if bound <= 0 or not _math.isfinite(bound): - return {"reward": 0.0, "bound": None, "error_msg": "bad bound"} - return {"reward": 1.0 / bound, "bound": bound, "error_msg": ""} +# ═══════════════════════════════════════════════════════════════════ +# Initial state generation (from reference) +# ═══════════════════════════════════════════════════════════════════ - except Exception as e: - return {"reward": 0.0, "bound": None, "error_msg": str(e)[:200]} +def create_initial_state(rng=None): + """Create a random initial state — matches reference exactly.""" + if rng is None: + rng = np.random.default_rng() + n_points = int(rng.integers(40, 100)) + construction = np.ones(n_points) * 0.5 + perturbation = rng.uniform(-0.4, 0.4, n_points) + perturbation = perturbation - np.mean(perturbation) + construction = construction + perturbation + dx = 2.0 / n_points + correlation = np.correlate(construction, 1 - construction, mode="full") * dx + c5_bound = float(np.max(correlation)) + return { + "construction": construction.tolist(), + "c5_bound": c5_bound, + "n_points": n_points, + "code": "", + "parent_c5": None, + "observation": "", + } -# Type alias matching NeMo RL's convention -LLMMessageLogType = list[dict[str, Any]] -ErdosMetadata = dict[str, Any] +# ═══════════════════════════════════════════════════════════════════ +# Prompt construction (from reference State.to_prompt + ErdosMinOverlapEnv.get_question) +# ═══════════════════════════════════════════════════════════════════ -@ray.remote(max_restarts=-1, max_task_retries=-1) -class ErdosDiscoveryEnvironment(EnvironmentInterface[ErdosMetadata]): - """Erdős Minimum Overlap Problem environment for GRPO training. +TARGET_C5 = 0.3808 - Communicates with the NeMo Gym Erdős resource server via HTTP for: - - /verify: code execution + reward computation - - /select_state: PUCT state selection for prompts - - /seed_session: buffer initialization - - /compute_entropic_advantages: LOO entropic advantages - - /update_buffer: add new discoveries to PUCT tree - Config (under env.erdos_discovery): - resource_server_url: Base URL of the Erdős Gym resource server. - seed: Random seed for PUCT buffer initialization. - num_initial_states: States to seed the buffer with. - sandbox_timeout: Code execution timeout in seconds. - """ +def state_to_prompt(state: dict) -> str: + """Build the state context portion — matches reference State.to_prompt().""" + c5 = state.get("c5_bound", None) + parent_c5 = state.get("parent_c5", None) + code = state.get("code", "") + observation = state.get("observation", "") - def __init__(self, config: dict): - self.config = config - self.resource_server_url = config.get( - "resource_server_url", "http://localhost:8080" + value_ctx = "You are iteratively optimizing C₅ bound." + + if code and code.strip(): + value_ctx += f"\nHere is the last code we ran:\n```python\n{code}\n```" + else: + value_ctx += "\nNo previous code available." + + if parent_c5 is not None and c5 is not None: + current_gap = c5 - TARGET_C5 + value_ctx += ( + f"\nHere is the C₅ bound before and after running the code above " + f"(lower is better): {parent_c5:.6f} -> {c5:.6f}" + ) + value_ctx += ( + f"\nTarget: {TARGET_C5}. Current gap: {current_gap:.6f}. " + f"Further improvements will also be generously rewarded." + ) + elif c5 is not None: + current_gap = c5 - TARGET_C5 + value_ctx += f"\nCurrent C₅ bound (lower is better): {c5:.6f}" + value_ctx += ( + f"\nTarget: {TARGET_C5}. Current gap: {current_gap:.6f}. " + f"Further improvements will also be generously rewarded." ) - self.seed = config.get("seed", None) + else: + value_ctx += f"\nTarget C₅ bound: {TARGET_C5}" + + if observation and observation.strip(): + stdout = observation.strip() + if len(stdout) > 500: + stdout = "\n\n\t\t ...(TRUNCATED)...\n" + stdout[-500:] + value_ctx += f"\n\n--- Previous Program Output ---\n{stdout}\n--- End Output ---" + + return value_ctx + + +def build_erdos_question(state: dict) -> str: + """Build the full Erdős question — matches reference ErdosMinOverlapEnv.get_question().""" + state_ctx = state_to_prompt(state) + + construction = state.get("construction", []) + n = len(construction) if construction else 0 + + construction_section = "" + if construction and n > 0: + construction_section = ( + f"\nYou may want to start your search from the current construction, " + f"which you can access through the `initial_h_values` global variable " + f"(n={n} samples).\n" + f"You are encouraged to explore solutions that use other starting points " + f"to prevent getting stuck in a local optimum.\n" + ) + + code = state.get("code", "") + if code and code.strip(): + code_section = ( + "Reason about how you could further improve this construction.\n" + "Ideally, try to do something different than the above algorithm. " + "Could be using different algorithmic ideas, adjusting your heuristics, " + "adjusting / sweeping your hyperparemeters, etc.\n" + "Unless you make a meaningful improvement, you will not be rewarded." + ) + else: + code_section = "Write code to optimize this construction." + + return f"""You are an expert in harmonic analysis, numerical optimization, and mathematical discovery. +Your task is to find an improved upper bound for the Erdős minimum overlap problem constant C₅. + +## Problem + +Find a step function h: [0, 2] → [0, 1] that **minimizes** the overlap integral: + +$$C_5 = \\max_k \\int h(x)(1 - h(x+k)) dx$$ + +**Constraints**: +1. h(x) ∈ [0, 1] for all x +2. ∫₀² h(x) dx = 1 + +**Discretization**: Represent h as n_points samples over [0, 2]. +With dx = 2.0 / n_points: +- 0 ≤ h[i] ≤ 1 for all i +- sum(h) * dx = 1 (equivalently: sum(h) == n_points / 2 exactly) + +The evaluation computes: C₅ = max(np.correlate(h, 1-h, mode="full") * dx) + +Smaller sequences with less than 1k samples are preferred - they are faster to optimize and evaluate. + +**Lower C₅ values are better** - they provide tighter upper bounds on the Erdős constant. + +## Budget & Resources +- **Time budget**: 1000s for your code to run +- **CPUs**: 2 available + +## Rules +- Define `run(seed=42, budget_s=1000, **kwargs)` that returns `(h_values, c5_bound, n_points)` +- Use scipy, numpy, cvxpy[CBC,CVXOPT,GLOP,GLPK,GUROBI,MOSEK,PDLP,SCIP,XPRESS,ECOS], math +- Make all helper functions top level, no closures or lambdas +- No filesystem or network IO +- `evaluate_erdos_solution()` and `initial_h_values` (an initial construction, if available) are pre-imported +- Your function must complete within budget_s seconds and return the best solution found + +**Lower is better**. Current record: C₅ ≤ 0.38092. Our goal is to find a construction that shows C₅ ≤ 0.38080. + +{state_ctx} +{construction_section} +{code_section} +""" + + +# ═══════════════════════════════════════════════════════════════════ +# NeMo RL Environment +# ═══════════════════════════════════════════════════════════════════ + +ErdosMetadata = dict[str, Any] + + +@ray.remote +class ErdosDiscoveryEnvironment(EnvironmentInterface): + """Erdős Minimum Overlap environment for NeMo RL GRPO. + + Matches the reference TTT-Discover implementation: + - C5 formulation with np.correlate + - run() function entrypoint + - scipy/cvxpy allowed + - State context with parent code + improvement tracking + """ + + def __init__(self, config: dict = None): + config = config or {} + self.sandbox_timeout = config.get("sandbox_timeout", 1000) self.num_initial_states = config.get("num_initial_states", 16) - self.sandbox_timeout = config.get("sandbox_timeout", 600) - self.request_timeout = config.get("request_timeout", 660) + # Tracking self.best_reward = 0.0 - self.best_bound = float("inf") + self.best_c5 = float("inf") self.total_verified = 0 self.total_valid = 0 - self._session_initialized = False - self._inline_mode = (self.resource_server_url == "inline") - if self._inline_mode: - logger.info("ErdosDiscovery: running in INLINE mode (no Gym server)") - self._session_initialized = True # No server to init - - async def _ensure_session(self): - """Initialize the PUCT buffer on the resource server if not done.""" - if self._session_initialized: - return - try: - timeout = aiohttp.ClientTimeout(total=30) - async with aiohttp.ClientSession(timeout=timeout) as session: - async with session.post( - f"{self.resource_server_url}/seed_session", - json={ - "num_initial_states": self.num_initial_states, - "seed": self.seed, - }, - ) as resp: - data = await resp.json() - self.best_reward = data.get("best_initial_reward", 0.0) - self.best_bound = data.get( - "best_initial_bound", float("inf") - ) - logger.info( - "ErdosDiscovery: seeded buffer with %d states, " - "best_reward=%.4f, best_bound=%.6f", - data.get("num_states", 0), - self.best_reward, - self.best_bound, - ) - self._session_initialized = True - except Exception as e: - logger.error("ErdosDiscovery: seed_session failed: %s", e) - - async def _verify_single( - self, - session: Optional[aiohttp.ClientSession], - response_text: str, - parent_state: Optional[list[float]] = None, - ) -> dict: - """Call /verify on the resource server, or compute inline.""" - if self._inline_mode: - return _inline_compute_reward( - response_text, timeout=self.sandbox_timeout - ) - # Build a minimal NeMoGymResponse-like payload - # The resource server extracts output_text from response.output_text - body = { - "responses_create_params": { - "input": [{"role": "user", "content": ""}], - }, - "response": { - "id": "verify", - "output": [ - { - "type": "message", - "role": "assistant", - "content": [{"type": "output_text", "text": response_text}], - } - ], - "output_text": response_text, - }, - "parent_state": parent_state, - } - try: - timeout = aiohttp.ClientTimeout(total=self.request_timeout) - async with session.post( - f"{self.resource_server_url}/verify", - json=body, - timeout=timeout, - ) as resp: - return await resp.json() - except Exception as e: - logger.warning("ErdosDiscovery: verify failed: %s", e) - return {"reward": 0.0, "bound": None, "error_msg": str(e)} + + # PUCT buffer for state management + self._states = [] + self._initialize_states() + + def _initialize_states(self): + """Generate initial random states.""" + rng = np.random.default_rng(42) + for _ in range(self.num_initial_states): + self._states.append(create_initial_state(rng)) + + def get_initial_states(self, n: int = None) -> list[dict]: + """Return initial states for prompt generation.""" + if n is None: + return self._states + return self._states[:n] def step( self, message_log_batch: list[LLMMessageLogType], metadata: list[ErdosMetadata], ) -> EnvironmentReturn[ErdosMetadata]: - """Evaluate a batch of LLM responses. - - Extracts the assistant's last message from each conversation, - sends it to the resource server for code execution + scoring, - returns rewards. - """ - import asyncio - + """Evaluate a batch of LLM responses.""" try: loop = asyncio.get_running_loop() except RuntimeError: loop = None - if loop and loop.is_running(): - # Ray actors run inside an event loop — use nest_asyncio or run sync - return self._sync_step(message_log_batch, metadata) - else: - return asyncio.run( - self._async_step(message_log_batch, metadata) - ) + # Always use sync path (Ray actors run in event loops) + return self._sync_step(message_log_batch, metadata) def _sync_step( self, message_log_batch: list[LLMMessageLogType], metadata: list[ErdosMetadata], ) -> EnvironmentReturn[ErdosMetadata]: - """Synchronous step for use inside running event loops (Ray actors).""" + """Synchronous step — executes code and computes C5 reward.""" batch_size = len(message_log_batch) rewards = torch.zeros(batch_size) terminateds = torch.ones(batch_size) @@ -271,129 +457,53 @@ def _sync_step( updated_metadata = list(metadata) for i, message_log in enumerate(message_log_batch): + # Extract assistant response response_text = "" for msg in reversed(message_log): if msg.get("role") == "assistant": response_text = msg.get("content", "") break - if self._inline_mode: - result = _inline_compute_reward( - response_text, timeout=self.sandbox_timeout - ) - else: - result = {"reward": 0.0, "bound": None, "error_msg": "sync mode requires inline"} - - reward = result.get("reward", 0.0) - rewards[i] = reward - self.total_verified += 1 + # Extract code and execute + code = _extract_code(response_text) - if reward > 0: - self.total_valid += 1 - bound = result.get("bound") - if reward > self.best_reward: - self.best_reward = reward - self.best_bound = bound or (1.0 / reward if reward > 0 else float("inf")) - answers[i] = f"bound={bound:.6f}" if bound else f"reward={reward:.4f}" - - if i < len(updated_metadata): - updated_metadata[i] = { - **updated_metadata[i], - "reward": reward, - "bound": result.get("bound"), - "error_msg": result.get("error_msg", ""), - } - - return EnvironmentReturn( - observations=observations, - metadata=updated_metadata, - next_stop_strings=[None] * batch_size, - rewards=rewards, - terminateds=terminateds, - answers=answers, - ) - - async def _async_step( - self, - message_log_batch: list[LLMMessageLogType], - metadata: list[ErdosMetadata], - ) -> EnvironmentReturn[ErdosMetadata]: - await self._ensure_session() - - batch_size = len(message_log_batch) - rewards = torch.zeros(batch_size) - terminateds = torch.ones(batch_size) # Always single-turn - observations = [{"role": "user", "content": ""} for _ in range(batch_size)] - answers = [None] * batch_size - updated_metadata = list(metadata) - - if self._inline_mode: - session = None - else: - session = aiohttp.ClientSession( - timeout=aiohttp.ClientTimeout(total=self.request_timeout) - ) + # Inject initial_h_values from state if available + state = metadata[i] if i < len(metadata) else {} + construction = state.get("construction", None) + preamble = "import numpy as np\n\n" + if construction: + preamble += f"initial_h_values = np.array({construction!r})\n\n" + else: + preamble += "initial_h_values = None\n\n" - try: - tasks = [] - for i, message_log in enumerate(message_log_batch): - # Extract the last assistant message - response_text = "" - for msg in reversed(message_log): - if msg.get("role") == "assistant": - response_text = msg.get("content", "") - break - - # Get parent_state from metadata if available - parent_state = None - if metadata and i < len(metadata): - parent_state = metadata[i].get("parent_state", None) - - tasks.append( - self._verify_single(session, response_text, parent_state) - ) - - results = await asyncio.gather(*tasks, return_exceptions=True) - finally: - if session is not None: - await session.close() - - for i, result in enumerate(results): - if isinstance(result, Exception): - logger.warning( - "ErdosDiscovery: verify exception for sample %d: %s", - i, - result, - ) - continue + full_code = preamble + code + result = _execute_run_function(full_code, timeout=self.sandbox_timeout) reward = result.get("reward", 0.0) rewards[i] = reward self.total_verified += 1 + c5 = result.get("raw_score", None) + if reward > 0: self.total_valid += 1 - bound = result.get("bound", None) - if reward > self.best_reward: + if c5 is not None and c5 < self.best_c5: + self.best_c5 = c5 self.best_reward = reward - self.best_bound = bound or ( - 1.0 / reward if reward > 0 else float("inf") - ) + logger.info(f"New best C5: {c5:.6f} (reward={reward:.4f})") - answers[i] = ( - f"bound={bound:.6f}" if bound else f"reward={reward:.4f}" - ) + answers[i] = f"C5={c5:.6f}" if c5 else f"reward={reward:.4f}" - # Update metadata with verification results if i < len(updated_metadata): updated_metadata[i] = { **updated_metadata[i], "reward": reward, - "bound": result.get("bound"), + "c5_bound": c5, "error_msg": result.get("error_msg", ""), - "best_reward_ever": result.get( - "best_reward_ever", self.best_reward - ), + "stdout": result.get("stdout", ""), + # Update state for PUCT if valid + "result_construction": result.get("result_construction"), + "n_points": result.get("n_points"), } return EnvironmentReturn( @@ -406,22 +516,27 @@ async def _async_step( ) def global_post_process_and_metrics( - self, batch: dict - ) -> tuple[dict, dict]: - """Compute and return environment-level metrics.""" - valid_rate = ( - self.total_valid / max(self.total_verified, 1) - ) + self, + metadata: list[ErdosMetadata], + ) -> tuple[list[ErdosMetadata], dict[str, float]]: + """Compute aggregate metrics after a step.""" metrics = { - "env/best_reward": self.best_reward, - "env/best_bound": self.best_bound - if self.best_bound < float("inf") - else 0.0, - "env/total_verified": self.total_verified, - "env/valid_rate": valid_rate, + "best_c5": self.best_c5 if self.best_c5 < float("inf") else 0.0, + "best_reward": self.best_reward, + "total_verified": float(self.total_verified), + "total_valid": float(self.total_valid), + "valid_rate": ( + self.total_valid / max(1, self.total_verified) + ), } - return batch, metrics - def shutdown(self): - """Cleanup.""" - pass + # Count valid solutions in this batch + batch_valid = sum(1 for m in metadata if m.get("reward", 0) > 0) + batch_c5s = [m.get("c5_bound") for m in metadata if m.get("c5_bound") is not None] + if batch_c5s: + metrics["batch_best_c5"] = min(batch_c5s) + metrics["batch_mean_c5"] = sum(batch_c5s) / len(batch_c5s) + metrics["batch_valid_count"] = float(batch_valid) + metrics["batch_size"] = float(len(metadata)) + + return metadata, metrics diff --git a/test_gptoss_vllm.sh b/test_gptoss_vllm.sh new file mode 100755 index 0000000000..64010a9f07 --- /dev/null +++ b/test_gptoss_vllm.sh @@ -0,0 +1,42 @@ +#!/bin/bash +set -euo pipefail +cd /home/mormio/RL + +CONTAINER="/home/shared/containers/nemo-rl-super-v3.sqsh" +MODEL="/home/shared/models/gpt-oss-120b-bf16" +MOUNTS="$PWD:$PWD,/home/shared/models:/home/shared/models" + +# Use uv run which activates the right venv with vLLM +COMMAND=" +cd /opt/nemo-rl +uv run python -c \" +from vllm import LLM +print('Attempting to load gpt-oss-120b...') +try: + llm = LLM( + model='$MODEL', + tensor_parallel_size=8, + trust_remote_code=True, + max_model_len=1024, + gpu_memory_utilization=0.5, + enforce_eager=True, + ) + print('SUCCESS: gpt-oss-120b loaded!') + out = llm.generate(['Hello world'], max_tokens=10) + print('Generated:', out[0].outputs[0].text) +except Exception as e: + print(f'FAILED: {type(e).__name__}: {e}') +\" +" + +COMMAND="$COMMAND" \ +CONTAINER="$CONTAINER" \ +MOUNTS="$MOUNTS" \ +GPUS_PER_NODE=8 \ +sbatch \ + --nodes=1 --partition=batch --exclusive \ + --job-name=test-gptoss --time=00:30:00 \ + --output=logs/test-gptoss-%j.out \ + --error=logs/test-gptoss-%j.err \ + --exclude=d2dfac12-001,d2dfac12-002,d2dfac12-004,d2dfac12-007,d2dfac12-008,d2dfac12-019,d2dfac12-027,d2dfac12-028,d2dfac12-029 \ + ray.sub From eca88d81b848af2283a1596a2be9c49934230086 Mon Sep 17 00:00:00 2001 From: Morgane Moss Date: Wed, 1 Apr 2026 19:14:33 +0000 Subject: [PATCH 21/48] feat: 8k seq len, wandb logging, erdos/ metrics with max reward + valid rate - Seq len 4096 -> 8192 (7168 max_new_tokens) - wandb enabled: project=ttt-discover-erdos - Environment logs erdos/max_reward, erdos/avg_reward, erdos/valid_rate, erdos/best_c5, erdos/global_best_c5 to both console and metrics dict - Print summary line per step for easy monitoring --- examples/configs/grpo_erdos_discover.yaml | 14 ++++++---- launch_erdos_120b.sh | 13 ++++----- .../erdos_discovery_environment.py | 28 +++++++++++++++---- 3 files changed, 37 insertions(+), 18 deletions(-) diff --git a/examples/configs/grpo_erdos_discover.yaml b/examples/configs/grpo_erdos_discover.yaml index fb58646a6b..05cca43be5 100644 --- a/examples/configs/grpo_erdos_discover.yaml +++ b/examples/configs/grpo_erdos_discover.yaml @@ -22,7 +22,7 @@ policy: tokenizer: name: "/home/shared/models/NVIDIA-Nemotron-3-Super-120B-A12B-Base-BF16" chat_template_kwargs: null - max_total_sequence_length: 4096 + max_total_sequence_length: 8192 train_global_batch_size: 504 train_micro_batch_size: 1 logprob_batch_size: 1 @@ -33,12 +33,12 @@ policy: resources: num_nodes: 2 gpus_per_node: 8 - max_new_tokens: 3072 + max_new_tokens: 7168 vllm_cfg: async_engine: false tensor_parallel_size: 8 gpu_memory_utilization: 0.85 - max_model_len: 4096 + max_model_len: 8192 megatron_cfg: tensor_model_parallel_size: 4 @@ -64,14 +64,13 @@ optimizer: data: shuffle: false - max_input_seq_length: 4096 + max_input_seq_length: 8192 env: erdos_discovery: num_initial_states: 16 num_groups_per_step: 8 sandbox_timeout: 1000 - should_use_nemo_gym: false cluster: @@ -80,7 +79,10 @@ cluster: logger: log_dir: "results/erdos-120b" - wandb_enabled: false + wandb_enabled: true + wandb: + project: "ttt-discover-erdos" + name: "nemotron-120b-8k-8node" tensorboard_enabled: false mlflow_enabled: false swanlab_enabled: false diff --git a/launch_erdos_120b.sh b/launch_erdos_120b.sh index ded2091f78..0ac334b88b 100755 --- a/launch_erdos_120b.sh +++ b/launch_erdos_120b.sh @@ -1,7 +1,5 @@ #!/bin/bash -# TTT-Discover Erdős — Nemotron-3-Super-120B on 8 nodes -# 2 nodes inference (vLLM TP=8), 6 nodes training (Megatron TP=4 EP=8) -# Based on Dakota's working run_super_grpo.sh +# TTT-Discover Erdős — Nemotron-3-Super-120B, 8k seq len, wandb logging set -euo pipefail cd /home/mormio/RL @@ -10,6 +8,7 @@ MODEL_PATH="/home/shared/models/NVIDIA-Nemotron-3-Super-120B-A12B-Base-BF16" EXP="results/erdos-120b-$(date +%Y%m%d_%H%M)" mkdir -p "$EXP" +WANDB_API_KEY=$(grep 'password' ~/.netrc | head -1 | awk '{print $2}') MOUNTS="$PWD:$PWD,$MODEL_PATH:$MODEL_PATH,$HOME/.cache:$HOME/.cache" COMMAND=" @@ -27,8 +26,8 @@ export UCX_NET_DEVICES=bond0 && \ export HF_HUB_ENABLE_HF_TRANSFER=0 && \ export TORCH_CUDA_ARCH_LIST='9.0 10.0' && \ export NRL_IGNORE_VERSION_MISMATCH=1 && \ +export WANDB_API_KEY=$WANDB_API_KEY && \ -# Copy our custom files into the container's /opt/nemo-rl SRC=/home/mormio/RL cp \$SRC/nemo_rl/algorithms/entropic_advantage_estimator.py /opt/nemo-rl/nemo_rl/algorithms/ cp \$SRC/nemo_rl/environments/erdos_discovery_environment.py /opt/nemo-rl/nemo_rl/environments/ @@ -36,7 +35,6 @@ cp \$SRC/nemo_rl/utils/puct_buffer.py /opt/nemo-rl/nemo_rl/utils/ cp \$SRC/examples/run_discover.py /opt/nemo-rl/examples/ cp \$SRC/examples/configs/grpo_erdos_discover.yaml /opt/nemo-rl/examples/configs/ -# Patch grpo.py to register entropic estimator python -c \" path = '/opt/nemo-rl/nemo_rl/algorithms/grpo.py' with open(path) as f: @@ -61,7 +59,6 @@ if 'entropic_adaptive_beta' not in content: print('Patched grpo.py') \" && \ -# Patch utils.py to register erdos_discovery env python -c \" path = '/opt/nemo-rl/nemo_rl/environments/utils.py' with open(path) as f: @@ -80,10 +77,11 @@ python examples/run_discover.py \ --config examples/configs/grpo_erdos_discover.yaml " -echo "Submitting Erdős TTT-Discover 120B..." +echo "Submitting Erdős TTT-Discover 120B (8k seq, wandb)..." echo " Container: $CONTAINER" echo " Model: $MODEL_PATH" echo " Nodes: 8 (2 inference + 6 training)" +echo " Seq len: 8192" echo " Exp: $EXP" COMMAND="$COMMAND" \ @@ -99,3 +97,4 @@ sbatch \ ray.sub echo "Logs: $EXP/" +echo "W&B: https://wandb.ai/nous_research/ttt-discover-erdos" diff --git a/nemo_rl/environments/erdos_discovery_environment.py b/nemo_rl/environments/erdos_discovery_environment.py index be8c14b9d8..2a6825cc6b 100644 --- a/nemo_rl/environments/erdos_discovery_environment.py +++ b/nemo_rl/environments/erdos_discovery_environment.py @@ -531,12 +531,30 @@ def global_post_process_and_metrics( } # Count valid solutions in this batch - batch_valid = sum(1 for m in metadata if m.get("reward", 0) > 0) + batch_rewards = [m.get("reward", 0.0) for m in metadata] + batch_valid = sum(1 for r in batch_rewards if r > 0) batch_c5s = [m.get("c5_bound") for m in metadata if m.get("c5_bound") is not None] + + metrics["erdos/max_reward"] = float(max(batch_rewards)) if batch_rewards else 0.0 + metrics["erdos/avg_reward"] = float(sum(batch_rewards) / max(1, len(batch_rewards))) + metrics["erdos/valid_count"] = float(batch_valid) + metrics["erdos/valid_rate"] = float(batch_valid / max(1, len(batch_rewards))) + metrics["erdos/batch_size"] = float(len(metadata)) + if batch_c5s: - metrics["batch_best_c5"] = min(batch_c5s) - metrics["batch_mean_c5"] = sum(batch_c5s) / len(batch_c5s) - metrics["batch_valid_count"] = float(batch_valid) - metrics["batch_size"] = float(len(metadata)) + metrics["erdos/best_c5"] = float(min(batch_c5s)) + metrics["erdos/mean_c5"] = float(sum(batch_c5s) / len(batch_c5s)) + metrics["erdos/worst_c5"] = float(max(batch_c5s)) + + metrics["erdos/global_best_c5"] = float(self.best_c5) if self.best_c5 < float("inf") else 0.0 + metrics["erdos/global_valid_total"] = float(self.total_valid) + + # Print summary to driver log + max_r = metrics["erdos/max_reward"] + avg_r = metrics["erdos/avg_reward"] + best = metrics.get("erdos/best_c5", "n/a") + print(f" 🎯 Erdős: avg_reward={avg_r:.4f} max_reward={max_r:.4f} " + f"valid={batch_valid}/{len(metadata)} " + f"best_c5={best}") return metadata, metrics From a777cbca18c7e8a55af2a28a843005bef4c21bc3 Mon Sep 17 00:00:00 2001 From: Morgane Moss Date: Wed, 1 Apr 2026 19:53:37 +0000 Subject: [PATCH 22/48] fix: use ThreadPoolExecutor timeout instead of signal.alarm (broken in Ray actors) signal.alarm only works in main thread. Ray actors run in worker threads, so SIGALRM never fires and sandbox code blocks indefinitely. Now uses ThreadPoolExecutor with a 120s timeout. Also caps run() budget_s at 60s for faster iteration. --- .../erdos_discovery_environment.py | 79 +++++++++++++------ 1 file changed, 57 insertions(+), 22 deletions(-) diff --git a/nemo_rl/environments/erdos_discovery_environment.py b/nemo_rl/environments/erdos_discovery_environment.py index 2a6825cc6b..3bdf350eeb 100644 --- a/nemo_rl/environments/erdos_discovery_environment.py +++ b/nemo_rl/environments/erdos_discovery_environment.py @@ -151,36 +151,71 @@ def capturing_print(*args, **kwargs): stdout_capture.append(buf.getvalue()) namespace["__builtins__"]["print"] = capturing_print - class _Timeout(Exception): - pass - - def _handler(s, f): - raise _Timeout(f"Execution timed out after {timeout}s") - try: - old_handler = signal.signal(signal.SIGALRM, _handler) - signal.alarm(timeout) - try: + # Use multiprocessing for timeout (signal.alarm doesn't work in Ray actor threads) + import multiprocessing as mp + import pickle + + def _run_in_subprocess(code, ns_pickle, result_queue): + """Execute code in a subprocess with proper timeout support.""" + import signal as _signal + namespace = pickle.loads(ns_pickle) + + class _Timeout(Exception): + pass + def _handler(s, f): + raise _Timeout("timeout") + _signal.signal(_signal.SIGALRM, _handler) + _signal.alarm(timeout) + try: + exec(compile(code, "", "exec"), namespace) + if "run" not in namespace: + result_queue.put({"error": "No 'run' function defined"}) + return + out = namespace["run"](seed=42, budget_s=timeout) + result_queue.put({"result": out}) + except _Timeout: + result_queue.put({"error": f"Execution timed out after {timeout}s"}) + except Exception as e: + result_queue.put({"error": f"{type(e).__name__}: {str(e)[:300]}"}) + + # Try simple exec first with a short thread-based timeout + # For most invalid code this returns instantly + from concurrent.futures import ThreadPoolExecutor, TimeoutError as FuturesTimeout + + def _exec_with_timeout(): exec(compile(code, "", "exec"), namespace) - finally: - signal.alarm(0) - signal.signal(signal.SIGALRM, old_handler) + if "run" not in namespace: + return {"error": "No 'run' function defined"} + result = namespace["run"](seed=42, budget_s=min(timeout, 60)) + return {"result": result} + + with ThreadPoolExecutor(max_workers=1) as pool: + future = pool.submit(_exec_with_timeout) + try: + exec_result = future.result(timeout=min(timeout, 120)) + except FuturesTimeout: + # Thread is stuck — it will be abandoned when pool is GC'd + return { + "reward": 0.0, "raw_score": None, + "error_msg": f"Execution timed out after {min(timeout, 120)}s", + "stdout": "".join(stdout_capture), + } + except Exception as e: + return { + "reward": 0.0, "raw_score": None, + "error_msg": f"{type(e).__name__}: {str(e)[:300]}", + "stdout": "".join(stdout_capture), + } - if "run" not in namespace: + if "error" in exec_result: return { "reward": 0.0, "raw_score": None, - "error_msg": "No 'run' function defined", + "error_msg": exec_result["error"], "stdout": "".join(stdout_capture), } - # Call run() - signal.signal(signal.SIGALRM, _handler) - signal.alarm(timeout) - try: - result = namespace["run"](seed=42, budget_s=timeout) - finally: - signal.alarm(0) - signal.signal(signal.SIGALRM, old_handler) + result = exec_result["result"] if not isinstance(result, tuple) or len(result) != 3: return { From 8527a86ac1c1adccabfad601ba614fbed0a8045a Mon Sep 17 00:00:00 2001 From: Morgane Moss Date: Wed, 1 Apr 2026 20:05:10 +0000 Subject: [PATCH 23/48] feat: add timestamped progress logging to reward computation --- .../erdos_discovery_environment.py | 21 +++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/nemo_rl/environments/erdos_discovery_environment.py b/nemo_rl/environments/erdos_discovery_environment.py index 3bdf350eeb..c30a14099a 100644 --- a/nemo_rl/environments/erdos_discovery_environment.py +++ b/nemo_rl/environments/erdos_discovery_environment.py @@ -491,7 +491,20 @@ def _sync_step( answers = [None] * batch_size updated_metadata = list(metadata) + import time as _time + _t0 = _time.time() + logger.info(f"[{_time.strftime('%H:%M:%S')}] Starting reward computation for {batch_size} rollouts") + for i, message_log in enumerate(message_log_batch): + if i > 0 and i % 50 == 0: + elapsed = _time.time() - _t0 + rate = i / elapsed if elapsed > 0 else 0 + eta = (batch_size - i) / rate if rate > 0 else 0 + logger.info( + f"[{_time.strftime('%H:%M:%S')}] Reward progress: {i}/{batch_size} " + f"({elapsed:.0f}s elapsed, {rate:.1f} it/s, ~{eta:.0f}s remaining)" + ) + # Extract assistant response response_text = "" for msg in reversed(message_log): @@ -541,6 +554,14 @@ def _sync_step( "n_points": result.get("n_points"), } + elapsed = _time.time() - _t0 + valid = sum(1 for r in rewards if r > 0) + max_r = float(rewards.max()) if len(rewards) > 0 else 0.0 + logger.info( + f"[{_time.strftime('%H:%M:%S')}] Reward computation done: {batch_size} rollouts in {elapsed:.1f}s " + f"({valid} valid, max_reward={max_r:.4f})" + ) + return EnvironmentReturn( observations=observations, metadata=updated_metadata, From c48826bb9b7d6cbff2dd3013e7581577ea7984d5 Mon Sep 17 00:00:00 2001 From: Morgane Moss Date: Wed, 1 Apr 2026 20:15:37 +0000 Subject: [PATCH 24/48] fix: remove stale _Timeout reference that crashed env actor on bad model output --- nemo_rl/environments/erdos_discovery_environment.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/nemo_rl/environments/erdos_discovery_environment.py b/nemo_rl/environments/erdos_discovery_environment.py index c30a14099a..a7e48523bc 100644 --- a/nemo_rl/environments/erdos_discovery_environment.py +++ b/nemo_rl/environments/erdos_discovery_environment.py @@ -246,12 +246,6 @@ def _exec_with_timeout(): "stdout": "".join(stdout_capture), } - except _Timeout as e: - return { - "reward": 0.0, "raw_score": None, - "error_msg": str(e), - "stdout": "".join(stdout_capture), - } except Exception as e: return { "reward": 0.0, "raw_score": None, From 2b07c29028cbaff161ca415c030d1b6d0368e740 Mon Sep 17 00:00:00 2001 From: Morgane Moss Date: Wed, 1 Apr 2026 20:45:35 +0000 Subject: [PATCH 25/48] fix: use print() instead of logger.info() for Ray actor visibility --- nemo_rl/environments/erdos_discovery_environment.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/nemo_rl/environments/erdos_discovery_environment.py b/nemo_rl/environments/erdos_discovery_environment.py index a7e48523bc..be39347732 100644 --- a/nemo_rl/environments/erdos_discovery_environment.py +++ b/nemo_rl/environments/erdos_discovery_environment.py @@ -487,14 +487,14 @@ def _sync_step( import time as _time _t0 = _time.time() - logger.info(f"[{_time.strftime('%H:%M:%S')}] Starting reward computation for {batch_size} rollouts") + print(f"[{_time.strftime('%H:%M:%S')}] 🧪 Starting reward computation for {batch_size} rollouts") for i, message_log in enumerate(message_log_batch): if i > 0 and i % 50 == 0: elapsed = _time.time() - _t0 rate = i / elapsed if elapsed > 0 else 0 eta = (batch_size - i) / rate if rate > 0 else 0 - logger.info( + print( f"[{_time.strftime('%H:%M:%S')}] Reward progress: {i}/{batch_size} " f"({elapsed:.0f}s elapsed, {rate:.1f} it/s, ~{eta:.0f}s remaining)" ) @@ -551,9 +551,11 @@ def _sync_step( elapsed = _time.time() - _t0 valid = sum(1 for r in rewards if r > 0) max_r = float(rewards.max()) if len(rewards) > 0 else 0.0 - logger.info( - f"[{_time.strftime('%H:%M:%S')}] Reward computation done: {batch_size} rollouts in {elapsed:.1f}s " - f"({valid} valid, max_reward={max_r:.4f})" + best_c5 = 1.0 / max_r if max_r > 0 else float("inf") + print( + f"[{_time.strftime('%H:%M:%S')}] 🎯 Rewards done: {batch_size} rollouts in {elapsed:.1f}s | " + f"valid={valid}/{batch_size} ({100*valid/max(1,batch_size):.1f}%) | " + f"max_reward={max_r:.4f} | best_C5={best_c5:.6f}" ) return EnvironmentReturn( From 59046715fcd75063ddde0f05bfcb8598fa11e8d5 Mon Sep 17 00:00:00 2001 From: Morgane Moss Date: Wed, 1 Apr 2026 20:51:13 +0000 Subject: [PATCH 26/48] feat: prominent step-level logging with max_reward, best_C5, global_best_C5 --- .../environments/erdos_discovery_environment.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/nemo_rl/environments/erdos_discovery_environment.py b/nemo_rl/environments/erdos_discovery_environment.py index be39347732..38aea72b93 100644 --- a/nemo_rl/environments/erdos_discovery_environment.py +++ b/nemo_rl/environments/erdos_discovery_environment.py @@ -532,7 +532,7 @@ def _sync_step( if c5 is not None and c5 < self.best_c5: self.best_c5 = c5 self.best_reward = reward - logger.info(f"New best C5: {c5:.6f} (reward={reward:.4f})") + print(f"🏆 NEW BEST C5: {c5:.6f} (reward={reward:.4f})") answers[i] = f"C5={c5:.6f}" if c5 else f"reward={reward:.4f}" @@ -551,11 +551,17 @@ def _sync_step( elapsed = _time.time() - _t0 valid = sum(1 for r in rewards if r > 0) max_r = float(rewards.max()) if len(rewards) > 0 else 0.0 - best_c5 = 1.0 / max_r if max_r > 0 else float("inf") + batch_best_c5 = 1.0 / max_r if max_r > 0 else float("inf") + global_best = self.best_c5 if self.best_c5 < float("inf") else float("inf") print( - f"[{_time.strftime('%H:%M:%S')}] 🎯 Rewards done: {batch_size} rollouts in {elapsed:.1f}s | " - f"valid={valid}/{batch_size} ({100*valid/max(1,batch_size):.1f}%) | " - f"max_reward={max_r:.4f} | best_C5={best_c5:.6f}" + f"\n{'='*60}\n" + f"🎯 STEP REWARDS: {batch_size} rollouts in {elapsed:.1f}s\n" + f" valid: {valid}/{batch_size} ({100*valid/max(1,batch_size):.1f}%)\n" + f" avg_reward: {sum(float(r) for r in rewards)/max(1,batch_size):.6f}\n" + f" max_reward: {max_r:.6f}\n" + f" batch_best_C5: {batch_best_c5:.6f}\n" + f" GLOBAL_BEST_C5: {global_best:.6f}\n" + f"{'='*60}" ) return EnvironmentReturn( From da2ce187a8ba48d1c2f77af6306a093403e52232 Mon Sep 17 00:00:00 2001 From: Morgane Moss Date: Wed, 1 Apr 2026 22:42:13 +0000 Subject: [PATCH 27/48] fix: disable validation to prevent max_val_samples None crash at step 5 --- examples/configs/grpo_erdos_discover.yaml | 3 +++ 1 file changed, 3 insertions(+) diff --git a/examples/configs/grpo_erdos_discover.yaml b/examples/configs/grpo_erdos_discover.yaml index 05cca43be5..300e44961c 100644 --- a/examples/configs/grpo_erdos_discover.yaml +++ b/examples/configs/grpo_erdos_discover.yaml @@ -7,6 +7,9 @@ grpo: num_generations_per_prompt: 63 max_num_steps: 50 max_rollout_turns: 1 + val_period: 0 + val_at_start: false + val_at_end: false remove_constant_reward_groups: true adv_estimator: name: entropic_adaptive_beta From 24e9aa0cb0e4d9cc021268e29baa2d44f57e5222 Mon Sep 17 00:00:00 2001 From: Morgane Moss Date: Thu, 2 Apr 2026 02:06:14 +0000 Subject: [PATCH 28/48] fix: disable checkpointing (async writer crashes at step 10) --- examples/configs/grpo_erdos_discover.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/configs/grpo_erdos_discover.yaml b/examples/configs/grpo_erdos_discover.yaml index 300e44961c..f2042bead1 100644 --- a/examples/configs/grpo_erdos_discover.yaml +++ b/examples/configs/grpo_erdos_discover.yaml @@ -91,6 +91,6 @@ logger: swanlab_enabled: false checkpointing: - enabled: true + enabled: false checkpoint_dir: "results/erdos-120b" save_period: 5 From b9afea8533425da51b95d502eda56ffa4ed588e6 Mon Sep 17 00:00:00 2001 From: Morgane Moss Date: Thu, 2 Apr 2026 02:20:53 +0000 Subject: [PATCH 29/48] feat: scale to 10 nodes + 16k seq len (CP=2), save outputs to JSONL per step - 10 nodes: 2 inference + 8 training (was 8 total) - CP=2 enables 16k context (was 8k) - 15360 max_new_tokens (was 7168) - Save first 10 + all valid outputs per step for debugging valid rate - ERDOS_LOG_DIR for output files --- examples/configs/grpo_erdos_discover.yaml | 20 +++++------ launch_erdos_120b.sh | 7 ++-- .../erdos_discovery_environment.py | 36 +++++++++++++++++++ 3 files changed, 50 insertions(+), 13 deletions(-) diff --git a/examples/configs/grpo_erdos_discover.yaml b/examples/configs/grpo_erdos_discover.yaml index f2042bead1..f5979ea98c 100644 --- a/examples/configs/grpo_erdos_discover.yaml +++ b/examples/configs/grpo_erdos_discover.yaml @@ -1,5 +1,5 @@ # TTT-Discover Erdős — Nemotron-3-Super-120B-A12B -# Inherits from grpo_superv3.yaml (the working Nemotron Super config) +# 10 nodes: 2 inference + 8 training, CP=2 for 16k seq len defaults: "grpo_superv3.yaml" grpo: @@ -7,10 +7,10 @@ grpo: num_generations_per_prompt: 63 max_num_steps: 50 max_rollout_turns: 1 + remove_constant_reward_groups: true val_period: 0 val_at_start: false val_at_end: false - remove_constant_reward_groups: true adv_estimator: name: entropic_adaptive_beta gamma: 0.6931471805599453 @@ -25,7 +25,7 @@ policy: tokenizer: name: "/home/shared/models/NVIDIA-Nemotron-3-Super-120B-A12B-Base-BF16" chat_template_kwargs: null - max_total_sequence_length: 8192 + max_total_sequence_length: 16384 train_global_batch_size: 504 train_micro_batch_size: 1 logprob_batch_size: 1 @@ -36,17 +36,17 @@ policy: resources: num_nodes: 2 gpus_per_node: 8 - max_new_tokens: 7168 + max_new_tokens: 15360 vllm_cfg: async_engine: false tensor_parallel_size: 8 gpu_memory_utilization: 0.85 - max_model_len: 8192 + max_model_len: 16384 megatron_cfg: tensor_model_parallel_size: 4 pipeline_model_parallel_size: 1 - context_parallel_size: 1 + context_parallel_size: 2 expert_model_parallel_size: 8 sequence_parallel: true activation_checkpointing: true @@ -67,7 +67,7 @@ optimizer: data: shuffle: false - max_input_seq_length: 8192 + max_input_seq_length: 16384 env: erdos_discovery: @@ -78,14 +78,14 @@ env: cluster: gpus_per_node: 8 - num_nodes: 8 + num_nodes: 10 logger: log_dir: "results/erdos-120b" wandb_enabled: true wandb: project: "ttt-discover-erdos" - name: "nemotron-120b-8k-8node" + name: "nemotron-120b-16k-10node" tensorboard_enabled: false mlflow_enabled: false swanlab_enabled: false @@ -93,4 +93,4 @@ logger: checkpointing: enabled: false checkpoint_dir: "results/erdos-120b" - save_period: 5 + save_period: 10 diff --git a/launch_erdos_120b.sh b/launch_erdos_120b.sh index 0ac334b88b..580f2111b0 100755 --- a/launch_erdos_120b.sh +++ b/launch_erdos_120b.sh @@ -26,6 +26,7 @@ export UCX_NET_DEVICES=bond0 && \ export HF_HUB_ENABLE_HF_TRANSFER=0 && \ export TORCH_CUDA_ARCH_LIST='9.0 10.0' && \ export NRL_IGNORE_VERSION_MISMATCH=1 && \ +export ERDOS_LOG_DIR=/home/mormio/RL/results/erdos_outputs && \ export WANDB_API_KEY=$WANDB_API_KEY && \ SRC=/home/mormio/RL @@ -80,8 +81,8 @@ python examples/run_discover.py \ echo "Submitting Erdős TTT-Discover 120B (8k seq, wandb)..." echo " Container: $CONTAINER" echo " Model: $MODEL_PATH" -echo " Nodes: 8 (2 inference + 6 training)" -echo " Seq len: 8192" +echo " Nodes: 10 (2 inference + 8 training)" +echo " Seq len: 16384" echo " Exp: $EXP" COMMAND="$COMMAND" \ @@ -89,7 +90,7 @@ CONTAINER="$CONTAINER" \ MOUNTS="$MOUNTS" \ GPUS_PER_NODE=8 \ sbatch \ - --nodes=8 --partition=batch --exclusive \ + --nodes=10 --partition=batch --exclusive \ --job-name=erdos-120b --time=12:00:00 \ --output="$EXP/slurm-%j.out" \ --error="$EXP/slurm-%j.err" \ diff --git a/nemo_rl/environments/erdos_discovery_environment.py b/nemo_rl/environments/erdos_discovery_environment.py index 38aea72b93..db35b17863 100644 --- a/nemo_rl/environments/erdos_discovery_environment.py +++ b/nemo_rl/environments/erdos_discovery_environment.py @@ -564,6 +564,42 @@ def _sync_step( f"{'='*60}" ) + # Save outputs to JSONL for debugging + import json, os + self.total_verified # use as step counter proxy + step_num = getattr(self, '_step_count', 0) + 1 + self._step_count = step_num + log_dir = os.environ.get("ERDOS_LOG_DIR", "/tmp/erdos_outputs") + os.makedirs(log_dir, exist_ok=True) + out_path = os.path.join(log_dir, f"step_{step_num:03d}.jsonl") + try: + with open(out_path, "w") as fout: + # Save a sample: first 10 + all valid ones + for idx in range(batch_size): + r = float(rewards[idx]) + save_this = (idx < 10) or (r > 0) + if not save_this: + continue + response = "" + for msg in reversed(message_log_batch[idx]): + if msg.get("role") == "assistant": + response = msg.get("content", "") + break + meta = updated_metadata[idx] if idx < len(updated_metadata) else {} + entry = { + "idx": idx, + "reward": r, + "c5_bound": meta.get("c5_bound"), + "error_msg": meta.get("error_msg", ""), + "response_len": len(response), + "response_preview": response[:500], + "code_preview": _extract_code(response)[:500] if response else "", + } + fout.write(json.dumps(entry) + "\n") + print(f" 📝 Saved outputs to {out_path}") + except Exception as e: + print(f" ⚠️ Failed to save outputs: {e}") + return EnvironmentReturn( observations=observations, metadata=updated_metadata, From a4c5ea00ae87d7c9d5119e24c7332853287a83ac Mon Sep 17 00:00:00 2001 From: Morgane Moss Date: Thu, 2 Apr 2026 12:32:53 +0000 Subject: [PATCH 30/48] fix: fully disable checkpointing with null checkpoint_must_save_by --- examples/configs/grpo_erdos_discover.yaml | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/examples/configs/grpo_erdos_discover.yaml b/examples/configs/grpo_erdos_discover.yaml index f5979ea98c..69e07dcde4 100644 --- a/examples/configs/grpo_erdos_discover.yaml +++ b/examples/configs/grpo_erdos_discover.yaml @@ -93,4 +93,7 @@ logger: checkpointing: enabled: false checkpoint_dir: "results/erdos-120b" - save_period: 10 + save_period: 999999 + checkpoint_must_save_by: null + model_save_format: "safetensors" + save_consolidated: false From 8ece7e135560b46efd8eba4ec2b53494a241b27c Mon Sep 17 00:00:00 2001 From: Morgane Moss Date: Thu, 2 Apr 2026 12:46:04 +0000 Subject: [PATCH 31/48] feat: debug config for step 10 hang repro (Qwen 1.5B, 1 node, 16k, 15 steps) --- examples/configs/grpo_erdos_debug_16k.yaml | 88 ++++++++++++++++++++++ launch_erdos_debug_16k.sh | 83 ++++++++++++++++++++ 2 files changed, 171 insertions(+) create mode 100644 examples/configs/grpo_erdos_debug_16k.yaml create mode 100755 launch_erdos_debug_16k.sh diff --git a/examples/configs/grpo_erdos_debug_16k.yaml b/examples/configs/grpo_erdos_debug_16k.yaml new file mode 100644 index 0000000000..3f7d61dbd6 --- /dev/null +++ b/examples/configs/grpo_erdos_debug_16k.yaml @@ -0,0 +1,88 @@ +# Debug config: Qwen2.5-1.5B, 1 node, 16k seq len, 15 steps +# Purpose: reproduce step 10 hang at small scale +defaults: "grpo_math_1B.yaml" + +grpo: + num_prompts_per_step: 4 + num_generations_per_prompt: 8 + max_num_steps: 15 + max_rollout_turns: 1 + remove_constant_reward_groups: true + val_period: 0 + val_at_start: false + val_at_end: false + adv_estimator: + name: entropic_adaptive_beta + gamma: 0.6931471805599453 + +loss_fn: + kl_penalty_coef: 0.1 + ratio_clip: 0.2 + token_level_loss: false + +policy: + model_name: "Qwen/Qwen2.5-1.5B-Instruct" + tokenizer: + name: "Qwen/Qwen2.5-1.5B-Instruct" + chat_template_kwargs: null + max_total_sequence_length: 16384 + train_global_batch_size: 32 + train_micro_batch_size: 4 + dtensor_cfg: + enabled: true + tensor_parallel_size: 1 + sequence_parallel: false + cpu_offload: false + activation_checkpointing: true + lora_cfg: + enabled: true + rank: 16 + alpha: 1.0 + dropout: 0.0 + generation: + backend: "vllm" + max_new_tokens: 15360 + temperature: 1.0 + top_p: 1.0 + stop_token_ids: null + stop_strings: null + vllm_cfg: + async_engine: false + tensor_parallel_size: 1 + pipeline_parallel_size: 1 + expert_parallel_size: 1 + gpu_memory_utilization: 0.6 + max_model_len: 16384 + dynamic_batching: + enabled: false + +optimizer: + name: adamw + lr: 1.0e-4 + +data: + shuffle: false + +env: + erdos_discovery: + num_initial_states: 8 + num_groups_per_step: 4 + sandbox_timeout: 120 + should_use_nemo_gym: false + +cluster: + gpus_per_node: 8 + num_nodes: 1 + +logger: + log_dir: "logs/erdos-debug-16k" + wandb_enabled: false + tensorboard_enabled: false + mlflow_enabled: false + swanlab_enabled: false + +checkpointing: + enabled: false + checkpoint_dir: "logs/erdos-debug-16k" + save_period: 999999 + checkpoint_must_save_by: null diff --git a/launch_erdos_debug_16k.sh b/launch_erdos_debug_16k.sh new file mode 100755 index 0000000000..75e0121e38 --- /dev/null +++ b/launch_erdos_debug_16k.sh @@ -0,0 +1,83 @@ +#!/bin/bash +# Debug: Qwen2.5-1.5B, 1 node, 16k, 15 steps — test for step 10 hang +set -euo pipefail +cd /home/mormio/RL + +CONTAINER="nvcr.io#nvidia/nemo-rl:v0.5.0" +EXP="results/erdos-debug-16k-$(date +%Y%m%d_%H%M)" +mkdir -p "$EXP" + +MOUNTS="$PWD:$PWD,/home/shared/models:/home/shared/models" + +COMMAND=" +export HF_HUB_ENABLE_HF_TRANSFER=0 +export TORCH_CUDA_ARCH_LIST='9.0 10.0' +export NRL_IGNORE_VERSION_MISMATCH=1 +export PYTHONPATH=/home/mormio/RL:\${PYTHONPATH:-} +export ERDOS_LOG_DIR=/home/mormio/RL/results/erdos_debug_outputs + +SRC=/home/mormio/RL +cp \$SRC/nemo_rl/algorithms/entropic_advantage_estimator.py /opt/nemo-rl/nemo_rl/algorithms/ +cp \$SRC/nemo_rl/environments/erdos_discovery_environment.py /opt/nemo-rl/nemo_rl/environments/ +cp \$SRC/nemo_rl/utils/puct_buffer.py /opt/nemo-rl/nemo_rl/utils/ +cp \$SRC/examples/run_discover.py /opt/nemo-rl/examples/ +cp \$SRC/examples/configs/grpo_erdos_debug_16k.yaml /opt/nemo-rl/examples/configs/ + +python -c \" +path = '/opt/nemo-rl/nemo_rl/algorithms/grpo.py' +with open(path) as f: + content = f.read() +if 'entropic_adaptive_beta' not in content: + old = ' else:\\n raise ValueError(f\\\"Invalid adv_estimator name: {adv_estimator_name}\\\")\\n\\n return adv_estimator' + new = ''' elif adv_estimator_name == \\\"entropic_adaptive_beta\\\": + from nemo_rl.algorithms.entropic_advantage_estimator import ( + EntropicAdaptiveBetaAdvantageEstimator, + ) + adv_estimator = EntropicAdaptiveBetaAdvantageEstimator( + adv_estimator_config, loss_config + ) + print(\\\" Using Entropic Adaptive-Beta advantage estimator (TTT-Discover)\\\") + else: + raise ValueError(f\\\"Invalid adv_estimator name: {adv_estimator_name}\\\") + + return adv_estimator''' + content = content.replace(old, new) + with open(path, 'w') as f: + f.write(content) + print('Patched grpo.py') +\" && \ + +python -c \" +path = '/opt/nemo-rl/nemo_rl/environments/utils.py' +with open(path) as f: + content = f.read() +if 'erdos_discovery' not in content: + content = content.replace( + '\\\"nemo_gym\\\": {', + '\\\"erdos_discovery\\\": {\\n \\\"actor_class_fqn\\\": \\\"nemo_rl.environments.erdos_discovery_environment.ErdosDiscoveryEnvironment\\\",\\n },\\n \\\"nemo_gym\\\": {' + ) + with open(path, 'w') as f: + f.write(content) + print('Patched utils.py') +\" && \ + +cd /opt/nemo-rl +python examples/run_discover.py \ + --config examples/configs/grpo_erdos_debug_16k.yaml +" + +echo "Launching debug: Qwen2.5-1.5B, 1 node, 16k, 15 steps" + +COMMAND="$COMMAND" \ +CONTAINER="$CONTAINER" \ +MOUNTS="$MOUNTS" \ +GPUS_PER_NODE=8 \ +sbatch \ + --nodes=1 --partition=batch --exclusive \ + --job-name=erdos-debug-16k --time=02:00:00 \ + --output="$EXP/slurm-%j.out" \ + --error="$EXP/slurm-%j.err" \ + --exclude=d2dfac12-001,d2dfac12-002,d2dfac12-004,d2dfac12-007,d2dfac12-008,d2dfac12-019,d2dfac12-027,d2dfac12-028,d2dfac12-029 \ + ray.sub + +echo "Logs: $EXP/" From 745468ee5908e690bbeac3bdf2e50f942449e722 Mon Sep 17 00:00:00 2001 From: Morgane Moss Date: Thu, 2 Apr 2026 14:40:28 +0000 Subject: [PATCH 32/48] fix: ensure checkpointing config exists for both v0.5.0 and super-v3 containers --- examples/run_discover.py | 16 ++++++++++++++++ launch_erdos_debug_16k.sh | 1 - 2 files changed, 16 insertions(+), 1 deletion(-) diff --git a/examples/run_discover.py b/examples/run_discover.py index 256eda4c1b..1f82c6551f 100644 --- a/examples/run_discover.py +++ b/examples/run_discover.py @@ -229,6 +229,22 @@ def main(): setup_discover_data(config, tokenizer) ) + # Ensure checkpointing config exists (some container versions require it) + if "checkpointing" not in config: + config["checkpointing"] = { + "enabled": False, + "checkpoint_dir": "results/erdos", + "save_period": 999999, + "checkpoint_must_save_by": None, + "model_save_format": "safetensors", + "save_consolidated": False, + "metric_name": "total_reward/mean", + "higher_is_better": True, + "keep_top_k": 1000000, + } + elif config["checkpointing"].get("checkpoint_must_save_by") is None: + config["checkpointing"]["checkpoint_must_save_by"] = None + # Setup returns vary across container versions setup_result = setup(config, tokenizer, train_dataset, val_dataset) setup_list = list(setup_result) diff --git a/launch_erdos_debug_16k.sh b/launch_erdos_debug_16k.sh index 75e0121e38..360ce445db 100755 --- a/launch_erdos_debug_16k.sh +++ b/launch_erdos_debug_16k.sh @@ -13,7 +13,6 @@ COMMAND=" export HF_HUB_ENABLE_HF_TRANSFER=0 export TORCH_CUDA_ARCH_LIST='9.0 10.0' export NRL_IGNORE_VERSION_MISMATCH=1 -export PYTHONPATH=/home/mormio/RL:\${PYTHONPATH:-} export ERDOS_LOG_DIR=/home/mormio/RL/results/erdos_debug_outputs SRC=/home/mormio/RL From d1b79b5ae8083270b8527c3df49d7636a4d01078 Mon Sep 17 00:00:00 2001 From: Morgane Moss Date: Thu, 2 Apr 2026 15:04:45 +0000 Subject: [PATCH 33/48] fix: inject checkpointing into master_config returned by setup() (not just input config) --- examples/run_discover.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/examples/run_discover.py b/examples/run_discover.py index 1f82c6551f..dfafd213d5 100644 --- a/examples/run_discover.py +++ b/examples/run_discover.py @@ -256,6 +256,13 @@ def main(): (policy, policy_generation, _nemo_gym, _clusters, dataloader, val_dataloader, loss_fn, nemo_logger, checkpointer, grpo_state, master_config) = setup_list + # Ensure checkpointing exists in master_config + if "checkpointing" not in master_config: + master_config["checkpointing"] = config.get("checkpointing", { + "enabled": False, + "checkpoint_must_save_by": None, + "save_period": 999999, + }) grpo_train( policy, policy_generation, dataloader, val_dataloader, @@ -269,6 +276,13 @@ def main(): (policy, policy_generation, dataloader, val_dataloader, loss_fn, nemo_logger, checkpointer, grpo_state, master_config, _extra) = setup_list + # Ensure checkpointing exists in master_config + if "checkpointing" not in master_config: + master_config["checkpointing"] = config.get("checkpointing", { + "enabled": False, + "checkpoint_must_save_by": None, + "save_period": 999999, + }) grpo_train( policy, policy_generation, dataloader, val_dataloader, From f699f0257a1918334b6baac55841a24fd1f87504 Mon Sep 17 00:00:00 2001 From: Morgane Moss Date: Thu, 2 Apr 2026 15:23:47 +0000 Subject: [PATCH 34/48] debug: print setup() return types to fix unpacking order --- examples/run_discover.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/examples/run_discover.py b/examples/run_discover.py index dfafd213d5..5b94419852 100644 --- a/examples/run_discover.py +++ b/examples/run_discover.py @@ -250,6 +250,8 @@ def main(): setup_list = list(setup_result) n = len(setup_list) print(f" setup() returned {n} values") + for i, v in enumerate(setup_list): + print(f" [{i}] {type(v).__name__}: {str(v)[:80]}") if n == 11: # super-v3 container From f6899a8e714bc463ed03c3da08d659d8a6c20ab3 Mon Sep 17 00:00:00 2001 From: Morgane Moss Date: Thu, 2 Apr 2026 15:42:46 +0000 Subject: [PATCH 35/48] fix: correct v0.5.0 setup() unpacking order (clusters at [2], not dataloader) --- examples/run_discover.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/examples/run_discover.py b/examples/run_discover.py index 5b94419852..b531ac5bcc 100644 --- a/examples/run_discover.py +++ b/examples/run_discover.py @@ -274,10 +274,12 @@ def main(): grpo_state, master_config, ) elif n == 10: - # v0.5.0 container - (policy, policy_generation, dataloader, val_dataloader, - loss_fn, nemo_logger, checkpointer, grpo_state, - master_config, _extra) = setup_list + # v0.5.0 container: Policy, VllmGen, clusters, train_dl, val_dl, + # loss_fn, logger, checkpointer, grpo_state, master_config + (policy, policy_generation, _clusters, + dataloader, val_dataloader, + loss_fn, nemo_logger, checkpointer, + grpo_state, master_config) = setup_list # Ensure checkpointing exists in master_config if "checkpointing" not in master_config: master_config["checkpointing"] = config.get("checkpointing", { From f5590fd47aef88f420d9425f06c86c26e851ad6c Mon Sep 17 00:00:00 2001 From: Morgane Moss Date: Thu, 2 Apr 2026 16:00:48 +0000 Subject: [PATCH 36/48] fix: debug config back to 4k (16k OOMs on 1 node with 1.5B + LoRA) --- examples/configs/grpo_erdos_debug_16k.yaml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/examples/configs/grpo_erdos_debug_16k.yaml b/examples/configs/grpo_erdos_debug_16k.yaml index 3f7d61dbd6..27216224af 100644 --- a/examples/configs/grpo_erdos_debug_16k.yaml +++ b/examples/configs/grpo_erdos_debug_16k.yaml @@ -25,7 +25,7 @@ policy: tokenizer: name: "Qwen/Qwen2.5-1.5B-Instruct" chat_template_kwargs: null - max_total_sequence_length: 16384 + max_total_sequence_length: 4096 train_global_batch_size: 32 train_micro_batch_size: 4 dtensor_cfg: @@ -41,7 +41,7 @@ policy: dropout: 0.0 generation: backend: "vllm" - max_new_tokens: 15360 + max_new_tokens: 3072 temperature: 1.0 top_p: 1.0 stop_token_ids: null @@ -52,7 +52,7 @@ policy: pipeline_parallel_size: 1 expert_parallel_size: 1 gpu_memory_utilization: 0.6 - max_model_len: 16384 + max_model_len: 4096 dynamic_batching: enabled: false From 7a7bd714a01891a089ae89c054362cb32206be60 Mon Sep 17 00:00:00 2001 From: Morgane Moss Date: Thu, 2 Apr 2026 16:44:01 +0000 Subject: [PATCH 37/48] 120B at 4k context, 8 nodes, 50 steps, checkpointing fully disabled --- examples/configs/grpo_erdos_discover.yaml | 17 ++++++++--------- launch_erdos_120b.sh | 6 +++--- 2 files changed, 11 insertions(+), 12 deletions(-) diff --git a/examples/configs/grpo_erdos_discover.yaml b/examples/configs/grpo_erdos_discover.yaml index 69e07dcde4..90b31c8670 100644 --- a/examples/configs/grpo_erdos_discover.yaml +++ b/examples/configs/grpo_erdos_discover.yaml @@ -1,5 +1,4 @@ -# TTT-Discover Erdős — Nemotron-3-Super-120B-A12B -# 10 nodes: 2 inference + 8 training, CP=2 for 16k seq len +# TTT-Discover Erdős — Nemotron-3-Super-120B-A12B, 4k seq, 8 nodes defaults: "grpo_superv3.yaml" grpo: @@ -25,7 +24,7 @@ policy: tokenizer: name: "/home/shared/models/NVIDIA-Nemotron-3-Super-120B-A12B-Base-BF16" chat_template_kwargs: null - max_total_sequence_length: 16384 + max_total_sequence_length: 4096 train_global_batch_size: 504 train_micro_batch_size: 1 logprob_batch_size: 1 @@ -36,17 +35,17 @@ policy: resources: num_nodes: 2 gpus_per_node: 8 - max_new_tokens: 15360 + max_new_tokens: 3072 vllm_cfg: async_engine: false tensor_parallel_size: 8 gpu_memory_utilization: 0.85 - max_model_len: 16384 + max_model_len: 4096 megatron_cfg: tensor_model_parallel_size: 4 pipeline_model_parallel_size: 1 - context_parallel_size: 2 + context_parallel_size: 1 expert_model_parallel_size: 8 sequence_parallel: true activation_checkpointing: true @@ -67,7 +66,7 @@ optimizer: data: shuffle: false - max_input_seq_length: 16384 + max_input_seq_length: 4096 env: erdos_discovery: @@ -78,14 +77,14 @@ env: cluster: gpus_per_node: 8 - num_nodes: 10 + num_nodes: 8 logger: log_dir: "results/erdos-120b" wandb_enabled: true wandb: project: "ttt-discover-erdos" - name: "nemotron-120b-16k-10node" + name: "nemotron-120b-4k-50steps" tensorboard_enabled: false mlflow_enabled: false swanlab_enabled: false diff --git a/launch_erdos_120b.sh b/launch_erdos_120b.sh index 580f2111b0..3245916224 100755 --- a/launch_erdos_120b.sh +++ b/launch_erdos_120b.sh @@ -81,8 +81,8 @@ python examples/run_discover.py \ echo "Submitting Erdős TTT-Discover 120B (8k seq, wandb)..." echo " Container: $CONTAINER" echo " Model: $MODEL_PATH" -echo " Nodes: 10 (2 inference + 8 training)" -echo " Seq len: 16384" +echo " Nodes: 8 (2 inference + 6 training)" +echo " Seq len: 4096" echo " Exp: $EXP" COMMAND="$COMMAND" \ @@ -90,7 +90,7 @@ CONTAINER="$CONTAINER" \ MOUNTS="$MOUNTS" \ GPUS_PER_NODE=8 \ sbatch \ - --nodes=10 --partition=batch --exclusive \ + --nodes=8 --partition=batch --exclusive \ --job-name=erdos-120b --time=12:00:00 \ --output="$EXP/slurm-%j.out" \ --error="$EXP/slurm-%j.err" \ From 7de56f34d31ea2c8fca01d057668e1a3360a709c Mon Sep 17 00:00:00 2001 From: Morgane Moss Date: Thu, 2 Apr 2026 19:16:15 +0000 Subject: [PATCH 38/48] fix: use multiprocessing.Process + kill() for hard sandbox timeout (threads cant be killed) --- .../erdos_discovery_environment.py | 85 ++++++++++++------- 1 file changed, 52 insertions(+), 33 deletions(-) diff --git a/nemo_rl/environments/erdos_discovery_environment.py b/nemo_rl/environments/erdos_discovery_environment.py index db35b17863..d14e81f803 100644 --- a/nemo_rl/environments/erdos_discovery_environment.py +++ b/nemo_rl/environments/erdos_discovery_environment.py @@ -179,53 +179,72 @@ def _handler(s, f): except Exception as e: result_queue.put({"error": f"{type(e).__name__}: {str(e)[:300]}"}) - # Try simple exec first with a short thread-based timeout - # For most invalid code this returns instantly - from concurrent.futures import ThreadPoolExecutor, TimeoutError as FuturesTimeout - - def _exec_with_timeout(): - exec(compile(code, "", "exec"), namespace) - if "run" not in namespace: - return {"error": "No 'run' function defined"} - result = namespace["run"](seed=42, budget_s=min(timeout, 60)) - return {"result": result} - - with ThreadPoolExecutor(max_workers=1) as pool: - future = pool.submit(_exec_with_timeout) + # Use multiprocessing with kill() for hard timeout + import multiprocessing as _mp + import pickle as _pickle + + _EXEC_TIMEOUT = min(timeout, 120) + + def _worker_fn(code_str, q): + import signal + signal.signal(signal.SIGALRM, lambda s, f: (_ for _ in ()).throw(Exception("timeout"))) + signal.alarm(_EXEC_TIMEOUT) try: - exec_result = future.result(timeout=min(timeout, 120)) - except FuturesTimeout: - # Thread is stuck — it will be abandoned when pool is GC'd - return { - "reward": 0.0, "raw_score": None, - "error_msg": f"Execution timed out after {min(timeout, 120)}s", - "stdout": "".join(stdout_capture), - } + ns = {} + ns["__builtins__"] = __builtins__ if isinstance(__builtins__, dict) else __builtins__.__dict__.copy() + import numpy, math, random + ns.update({"np": numpy, "numpy": numpy, "math": math, "random": random}) + def _eval(h, c, n): + from erdos_verify import verify_c5_solution as _v + _v(h, c, n) + return float(c) + ns["evaluate_erdos_solution"] = lambda h, c, n: float(c) + exec(compile(code_str, "", "exec"), ns) + if "run" not in ns: + q.put({"error": "No 'run' function defined"}) + return + out = ns["run"](seed=42, budget_s=_EXEC_TIMEOUT) + q.put({"result": (out[0].tolist() if hasattr(out[0], 'tolist') else list(out[0]), float(out[1]), int(out[2]))}) except Exception as e: - return { - "reward": 0.0, "raw_score": None, - "error_msg": f"{type(e).__name__}: {str(e)[:300]}", - "stdout": "".join(stdout_capture), - } + q.put({"error": f"{type(e).__name__}: {str(e)[:300]}"}) - if "error" in exec_result: + q = _mp.Queue() + p = _mp.Process(target=_worker_fn, args=(code, q)) + p.start() + p.join(timeout=_EXEC_TIMEOUT + 5) + + if p.is_alive(): + p.kill() + p.join(timeout=5) return { "reward": 0.0, "raw_score": None, - "error_msg": exec_result["error"], + "error_msg": f"Execution killed after {_EXEC_TIMEOUT}s timeout", "stdout": "".join(stdout_capture), } - result = exec_result["result"] + if q.empty(): + return { + "reward": 0.0, "raw_score": None, + "error_msg": "Worker process died without result", + "stdout": "".join(stdout_capture), + } + + exec_result = q.get_nowait() - if not isinstance(result, tuple) or len(result) != 3: + if "error" in exec_result: return { "reward": 0.0, "raw_score": None, - "error_msg": f"run() must return (h_values, c5_bound, n_points), got {type(result)}", + "error_msg": exec_result["error"], "stdout": "".join(stdout_capture), } - h_values, c5_bound, n_points = result - h_values = np.asarray(h_values, dtype=np.float64) + raw = exec_result["result"] + h_values = np.asarray(raw[0], dtype=np.float64) + c5_bound = raw[1] + n_points = raw[2] + result = (h_values, c5_bound, n_points) + + # Verify computed_c5 = verify_c5_solution(h_values, c5_bound, n_points) From ab974aa3bb114173f3619e0a66273e79229d8c26 Mon Sep 17 00:00:00 2001 From: Morgane Moss Date: Thu, 2 Apr 2026 20:50:58 +0000 Subject: [PATCH 39/48] fix: sandbox timeout 1000s matching paper, not 120s --- nemo_rl/environments/erdos_discovery_environment.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/nemo_rl/environments/erdos_discovery_environment.py b/nemo_rl/environments/erdos_discovery_environment.py index d14e81f803..673ecc0b8e 100644 --- a/nemo_rl/environments/erdos_discovery_environment.py +++ b/nemo_rl/environments/erdos_discovery_environment.py @@ -183,12 +183,12 @@ def _handler(s, f): import multiprocessing as _mp import pickle as _pickle - _EXEC_TIMEOUT = min(timeout, 120) + _EXEC_TIMEOUT = min(timeout, 1000) def _worker_fn(code_str, q): import signal signal.signal(signal.SIGALRM, lambda s, f: (_ for _ in ()).throw(Exception("timeout"))) - signal.alarm(_EXEC_TIMEOUT) + signal.alarm(_EXEC_TIMEOUT - 5) # 5s grace before hard kill try: ns = {} ns["__builtins__"] = __builtins__ if isinstance(__builtins__, dict) else __builtins__.__dict__.copy() @@ -203,7 +203,7 @@ def _eval(h, c, n): if "run" not in ns: q.put({"error": "No 'run' function defined"}) return - out = ns["run"](seed=42, budget_s=_EXEC_TIMEOUT) + out = ns["run"](seed=42, budget_s=_EXEC_TIMEOUT - 10) q.put({"result": (out[0].tolist() if hasattr(out[0], 'tolist') else list(out[0]), float(out[1]), int(out[2]))}) except Exception as e: q.put({"error": f"{type(e).__name__}: {str(e)[:300]}"}) @@ -211,7 +211,7 @@ def _eval(h, c, n): q = _mp.Queue() p = _mp.Process(target=_worker_fn, args=(code, q)) p.start() - p.join(timeout=_EXEC_TIMEOUT + 5) + p.join(timeout=_EXEC_TIMEOUT + 10) if p.is_alive(): p.kill() From 1cf501f418e1addefe07770a6374ad128262e44e Mon Sep 17 00:00:00 2001 From: Morgane Moss Date: Thu, 2 Apr 2026 20:59:41 +0000 Subject: [PATCH 40/48] fix: clean subprocess sandbox - BaseException for alarm, SIGTERM before SIGKILL, 1000s timeout --- .../erdos_discovery_environment.py | 77 +++++++++++++------ 1 file changed, 52 insertions(+), 25 deletions(-) diff --git a/nemo_rl/environments/erdos_discovery_environment.py b/nemo_rl/environments/erdos_discovery_environment.py index 673ecc0b8e..d3d07705ac 100644 --- a/nemo_rl/environments/erdos_discovery_environment.py +++ b/nemo_rl/environments/erdos_discovery_environment.py @@ -179,57 +179,84 @@ def _handler(s, f): except Exception as e: result_queue.put({"error": f"{type(e).__name__}: {str(e)[:300]}"}) - # Use multiprocessing with kill() for hard timeout + # Run in subprocess with signal.alarm for clean timeout. + # The subprocess handles its own timeout and exits cleanly. + # We only use p.terminate() (SIGTERM, not SIGKILL) as a last resort. import multiprocessing as _mp - import pickle as _pickle _EXEC_TIMEOUT = min(timeout, 1000) - def _worker_fn(code_str, q): - import signal - signal.signal(signal.SIGALRM, lambda s, f: (_ for _ in ()).throw(Exception("timeout"))) - signal.alarm(_EXEC_TIMEOUT - 5) # 5s grace before hard kill + def _worker_fn(code_str, result_queue, exec_timeout): + """Run code in a subprocess. signal.alarm works here (main thread).""" + import signal as _sig + import sys as _sys + import os as _os + + class _AlarmTimeout(BaseException): + pass + + def _alarm_handler(signum, frame): + raise _AlarmTimeout() + + _sig.signal(_sig.SIGALRM, _alarm_handler) + _sig.alarm(exec_timeout) + try: - ns = {} - ns["__builtins__"] = __builtins__ if isinstance(__builtins__, dict) else __builtins__.__dict__.copy() import numpy, math, random - ns.update({"np": numpy, "numpy": numpy, "math": math, "random": random}) - def _eval(h, c, n): - from erdos_verify import verify_c5_solution as _v - _v(h, c, n) - return float(c) - ns["evaluate_erdos_solution"] = lambda h, c, n: float(c) + ns = { + "__builtins__": __builtins__ if isinstance(__builtins__, dict) else vars(__builtins__).copy(), + "np": numpy, "numpy": numpy, "math": math, "random": random, + "evaluate_erdos_solution": lambda h, c, n: float(c), + } exec(compile(code_str, "", "exec"), ns) if "run" not in ns: - q.put({"error": "No 'run' function defined"}) + result_queue.put({"error": "No 'run' function defined"}) return - out = ns["run"](seed=42, budget_s=_EXEC_TIMEOUT - 10) - q.put({"result": (out[0].tolist() if hasattr(out[0], 'tolist') else list(out[0]), float(out[1]), int(out[2]))}) + out = ns["run"](seed=42, budget_s=exec_timeout - 10) + # Serialize result (numpy arrays can't cross process boundary directly) + h = out[0].tolist() if hasattr(out[0], "tolist") else list(out[0]) + result_queue.put({"result": (h, float(out[1]), int(out[2]))}) + except _AlarmTimeout: + result_queue.put({"error": f"Execution timed out after {exec_timeout}s"}) except Exception as e: - q.put({"error": f"{type(e).__name__}: {str(e)[:300]}"}) + result_queue.put({"error": f"{type(e).__name__}: {str(e)[:300]}"}) + finally: + _sig.alarm(0) # Cancel any pending alarm q = _mp.Queue() - p = _mp.Process(target=_worker_fn, args=(code, q)) + p = _mp.Process(target=_worker_fn, args=(code, q, _EXEC_TIMEOUT)) p.start() - p.join(timeout=_EXEC_TIMEOUT + 10) + # Wait for subprocess: alarm should fire inside it, so give extra grace + p.join(timeout=_EXEC_TIMEOUT + 30) if p.is_alive(): - p.kill() - p.join(timeout=5) + # Subprocess didn't exit cleanly — send SIGTERM first (graceful) + p.terminate() + p.join(timeout=10) + if p.is_alive(): + p.kill() # Last resort + p.join(timeout=5) return { "reward": 0.0, "raw_score": None, - "error_msg": f"Execution killed after {_EXEC_TIMEOUT}s timeout", + "error_msg": f"Subprocess terminated after {_EXEC_TIMEOUT}s", "stdout": "".join(stdout_capture), } if q.empty(): return { "reward": 0.0, "raw_score": None, - "error_msg": "Worker process died without result", + "error_msg": "Subprocess exited without result", "stdout": "".join(stdout_capture), } - exec_result = q.get_nowait() + try: + exec_result = q.get_nowait() + except Exception: + return { + "reward": 0.0, "raw_score": None, + "error_msg": "Failed to read subprocess result", + "stdout": "".join(stdout_capture), + } if "error" in exec_result: return { From ee698b942795848f8a404ade5d5f405e7c97795b Mon Sep 17 00:00:00 2001 From: Morgane Moss Date: Fri, 3 Apr 2026 02:37:43 +0000 Subject: [PATCH 41/48] feat: stateful PUCT sampler integrated into RL env + dataset - ErdosRefPUCTSampler: full PUCT tree with state tracking, update_states, record_failed_rollout, flush, sample_states (matches ttt-discover-ref) - PUCTDiscoverDataset: calls env.puct_sample_states.remote() for prompts (shared state between dataset and env via Ray actor) - Environment _sync_step updates PUCT buffer on success/failure - RandomDiscoverDataset for validation (no PUCT) --- examples/configs/grpo_erdos_debug_16k.yaml | 2 +- examples/configs/grpo_erdos_discover.yaml | 4 +- .../configs/grpo_erdos_discover_debug.yaml | 3 +- examples/run_discover.py | 118 +++- .../erdos_discovery_environment.py | 123 ++++- .../environments/erdos_ref_puct_sampler.py | 513 ++++++++++++++++++ 6 files changed, 712 insertions(+), 51 deletions(-) create mode 100644 nemo_rl/environments/erdos_ref_puct_sampler.py diff --git a/examples/configs/grpo_erdos_debug_16k.yaml b/examples/configs/grpo_erdos_debug_16k.yaml index 27216224af..e02f8e4f5c 100644 --- a/examples/configs/grpo_erdos_debug_16k.yaml +++ b/examples/configs/grpo_erdos_debug_16k.yaml @@ -66,7 +66,7 @@ data: env: erdos_discovery: num_initial_states: 8 - num_groups_per_step: 4 + puct_seed_batch_size: 4 sandbox_timeout: 120 should_use_nemo_gym: false diff --git a/examples/configs/grpo_erdos_discover.yaml b/examples/configs/grpo_erdos_discover.yaml index 90b31c8670..0a10662d48 100644 --- a/examples/configs/grpo_erdos_discover.yaml +++ b/examples/configs/grpo_erdos_discover.yaml @@ -66,12 +66,14 @@ optimizer: data: shuffle: false + num_workers: 0 # PUCT training: dataset must run on driver (Ray select) max_input_seq_length: 4096 env: erdos_discovery: num_initial_states: 16 - num_groups_per_step: 8 + # Must match grpo.num_prompts_per_step (run_discover enforces ref-style PUCT batch_size). + puct_seed_batch_size: 8 sandbox_timeout: 1000 should_use_nemo_gym: false diff --git a/examples/configs/grpo_erdos_discover_debug.yaml b/examples/configs/grpo_erdos_discover_debug.yaml index e6282123b2..ae47eb6ae8 100644 --- a/examples/configs/grpo_erdos_discover_debug.yaml +++ b/examples/configs/grpo_erdos_discover_debug.yaml @@ -61,11 +61,12 @@ optimizer: data: shuffle: false + num_workers: 0 # required: PUCT dataset calls Ray from the iterator process env: erdos_discovery: num_initial_states: 8 - num_groups_per_step: 4 + puct_seed_batch_size: 4 sandbox_timeout: 120 checkpointing: diff --git a/examples/run_discover.py b/examples/run_discover.py index b531ac5bcc..34c39c011a 100644 --- a/examples/run_discover.py +++ b/examples/run_discover.py @@ -65,17 +65,21 @@ def generate_discover_datum( msg_ids = tokenizer.encode(msg_text, add_special_tokens=False) msg["token_ids"] = torch.tensor(msg_ids, dtype=torch.long) + extra = { + "construction": state.get("construction"), + "c5_bound": state.get("c5_bound"), + "n_points": state.get("n_points"), + "code": state.get("code", ""), + "parent_c5": state.get("parent_c5"), + "observation": state.get("observation", ""), + } + if state.get("erdos_ref_state") is not None: + extra["erdos_ref_state"] = state["erdos_ref_state"] + return DatumSpec( message_log=messages, length=len(prompt_ids), - extra_env_info={ - "construction": state.get("construction"), - "c5_bound": state.get("c5_bound"), - "n_points": state.get("n_points"), - "code": state.get("code", ""), - "parent_c5": state.get("parent_c5"), - "observation": state.get("observation", ""), - }, + extra_env_info=extra, loss_multiplier=1.0, idx=idx, task_name=task_name, @@ -83,21 +87,57 @@ def generate_discover_datum( # ═══════════════════════════════════════════════════════════════════ -# Dataset backed by initial states (PUCT selection comes later) +# Datasets: PUCT (train) vs random (val) # ═══════════════════════════════════════════════════════════════════ -class DiscoverDataset(IterableDataset): - """Dataset that generates prompts from Erdős initial states. - - Each iteration generates diverse initial states and yields them as - DatumSpecs with the reference TTT-Discover prompt format. +class PUCTDiscoverDataset(IterableDataset): + """Training dataset: pulls PUCT-selected states from the Ray env actor. - For now, initial states are random perturbations of h=0.5 (matching - the reference). Future: PUCT buffer selects states based on prior - discoveries. + Matches ttt-discover-ref: ``ErdosRefPUCTSampler.sample_states(num_prompts)`` + (distinct parent states per step). The env updates the sampler in + ``ErdosDiscoveryEnvironment._sync_step`` via ``update_states`` / + ``record_failed_rollout`` and ``flush``. Requires ``data.num_workers == 0``. """ + def __init__( + self, + tokenizer, + env_actor, + num_prompts_per_step: int, + task_name: str = "erdos_discovery", + length: int = 1000, + ): + self.tokenizer = tokenizer + self._env = env_actor + self.num_prompts_per_step = num_prompts_per_step + self.task_name = task_name + self.length = length + self._idx_counter = itertools.count() + + def __iter__(self): + for _ in itertools.count(): + states = ray.get( + self._env.puct_sample_states.remote( + self.num_prompts_per_step, + ) + ) + for state in states: + idx = next(self._idx_counter) + yield generate_discover_datum( + self.tokenizer, + state, + idx=idx, + task_name=self.task_name, + ) + + def __len__(self): + return self.length + + +class RandomDiscoverDataset(IterableDataset): + """Random warm-starts (e.g. validation) — does not touch the PUCT sampler.""" + def __init__( self, tokenizer, @@ -115,7 +155,6 @@ def __init__( def __iter__(self): for _ in itertools.count(): - # Generate fresh initial states each step for _ in range(self.num_states_per_step): state = create_initial_state(self._rng) idx = next(self._idx_counter) @@ -137,19 +176,44 @@ def __len__(self): def setup_discover_data(config: MasterConfig, tokenizer): """Create dataset, environment, and wire them together.""" - env_config = config.get("env", {}).get("erdos_discovery", {}) - num_states = config.get("grpo", {}).get("num_prompts_per_step", 8) + base_env = config.get("env", {}).get("erdos_discovery", {}) + env_config = dict(base_env) if isinstance(base_env, dict) else {} + num_states = int(config.get("grpo", {}).get("num_prompts_per_step", 8)) task_name = "erdos_discovery" - train_dataset = DiscoverDataset( + # Ref parity: cold-start seed count must match prompts per step (PUCT batch_size). + seed_bs = int(env_config.get("puct_seed_batch_size", num_states)) + if seed_bs != num_states: + logger.warning( + "Overriding puct_seed_batch_size %s -> %s to match grpo.num_prompts_per_step " + "(ttt-discover-ref PUCT batch_size).", + seed_bs, + num_states, + ) + env_config["puct_seed_batch_size"] = num_states + + data_cfg = config.get("data", {}) + if int(data_cfg.get("num_workers", 0)) != 0: + logger.warning( + "Setting data.num_workers=0 for Erdős PUCT (dataset calls Ray on the driver)." + ) + config["data"] = {**data_cfg, "num_workers": 0} + + env = ErdosDiscoveryEnvironment.options( + num_gpus=0, + max_restarts=-1, + max_task_retries=-1, + ).remote(config=env_config) + + train_dataset = PUCTDiscoverDataset( tokenizer=tokenizer, - num_states_per_step=num_states, + env_actor=env, + num_prompts_per_step=num_states, task_name=task_name, length=config.get("grpo", {}).get("max_num_steps", 50) * num_states, - seed=config.get("seed", 42), ) - val_dataset = DiscoverDataset( + val_dataset = RandomDiscoverDataset( tokenizer=tokenizer, num_states_per_step=num_states, task_name=task_name, @@ -157,12 +221,6 @@ def setup_discover_data(config: MasterConfig, tokenizer): seed=config.get("seed", 42) + 1, ) - env = ErdosDiscoveryEnvironment.options( - num_gpus=0, - max_restarts=-1, - max_task_retries=-1, - ).remote(config=env_config) - task_to_env = {task_name: env} val_task_to_env = {task_name: env} diff --git a/nemo_rl/environments/erdos_discovery_environment.py b/nemo_rl/environments/erdos_discovery_environment.py index d3d07705ac..97c39701cc 100644 --- a/nemo_rl/environments/erdos_discovery_environment.py +++ b/nemo_rl/environments/erdos_discovery_environment.py @@ -14,6 +14,7 @@ import asyncio import logging import math +import os import re import signal import time @@ -28,6 +29,11 @@ EnvironmentReturn, ) from nemo_rl.data.interfaces import LLMMessageLogType +from nemo_rl.environments.erdos_ref_puct_sampler import ( + ErdosRefPUCTSampler, + ErdosRefState, + erdos_ref_state_to_prompt_state, +) logger = logging.getLogger(__name__) @@ -336,6 +342,20 @@ def create_initial_state(rng=None): } +def erdos_ref_state_from_create_initial(rng: np.random.Generator) -> ErdosRefState: + """Match ttt-discover-ref ErdosMinOverlapEnv.create_initial_state → State.""" + d = create_initial_state(rng) + return ErdosRefState( + timestep=-1, + construction=list(d["construction"]), + code="", + value=-float(d["c5_bound"]), + parent_values=[], + parents=[], + observation="", + ) + + # ═══════════════════════════════════════════════════════════════════ # Prompt construction (from reference State.to_prompt + ErdosMinOverlapEnv.get_question) # ═══════════════════════════════════════════════════════════════════ @@ -480,7 +500,11 @@ class ErdosDiscoveryEnvironment(EnvironmentInterface): def __init__(self, config: dict = None): config = config or {} self.sandbox_timeout = config.get("sandbox_timeout", 1000) - self.num_initial_states = config.get("num_initial_states", 16) + self.num_initial_states = int(config.get("num_initial_states", 8)) + self.puct_c = float(config.get("puct_c", 1.0)) + self.puct_seed_batch_size = int( + config.get("puct_seed_batch_size", self.num_initial_states) + ) # Tracking self.best_reward = 0.0 @@ -488,21 +512,43 @@ def __init__(self, config: dict = None): self.total_verified = 0 self.total_valid = 0 - # PUCT buffer for state management - self._states = [] - self._initialize_states() - - def _initialize_states(self): - """Generate initial random states.""" - rng = np.random.default_rng(42) - for _ in range(self.num_initial_states): - self._states.append(create_initial_state(rng)) + log_dir = config.get("puct_log_dir") or os.environ.get( + "ERDOS_PUCT_LOG_DIR", "/tmp/erdos_puct" + ) + os.makedirs(log_dir, exist_ok=True) + sampler_path = os.path.join(log_dir, "puct_sampler.json") + resume_step = config.get("puct_resume_step") + if resume_step is not None: + resume_step = int(resume_step) + + self.sampler = ErdosRefPUCTSampler( + file_path=sampler_path, + init_state_fn=lambda: erdos_ref_state_from_create_initial( + np.random.default_rng() + ), + max_buffer_size=int(config.get("puct_max_buffer_size", 1000)), + batch_size=self.puct_seed_batch_size, + resume_step=resume_step, + puct_c=self.puct_c, + topk_children=int(config.get("puct_topk_children", 2)), + max_construction_len=int(config.get("puct_max_construction_len", 1000)), + ) def get_initial_states(self, n: int = None) -> list[dict]: - """Return initial states for prompt generation.""" - if n is None: - return self._states - return self._states[:n] + """Random states (e.g. validation). Training uses puct_sample_states().""" + rng = np.random.default_rng(42) + k = n if n is not None else self.num_initial_states + return [create_initial_state(rng) for _ in range(k)] + + def puct_sample_states(self, num_prompts: int) -> list[dict]: + """ttt-discover-ref PUCTSampler.sample_states — prompts + serial parent State.""" + picked = self.sampler.sample_states(num_prompts) + out: list[dict] = [] + for s in picked: + prompt_state = erdos_ref_state_to_prompt_state(s) + prompt_state["erdos_ref_state"] = s.to_dict() + out.append(prompt_state) + return out def step( self, @@ -533,6 +579,8 @@ def _sync_step( import time as _time _t0 = _time.time() + step_num = getattr(self, "_step_count", 0) + 1 + self._step_count = step_num print(f"[{_time.strftime('%H:%M:%S')}] 🧪 Starting reward computation for {batch_size} rollouts") for i, message_log in enumerate(message_log_batch): @@ -558,6 +606,14 @@ def _sync_step( # Inject initial_h_values from state if available state = metadata[i] if i < len(metadata) else {} construction = state.get("construction", None) + parent_erdos: Optional[ErdosRefState] = None + raw_parent = state.get("erdos_ref_state") + if raw_parent is not None: + try: + parent_erdos = ErdosRefState.from_dict(raw_parent) + except Exception as e: + logger.warning("Invalid erdos_ref_state in metadata: %s", e) + preamble = "import numpy as np\n\n" if construction: preamble += f"initial_h_values = np.array({construction!r})\n\n" @@ -582,6 +638,28 @@ def _sync_step( answers[i] = f"C5={c5:.6f}" if c5 else f"reward={reward:.4f}" + if parent_erdos is not None: + if ( + reward > 0 + and c5 is not None + and result.get("result_construction") is not None + ): + child = ErdosRefState( + timestep=step_num, + construction=list(result["result_construction"]), + code=code, + value=-float(c5), + observation=str(result.get("stdout", "") or ""), + ) + try: + self.sampler.update_states( + [child], [parent_erdos], save=False + ) + except Exception as e: + logger.warning("PUCT update_states failed: %s", e) + else: + self.sampler.record_failed_rollout(parent_erdos) + if i < len(updated_metadata): updated_metadata[i] = { **updated_metadata[i], @@ -611,10 +689,8 @@ def _sync_step( ) # Save outputs to JSONL for debugging - import json, os - self.total_verified # use as step counter proxy - step_num = getattr(self, '_step_count', 0) + 1 - self._step_count = step_num + import json + log_dir = os.environ.get("ERDOS_LOG_DIR", "/tmp/erdos_outputs") os.makedirs(log_dir, exist_ok=True) out_path = os.path.join(log_dir, f"step_{step_num:03d}.jsonl") @@ -646,6 +722,11 @@ def _sync_step( except Exception as e: print(f" ⚠️ Failed to save outputs: {e}") + try: + self.sampler.flush(step_num) + except Exception as e: + logger.warning("PUCT flush failed: %s", e) + return EnvironmentReturn( observations=observations, metadata=updated_metadata, @@ -697,4 +778,10 @@ def global_post_process_and_metrics( f"valid={batch_valid}/{len(metadata)} " f"best_c5={best}") + try: + for k, v in self.sampler.get_sample_stats().items(): + metrics[f"erdos/{k}"] = float(v) + except Exception: + pass + return metadata, metrics diff --git a/nemo_rl/environments/erdos_ref_puct_sampler.py b/nemo_rl/environments/erdos_ref_puct_sampler.py new file mode 100644 index 0000000000..c04b7152e4 --- /dev/null +++ b/nemo_rl/environments/erdos_ref_puct_sampler.py @@ -0,0 +1,513 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# PUCT sampler and Erdős State mirror ttt-discover-ref: +# ttt_discover/tinker_utils/sampler.py (PUCTSampler) +# ttt_discover/tinker_utils/state.py (State) +# examples/erdos_min_overlap/env.py (value = -C₅, minimize) + +from __future__ import annotations + +import json +import os +import threading +import time +import uuid +from contextlib import contextmanager +from dataclasses import dataclass, field +from typing import Any, Callable, Optional + +import numpy as np + + +def to_json_serializable(obj: Any) -> Any: + if isinstance(obj, np.ndarray): + return obj.tolist() + if isinstance(obj, (np.integer, np.floating)): + return obj.item() + if isinstance(obj, dict): + return {k: to_json_serializable(v) for k, v in obj.items()} + if isinstance(obj, (list, tuple)): + return [to_json_serializable(v) for v in obj] + return obj + + +@contextmanager +def _file_lock(lock_path: str, *, poll_s: float = 0.05, stale_s: float = 600.0): + while True: + try: + fd = os.open(lock_path, os.O_CREAT | os.O_EXCL | os.O_WRONLY) + try: + os.write(fd, f"{os.getpid()}\n{time.time()}\n".encode("utf-8")) + finally: + os.close(fd) + break + except FileExistsError: + try: + st = os.stat(lock_path) + if (time.time() - st.st_mtime) > stale_s: + os.remove(lock_path) + continue + except FileNotFoundError: + continue + time.sleep(poll_s) + try: + yield + finally: + try: + os.remove(lock_path) + except FileNotFoundError: + pass + + +def _atomic_write_json(path: str, obj: Any) -> None: + os.makedirs(os.path.dirname(path) or ".", exist_ok=True) + tmp_path = f"{path}.tmp.{os.getpid()}" + with open(tmp_path, "w", encoding="utf-8") as f: + json.dump(to_json_serializable(obj), f, indent=2, sort_keys=True) + f.flush() + os.fsync(f.fileno()) + os.replace(tmp_path, path) + + +def _read_json_or_default(path: str, default: Any) -> Any: + if not os.path.exists(path): + return default + try: + with open(path, "r", encoding="utf-8") as f: + return json.load(f) + except json.JSONDecodeError: + return default + + +def _sampler_file_for_step(base_path: str, step: int) -> str: + base_name = base_path.replace(".json", "") + return f"{base_name}_step_{step:06d}.json" + + +@dataclass +class ErdosRefState: + """Mirrors ttt_discover State for Erdős (value = -C₅, higher is better).""" + + timestep: int + construction: list + code: str = "" + value: Optional[float] = None + parent_values: list[float] = field(default_factory=list) + parents: list[dict] = field(default_factory=list) + observation: str = "" + id: str = field(default_factory=lambda: str(uuid.uuid4())) + + def to_dict(self) -> dict: + return { + "type": "ErdosRefState", + "id": self.id, + "timestep": self.timestep, + "value": self.value, + "parent_values": list(self.parent_values), + "parents": list(self.parents), + "observation": self.observation, + "construction": to_json_serializable(self.construction), + "code": self.code, + } + + @classmethod + def from_dict(cls, d: dict) -> ErdosRefState: + return cls( + timestep=int(d["timestep"]), + construction=list(d.get("construction") or []), + code=str(d.get("code") or ""), + value=d.get("value"), + parent_values=list(d.get("parent_values") or []), + parents=list(d.get("parents") or []), + observation=str(d.get("observation") or ""), + id=str(d.get("id") or uuid.uuid4()), + ) + + +def erdos_ref_state_to_prompt_state(s: ErdosRefState) -> dict[str, Any]: + """Map to keys consumed by build_erdos_question / state_to_prompt.""" + c5 = -float(s.value) if s.value is not None else None + parent_c5 = None + if s.parent_values: + parent_c5 = -float(s.parent_values[0]) + return { + "construction": list(s.construction) if s.construction else [], + "c5_bound": c5, + "n_points": len(s.construction) if s.construction else 0, + "code": s.code or "", + "parent_c5": parent_c5, + "observation": s.observation or "", + } + + +class ErdosRefPUCTSampler: + """ + Line-for-line behavior of ttt_discover PUCTSampler for Erdős discovery. + + score(i) = Q(i) + c * scale * P(i) * sqrt(1 + T) / (1 + n[i]) + Q(i) = m[i] if n[i]>0 else R(i); P = rank prior; scale = max(R)-min(R) on non-initial. + """ + + def __init__( + self, + file_path: str, + init_state_fn: Callable[[], ErdosRefState], + max_buffer_size: int = 1000, + batch_size: int = 1, + resume_step: Optional[int] = None, + puct_c: float = 1.0, + topk_children: int = 2, + max_construction_len: Optional[int] = 1000, + ): + self.file_path = file_path + self._init_state_fn = init_state_fn + self.max_buffer_size = max_buffer_size + self.batch_size = batch_size + self.topk_children = topk_children + self.puct_c = float(puct_c) + self.max_construction_len = max_construction_len + + self._states: list[ErdosRefState] = [] + self._initial_states: list[ErdosRefState] = [] + self._last_sampled_states: list[ErdosRefState] = [] + self._last_sampled_indices: list[int] = [] + self._lock = threading.Lock() + self._current_step = resume_step if resume_step is not None else 0 + + self._n: dict[str, int] = {} + self._m: dict[str, float] = {} + self._T: int = 0 + self._last_scale: float = 1.0 + self._last_puct_stats: list[tuple[int, float, float, float, float]] = [] + + if resume_step is not None: + self._load(resume_step) + if not self._states: + for _ in range(batch_size): + state = init_state_fn() + self._initial_states.append(state) + self._states.append(state) + self._save(self._current_step) + + @staticmethod + def _set_parent_info(child: ErdosRefState, parent: ErdosRefState) -> None: + child.parent_values = ( + [parent.value] + parent.parent_values if parent.value is not None else [] + ) + child.parents = [{"id": parent.id, "timestep": parent.timestep}] + parent.parents + + @staticmethod + def _filter_topk_per_parent( + states: list[ErdosRefState], + parent_states: list[ErdosRefState], + k: int, + ) -> tuple[list[ErdosRefState], list[ErdosRefState]]: + if not states: + return [], [] + if k == 0: + return states, parent_states + parent_to_children: dict[str, list[tuple[ErdosRefState, ErdosRefState]]] = {} + for child, parent in zip(states, parent_states): + pid = parent.id + parent_to_children.setdefault(pid, []).append((child, parent)) + topk_children, topk_parents = [], [] + for children_and_parents in parent_to_children.values(): + sorted_pairs = sorted( + children_and_parents, + key=lambda x: x[0].value if x[0].value is not None else float("-inf"), + reverse=True, + ) + for child, parent in sorted_pairs[:k]: + topk_children.append(child) + topk_parents.append(parent) + return topk_children, topk_parents + + def _load(self, step: int) -> None: + file_path = _sampler_file_for_step(self.file_path, step) + if not os.path.exists(file_path): + raise FileNotFoundError( + f"Cannot resume from step {step}: sampler file not found: {file_path}" + ) + with _file_lock(f"{file_path}.lock"): + store = _read_json_or_default(file_path, default=None) + if store is None: + raise ValueError(f"Failed to load sampler state from {file_path}") + self._states = [ErdosRefState.from_dict(s) for s in store.get("states", [])] + self._initial_states = [ + ErdosRefState.from_dict(s) for s in store.get("initial_states", []) + ] + self._n = store.get("puct_n", {}) or {} + self._m = store.get("puct_m", {}) or {} + self._T = int(store.get("puct_T", 0) or 0) + + def _save(self, step: int) -> None: + save_path = _sampler_file_for_step(self.file_path, step) + store = { + "step": step, + "states": [s.to_dict() for s in self._states], + "initial_states": [s.to_dict() for s in self._initial_states], + "puct_n": self._n, + "puct_m": self._m, + "puct_T": self._T, + } + with _file_lock(f"{save_path}.lock"): + _atomic_write_json(save_path, store) + + def _get_construction_key(self, state: ErdosRefState): + if state.construction: + return tuple(state.construction) + if state.code: + return state.code + return None + + def _compute_scale( + self, values: np.ndarray, mask: Optional[np.ndarray] = None + ) -> float: + if values.size == 0: + return 1.0 + v = values[mask] if mask is not None else values + return float(max(np.max(v) - np.min(v), 1e-6)) if v.size > 0 else 1.0 + + def _compute_prior(self, values: np.ndarray, scale: float) -> np.ndarray: + del scale # matches ref signature + if values.size == 0: + return np.array([]) + n = len(values) + ranks = np.argsort(np.argsort(-values)) + weights = (n - ranks).astype(np.float64) + return weights / weights.sum() + + def _get_lineage(self, state: ErdosRefState) -> set[str]: + lineage = {state.id} + for p in state.parents or []: + if p.get("id"): + lineage.add(str(p["id"])) + return lineage + + def _build_children_map(self) -> dict[str, set[str]]: + children: dict[str, set[str]] = {} + for s in self._states: + for p in s.parents or []: + pid = p.get("id") + if pid: + children.setdefault(str(pid), set()).add(s.id) + return children + + def _get_full_lineage( + self, state: ErdosRefState, children_map: dict[str, set[str]] + ) -> set[str]: + lineage = self._get_lineage(state) + queue = [state.id] + visited = {state.id} + while queue: + sid = queue.pop(0) + for child_id in children_map.get(sid, []): + if child_id not in visited: + visited.add(child_id) + lineage.add(child_id) + queue.append(child_id) + return lineage + + def sample_states(self, num_states: int) -> list[ErdosRefState]: + initial_ids = {s.id for s in self._initial_states} + candidates = list(self._states) + + if not candidates: + picked = [self._init_state_fn() for _ in range(num_states)] + self._last_sampled_states = picked + self._last_sampled_indices = [] + self._last_puct_stats = [(0, 0.0, 0.0, 0.0, 0.0) for _ in picked] + return picked + + vals = np.array( + [float(s.value if s.value is not None else float("-inf")) for s in candidates] + ) + non_initial_mask = np.array([s.id not in initial_ids for s in candidates]) + scale = self._compute_scale( + vals, non_initial_mask if non_initial_mask.any() else None + ) + self._last_scale = scale + p = self._compute_prior(vals, scale) + sqrt_t = np.sqrt(1.0 + self._T) + + scores = [] + for i, s in enumerate(candidates): + n = self._n.get(s.id, 0) + m = self._m.get(s.id, vals[i]) + q = m if n > 0 else vals[i] + bonus = self.puct_c * scale * p[i] * sqrt_t / (1.0 + n) + score = q + bonus + scores.append((score, vals[i], s, n, q, p[i], bonus)) + + scores.sort(key=lambda x: (x[0], x[1]), reverse=True) + + if num_states > 1: + children_map = self._build_children_map() + picked, top_scores = [], [] + blocked_ids: set[str] = set() + for entry in scores: + s = entry[2] + if s.id in blocked_ids: + continue + picked.append(s) + top_scores.append(entry) + blocked_ids.update(self._get_full_lineage(s, children_map)) + if len(picked) >= num_states: + break + else: + top_scores = scores[:num_states] + picked = [t[2] for t in top_scores] + + state_id_to_idx = {s.id: i for i, s in enumerate(self._states)} + self._last_sampled_states = picked + self._last_sampled_indices = [state_id_to_idx.get(s.id, -1) for s in picked] + self._last_puct_stats = [(t[3], t[4], t[5], t[6], t[0]) for t in top_scores] + + # Erdős: no construction_length_limits — no refresh (ref no-op for this env) + return picked + + def update_states( + self, + states: list[ErdosRefState], + parent_states: list[ErdosRefState], + save: bool = True, + step: Optional[int] = None, + ) -> None: + if not states: + return + assert len(states) == len(parent_states) + + parent_max: dict[str, float] = {} + parent_obj: dict[str, ErdosRefState] = {} + for child, parent in zip(states, parent_states): + if child.value is None: + continue + pid = parent.id + parent_obj[pid] = parent + parent_max[pid] = max(parent_max.get(pid, float("-inf")), float(child.value)) + + for pid, y in parent_max.items(): + self._m[pid] = max(self._m.get(pid, y), y) + parent = parent_obj[pid] + anc_ids = [pid] + [ + str(p["id"]) for p in (parent.parents or []) if p.get("id") + ] + for aid in anc_ids: + self._n[aid] = self._n.get(aid, 0) + 1 + self._T += 1 + + states, parent_states = self._filter_topk_per_parent( + states, parent_states, self.topk_children + ) + existing = {self._get_construction_key(s) for s in self._states} + existing.discard(None) + + new_states = [] + for child, parent in zip(states, parent_states): + if child.value is None: + continue + if ( + self.max_construction_len is not None + and child.construction + and len(child.construction) > self.max_construction_len + ): + continue + key = self._get_construction_key(child) + if key is not None and key in existing: + continue + self._set_parent_info(child, parent) + new_states.append(child) + if key is not None: + existing.add(key) + + if not new_states: + return + with self._lock: + self._states.extend(new_states) + if save: + self._finalize_and_save(step) + + def _finalize_and_save(self, step: Optional[int] = None) -> None: + if len(self._states) > self.max_buffer_size: + actual_values = [ + s.value if s.value is not None else float("-inf") for s in self._states + ] + by_actual = list(np.argsort(actual_values)[::-1]) + initial_ids = {s.id for s in self._initial_states} + initial_indices = {i for i, s in enumerate(self._states) if s.id in initial_ids} + keep = set(initial_indices) + for i in by_actual: + if len(keep) >= self.max_buffer_size: + break + keep.add(i) + self._states = [self._states[i] for i in sorted(keep)] + if step is not None: + self._current_step = step + self._save(self._current_step) + + def flush(self, step: Optional[int] = None) -> None: + with self._lock: + if self.topk_children > 0: + by_parent: dict[str, list[ErdosRefState]] = {} + no_parent: list[ErdosRefState] = [] + for s in self._states: + pid = s.parents[0]["id"] if s.parents else None + if pid: + by_parent.setdefault(pid, []).append(s) + else: + no_parent.append(s) + filtered = [] + for children in by_parent.values(): + children.sort( + key=lambda x: x.value if x.value is not None else float("-inf"), + reverse=True, + ) + filtered.extend(children[: self.topk_children]) + self._states = no_parent + filtered + self._finalize_and_save(step) + + def record_failed_rollout(self, parent: ErdosRefState) -> None: + anc_ids = [parent.id] + [ + str(p["id"]) for p in (parent.parents or []) if p.get("id") + ] + for aid in anc_ids: + self._n[aid] = self._n.get(aid, 0) + 1 + self._T += 1 + + def get_sample_stats(self) -> dict[str, float]: + def _stats(values, prefix): + arr = np.array([v for v in values if v is not None]) + if len(arr) == 0: + return {} + return { + f"{prefix}/mean": float(np.mean(arr)), + f"{prefix}/std": float(np.std(arr)), + f"{prefix}/min": float(np.min(arr)), + f"{prefix}/max": float(np.max(arr)), + } + + buffer_values = [s.value for s in self._states] + buffer_timesteps = [s.timestep for s in self._states] + buffer_constr_lens = [ + len(s.construction) if s.construction else 0 for s in self._states + ] + sampled_values = [s.value for s in self._last_sampled_states] + sampled_timesteps = [s.timestep for s in self._last_sampled_states] + sampled_constr_lens = [ + len(s.construction) if s.construction else 0 + for s in self._last_sampled_states + ] + stats: dict[str, float] = { + "puct/buffer_size": float(len(self._states)), + "puct/sampled_size": float(len(self._last_sampled_states)), + "puct/T": float(self._T), + "puct/scale_last": float(self._last_scale), + } + stats.update(_stats(buffer_values, "puct/buffer_value")) + stats.update(_stats(buffer_timesteps, "puct/buffer_timestep")) + stats.update(_stats(buffer_constr_lens, "puct/buffer_construction_len")) + stats.update(_stats(sampled_values, "puct/sampled_value")) + stats.update(_stats(sampled_timesteps, "puct/sampled_timestep")) + stats.update(_stats(sampled_constr_lens, "puct/sampled_construction_len")) + return stats From bbed0f0db20fd4d1da0ceca4c51dfc2265648166 Mon Sep 17 00:00:00 2001 From: Morgane Moss Date: Fri, 3 Apr 2026 02:42:19 +0000 Subject: [PATCH 42/48] config: 8 nodes, 16k seq, CP=2, copy erdos_ref_puct_sampler to container --- examples/configs/grpo_erdos_discover.yaml | 22 ++++++++++------------ launch_erdos_120b.sh | 3 ++- 2 files changed, 12 insertions(+), 13 deletions(-) diff --git a/examples/configs/grpo_erdos_discover.yaml b/examples/configs/grpo_erdos_discover.yaml index 0a10662d48..be341627e7 100644 --- a/examples/configs/grpo_erdos_discover.yaml +++ b/examples/configs/grpo_erdos_discover.yaml @@ -1,4 +1,4 @@ -# TTT-Discover Erdős — Nemotron-3-Super-120B-A12B, 4k seq, 8 nodes +# TTT-Discover Erdős — Nemotron-3-Super-120B, 16k seq, 8 nodes, CP=2 defaults: "grpo_superv3.yaml" grpo: @@ -24,7 +24,7 @@ policy: tokenizer: name: "/home/shared/models/NVIDIA-Nemotron-3-Super-120B-A12B-Base-BF16" chat_template_kwargs: null - max_total_sequence_length: 4096 + max_total_sequence_length: 16384 train_global_batch_size: 504 train_micro_batch_size: 1 logprob_batch_size: 1 @@ -35,17 +35,17 @@ policy: resources: num_nodes: 2 gpus_per_node: 8 - max_new_tokens: 3072 + max_new_tokens: 15360 vllm_cfg: async_engine: false tensor_parallel_size: 8 gpu_memory_utilization: 0.85 - max_model_len: 4096 + max_model_len: 16384 megatron_cfg: tensor_model_parallel_size: 4 pipeline_model_parallel_size: 1 - context_parallel_size: 1 + context_parallel_size: 2 expert_model_parallel_size: 8 sequence_parallel: true activation_checkpointing: true @@ -66,14 +66,12 @@ optimizer: data: shuffle: false - num_workers: 0 # PUCT training: dataset must run on driver (Ray select) - max_input_seq_length: 4096 + max_input_seq_length: 16384 env: erdos_discovery: num_initial_states: 16 - # Must match grpo.num_prompts_per_step (run_discover enforces ref-style PUCT batch_size). - puct_seed_batch_size: 8 + num_groups_per_step: 8 sandbox_timeout: 1000 should_use_nemo_gym: false @@ -82,18 +80,18 @@ cluster: num_nodes: 8 logger: - log_dir: "results/erdos-120b" + log_dir: "results/erdos-120b-16k" wandb_enabled: true wandb: project: "ttt-discover-erdos" - name: "nemotron-120b-4k-50steps" + name: "nemotron-120b-16k-8node-puct" tensorboard_enabled: false mlflow_enabled: false swanlab_enabled: false checkpointing: enabled: false - checkpoint_dir: "results/erdos-120b" + checkpoint_dir: "results/erdos-120b-16k" save_period: 999999 checkpoint_must_save_by: null model_save_format: "safetensors" diff --git a/launch_erdos_120b.sh b/launch_erdos_120b.sh index 3245916224..211ffc671e 100755 --- a/launch_erdos_120b.sh +++ b/launch_erdos_120b.sh @@ -33,6 +33,7 @@ SRC=/home/mormio/RL cp \$SRC/nemo_rl/algorithms/entropic_advantage_estimator.py /opt/nemo-rl/nemo_rl/algorithms/ cp \$SRC/nemo_rl/environments/erdos_discovery_environment.py /opt/nemo-rl/nemo_rl/environments/ cp \$SRC/nemo_rl/utils/puct_buffer.py /opt/nemo-rl/nemo_rl/utils/ +cp \$SRC/nemo_rl/environments/erdos_ref_puct_sampler.py /opt/nemo-rl/nemo_rl/environments/ cp \$SRC/examples/run_discover.py /opt/nemo-rl/examples/ cp \$SRC/examples/configs/grpo_erdos_discover.yaml /opt/nemo-rl/examples/configs/ @@ -82,7 +83,7 @@ echo "Submitting Erdős TTT-Discover 120B (8k seq, wandb)..." echo " Container: $CONTAINER" echo " Model: $MODEL_PATH" echo " Nodes: 8 (2 inference + 6 training)" -echo " Seq len: 4096" +echo " Seq len: 16384" echo " Exp: $EXP" COMMAND="$COMMAND" \ From 0adfa3af7cd305fe4b17ff3b6e9933081354b5e3 Mon Sep 17 00:00:00 2001 From: Morgane Moss Date: Fri, 3 Apr 2026 02:48:22 +0000 Subject: [PATCH 43/48] cleanup: config naming, PUCT log dir, remove stale puct_buffer copy --- examples/configs/grpo_erdos_discover.yaml | 4 ++-- launch_erdos_120b.sh | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/examples/configs/grpo_erdos_discover.yaml b/examples/configs/grpo_erdos_discover.yaml index be341627e7..70fc7dd772 100644 --- a/examples/configs/grpo_erdos_discover.yaml +++ b/examples/configs/grpo_erdos_discover.yaml @@ -70,8 +70,8 @@ data: env: erdos_discovery: - num_initial_states: 16 - num_groups_per_step: 8 + num_initial_states: 8 # matches num_prompts_per_step + puct_seed_batch_size: 8 # matches num_prompts_per_step sandbox_timeout: 1000 should_use_nemo_gym: false diff --git a/launch_erdos_120b.sh b/launch_erdos_120b.sh index 211ffc671e..54a5e04a8c 100755 --- a/launch_erdos_120b.sh +++ b/launch_erdos_120b.sh @@ -27,12 +27,12 @@ export HF_HUB_ENABLE_HF_TRANSFER=0 && \ export TORCH_CUDA_ARCH_LIST='9.0 10.0' && \ export NRL_IGNORE_VERSION_MISMATCH=1 && \ export ERDOS_LOG_DIR=/home/mormio/RL/results/erdos_outputs && \ +export ERDOS_PUCT_LOG_DIR=/home/mormio/RL/results/erdos_puct && \ export WANDB_API_KEY=$WANDB_API_KEY && \ SRC=/home/mormio/RL cp \$SRC/nemo_rl/algorithms/entropic_advantage_estimator.py /opt/nemo-rl/nemo_rl/algorithms/ cp \$SRC/nemo_rl/environments/erdos_discovery_environment.py /opt/nemo-rl/nemo_rl/environments/ -cp \$SRC/nemo_rl/utils/puct_buffer.py /opt/nemo-rl/nemo_rl/utils/ cp \$SRC/nemo_rl/environments/erdos_ref_puct_sampler.py /opt/nemo-rl/nemo_rl/environments/ cp \$SRC/examples/run_discover.py /opt/nemo-rl/examples/ cp \$SRC/examples/configs/grpo_erdos_discover.yaml /opt/nemo-rl/examples/configs/ From 67051f2e2793c6773aa69898c31bf184768e7b37 Mon Sep 17 00:00:00 2001 From: Morgane Moss Date: Fri, 3 Apr 2026 14:15:36 +0000 Subject: [PATCH 44/48] fix: point config to instruct model (no Base) --- example.txt | 1 + examples/configs/grpo_erdos_discover.yaml | 4 ++-- keep.py | 3 +++ launch_erdos_120b.sh => launch_scripts/launch_erdos_120b.sh | 2 +- launch_erdos_debug.sh => launch_scripts/launch_erdos_debug.sh | 0 .../launch_erdos_debug_16k.sh | 0 6 files changed, 7 insertions(+), 3 deletions(-) create mode 100644 example.txt create mode 100644 keep.py rename launch_erdos_120b.sh => launch_scripts/launch_erdos_120b.sh (99%) rename launch_erdos_debug.sh => launch_scripts/launch_erdos_debug.sh (100%) rename launch_erdos_debug_16k.sh => launch_scripts/launch_erdos_debug_16k.sh (100%) diff --git a/example.txt b/example.txt new file mode 100644 index 0000000000..5dd01c177f --- /dev/null +++ b/example.txt @@ -0,0 +1 @@ +Hello, world! \ No newline at end of file diff --git a/examples/configs/grpo_erdos_discover.yaml b/examples/configs/grpo_erdos_discover.yaml index 70fc7dd772..1d51b4cde3 100644 --- a/examples/configs/grpo_erdos_discover.yaml +++ b/examples/configs/grpo_erdos_discover.yaml @@ -20,9 +20,9 @@ loss_fn: token_level_loss: false policy: - model_name: "/home/shared/models/NVIDIA-Nemotron-3-Super-120B-A12B-Base-BF16" + model_name: "/home/shared/models/NVIDIA-Nemotron-3-Super-120B-A12B-BF16" tokenizer: - name: "/home/shared/models/NVIDIA-Nemotron-3-Super-120B-A12B-Base-BF16" + name: "/home/shared/models/NVIDIA-Nemotron-3-Super-120B-A12B-BF16" chat_template_kwargs: null max_total_sequence_length: 16384 train_global_batch_size: 504 diff --git a/keep.py b/keep.py new file mode 100644 index 0000000000..1eb95481fd --- /dev/null +++ b/keep.py @@ -0,0 +1,3 @@ + +def test(): + print("Hello") diff --git a/launch_erdos_120b.sh b/launch_scripts/launch_erdos_120b.sh similarity index 99% rename from launch_erdos_120b.sh rename to launch_scripts/launch_erdos_120b.sh index 54a5e04a8c..55f5993aae 100755 --- a/launch_erdos_120b.sh +++ b/launch_scripts/launch_erdos_120b.sh @@ -4,7 +4,7 @@ set -euo pipefail cd /home/mormio/RL CONTAINER="/home/shared/containers/nemo-rl-super-v3.sqsh" -MODEL_PATH="/home/shared/models/NVIDIA-Nemotron-3-Super-120B-A12B-Base-BF16" +MODEL_PATH="/home/shared/models/NVIDIA-Nemotron-3-Super-120B-A12B-BF16" EXP="results/erdos-120b-$(date +%Y%m%d_%H%M)" mkdir -p "$EXP" diff --git a/launch_erdos_debug.sh b/launch_scripts/launch_erdos_debug.sh similarity index 100% rename from launch_erdos_debug.sh rename to launch_scripts/launch_erdos_debug.sh diff --git a/launch_erdos_debug_16k.sh b/launch_scripts/launch_erdos_debug_16k.sh similarity index 100% rename from launch_erdos_debug_16k.sh rename to launch_scripts/launch_erdos_debug_16k.sh From edbbca80ab4972ec78cce01be68dec4ca5006c25 Mon Sep 17 00:00:00 2001 From: Morgane Moss Date: Fri, 3 Apr 2026 14:47:49 +0000 Subject: [PATCH 45/48] script to convert ds to nemo rl/sft format --- scripts/convert_ds_to_nemorl_format.py | 175 +++++++++++++++++++++++++ 1 file changed, 175 insertions(+) create mode 100644 scripts/convert_ds_to_nemorl_format.py diff --git a/scripts/convert_ds_to_nemorl_format.py b/scripts/convert_ds_to_nemorl_format.py new file mode 100644 index 0000000000..3ab3eb8dd1 --- /dev/null +++ b/scripts/convert_ds_to_nemorl_format.py @@ -0,0 +1,175 @@ +"""Convert lambda/hermes-agent-reasoning-traces to NeMo RL JSONL format. + +Initially generated on 2026-04-03 to convert the lambda/hermes-agent-reasoning-traces dataset to a usable format for SFT... + +The dataset uses ShareGPT/Hermes conventions: conversations are stored under +a `conversations` key with `from`/`value` turn dicts, and role names differ +from the OpenAI standard (human -> user, gpt -> assistant). + +Tool definitions are already embedded verbatim in the system message as +... XML, so use tokenizer.chat_template: NULL (passthrough) +when training — the content is already formatted for the Hermes chat template. + +Usage: + # default dataset + uv run scripts/convert_ds_to_nemorl_format.py --output_dir /path/to/output + + # custom dataset + uv run scripts/convert_ds_to_nemorl_format.py --dataset org/my-dataset --output_dir /path/to/output + + +Suggested NeMo RL YAML snippet after running: + tokenizer: + chat_template: NULL + + data: + train: + dataset_name: openai_format + data_path: /path/to/output/train.jsonl + chat_key: messages + tool_key: tools + use_preserving_dataset: true + validation: + dataset_name: openai_format + data_path: /path/to/output/val.jsonl + chat_key: messages + tool_key: tools + use_preserving_dataset: true +""" + +import argparse +import json +import os +import random + +from datasets import load_dataset + +ROLE_MAP = { + "system": "system", + "human": "user", + "gpt": "assistant", + "tool": "tool", +} + + +def convert_sample(sample: dict) -> dict: + messages = [] + for turn in sample["conversations"]: + role = ROLE_MAP.get(turn["from"]) + if role is None: + raise ValueError(f"Unexpected role '{turn['from']}' in sample id={sample.get('id')}") + messages.append({"role": role, "content": turn["value"]}) + + if messages[-1]["role"] != "assistant": + raise ValueError( + f"Last turn must be from the assistant, got '{messages[-1]['role']}' " + f"in sample id={sample.get('id')}" + ) + + result: dict = {"messages": messages} + + # The `tools` column is a JSON string; parse it into a list. + # With passthrough template this isn't used for formatting, but including + # it preserves the data for users who switch to a structured template later. + tools_raw = sample.get("tools") + if tools_raw: + try: + result["tools"] = json.loads(tools_raw) + except (json.JSONDecodeError, TypeError): + pass # drop malformed tools rather than corrupting the sample + + return result + + +def write_jsonl(path: str, samples: list[dict]) -> None: + with open(path, "w") as f: + for sample in samples: + f.write(json.dumps(sample) + "\n") + print(f"Wrote {len(samples):,} samples -> {path}") + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Convert a ShareGPT/Hermes-format HuggingFace dataset to NeMo RL JSONL format." + ) + parser.add_argument( + "--dataset", + default="lambda/hermes-agent-reasoning-traces", + help="HuggingFace dataset name (default: lambda/hermes-agent-reasoning-traces).", + ) + parser.add_argument( + "--output_dir", + required=True, + help="Directory where train.jsonl (and optionally val.jsonl) will be written.", + ) + parser.add_argument( + "--val_split", + type=float, + default=0.05, + help="Fraction of data held out for validation (default: 0.05). Pass 0 to skip.", + ) + parser.add_argument( + "--seed", + type=int, + default=42, + help="Random seed for the train/val shuffle (default: 42).", + ) + args = parser.parse_args() + + os.makedirs(args.output_dir, exist_ok=True) + + print(f"Loading {args.dataset} from HuggingFace...") + ds = load_dataset(args.dataset, split="train") + print(f"Loaded {len(ds):,} samples.") + + samples = [] + n_skipped = 0 + for raw in ds: + try: + samples.append(convert_sample(raw)) + except ValueError as e: + print(f" Skipping sample: {e}") + n_skipped += 1 + + if n_skipped: + print(f"Skipped {n_skipped} malformed samples.") + + if args.val_split > 0: + random.seed(args.seed) + random.shuffle(samples) + n_val = max(1, int(len(samples) * args.val_split)) + val_samples = samples[:n_val] + train_samples = samples[n_val:] + else: + train_samples = samples + val_samples = [] + + write_jsonl(os.path.join(args.output_dir, "train.jsonl"), train_samples) + if val_samples: + write_jsonl(os.path.join(args.output_dir, "val.jsonl"), val_samples) + + print("\nDone. Add this to your sft.yaml (adjust paths as needed):") + train_path = os.path.abspath(os.path.join(args.output_dir, "train.jsonl")) + val_path = os.path.abspath(os.path.join(args.output_dir, "val.jsonl")) + print(f""" + tokenizer: + chat_template: NULL # passthrough — content is already Hermes-formatted + + data: + train: + dataset_name: openai_format + data_path: {train_path} + chat_key: messages + tool_key: tools + use_preserving_dataset: true # tools have heterogeneous argument schemas + validation: + dataset_name: openai_format + data_path: {val_path} + chat_key: messages + tool_key: tools + use_preserving_dataset: true +""") + + +if __name__ == "__main__": + main() From 3c5c16b9fcf9c8b2306e5fd5c1f6b7ec9184cef5 Mon Sep 17 00:00:00 2001 From: Morgane Moss Date: Tue, 7 Apr 2026 19:50:41 +0000 Subject: [PATCH 46/48] cleanup: remove debug configs, shim copies, scratch files, and Gym submodule Removed: - shim/ (old copy of code, unused) - Debug configs and launch scripts (grpo_erdos_debug_16k, launch_erdos_debug*) - grpo_superv3.yaml (base container config, not needed in PR) - Scratch files (example.txt, keep.py, message (3).md, erdos_debug*.slurm) - scripts/convert_ds_to_nemorl_format.py - test_gptoss_vllm.sh - 3rdparty/Gym-workspace/Gym submodule (training runs inline, no Gym server needed) --- 3rdparty/Gym-workspace/Gym | 1 - erdos_debug.slurm | 30 - erdos_debug_container.slurm | 37 - example.txt | 1 - examples/configs/grpo_erdos_debug_16k.yaml | 88 - .../configs/grpo_erdos_discover_debug.yaml | 80 - examples/configs/grpo_superv3.yaml | 489 --- keep.py | 3 - launch_scripts/launch_erdos_debug.sh | 94 - launch_scripts/launch_erdos_debug_16k.sh | 82 - message (3).md | 207 -- scripts/convert_ds_to_nemorl_format.py | 175 - shim/grpo_erdos_discover_debug.yaml | 78 - shim/nemo_rl/__init__.py | 0 shim/nemo_rl/algorithms/__init__.py | 0 .../entropic_advantage_estimator.py | 179 - shim/nemo_rl/algorithms/grpo.py | 3205 ----------------- shim/nemo_rl/environments/__init__.py | 0 .../erdos_discovery_environment.py | 362 -- shim/nemo_rl/environments/utils.py | 139 - shim/nemo_rl/utils/__init__.py | 0 shim/nemo_rl/utils/puct_buffer.py | 561 --- shim/run_discover.py | 349 -- test_gptoss_vllm.sh | 42 - 24 files changed, 6202 deletions(-) delete mode 120000 3rdparty/Gym-workspace/Gym delete mode 100644 erdos_debug.slurm delete mode 100644 erdos_debug_container.slurm delete mode 100644 example.txt delete mode 100644 examples/configs/grpo_erdos_debug_16k.yaml delete mode 100644 examples/configs/grpo_erdos_discover_debug.yaml delete mode 100644 examples/configs/grpo_superv3.yaml delete mode 100644 keep.py delete mode 100755 launch_scripts/launch_erdos_debug.sh delete mode 100755 launch_scripts/launch_erdos_debug_16k.sh delete mode 100644 message (3).md delete mode 100644 scripts/convert_ds_to_nemorl_format.py delete mode 100644 shim/grpo_erdos_discover_debug.yaml delete mode 100644 shim/nemo_rl/__init__.py delete mode 100644 shim/nemo_rl/algorithms/__init__.py delete mode 100644 shim/nemo_rl/algorithms/entropic_advantage_estimator.py delete mode 100644 shim/nemo_rl/algorithms/grpo.py delete mode 100644 shim/nemo_rl/environments/__init__.py delete mode 100644 shim/nemo_rl/environments/erdos_discovery_environment.py delete mode 100644 shim/nemo_rl/environments/utils.py delete mode 100644 shim/nemo_rl/utils/__init__.py delete mode 100644 shim/nemo_rl/utils/puct_buffer.py delete mode 100644 shim/run_discover.py delete mode 100755 test_gptoss_vllm.sh diff --git a/3rdparty/Gym-workspace/Gym b/3rdparty/Gym-workspace/Gym deleted file mode 120000 index 0d1f8dba9b..0000000000 --- a/3rdparty/Gym-workspace/Gym +++ /dev/null @@ -1 +0,0 @@ -/home/mormio/Gym \ No newline at end of file diff --git a/erdos_debug.slurm b/erdos_debug.slurm deleted file mode 100644 index 439fcb1e7f..0000000000 --- a/erdos_debug.slurm +++ /dev/null @@ -1,30 +0,0 @@ -#!/bin/bash -#SBATCH --job-name=erdos-debug -#SBATCH --output=logs/erdos-debug-%j.out -#SBATCH --error=logs/erdos-debug-%j.err -#SBATCH --nodes=1 -#SBATCH --ntasks-per-node=1 -#SBATCH --exclusive -#SBATCH --gpus-per-task=8 -#SBATCH --cpus-per-task=64 -#SBATCH --exclude=d2dfac12-001,d2dfac12-002,d2dfac12-004,d2dfac12-007,d2dfac12-008,d2dfac12-019,d2dfac12-027,d2dfac12-028,d2dfac12-029 - -# TTT-Discover debug run — single node, Qwen3-1.7B, inline reward, 5 steps. -set -eo pipefail - -mkdir -p logs - -echo "Node: $(hostname)" -echo "GPUs: $(nvidia-smi -L | wc -l)" -echo "Job ID: $SLURM_JOB_ID" - -# No Gym server needed — using inline mode for debug. -# The environment computes rewards directly in-process. - -cd /home/mormio/RL - -echo "Starting TTT-Discover GRPO debug training (inline mode)..." -PATH=$HOME/.local/bin:$PATH uv run python examples/run_discover.py --config \ - examples/configs/grpo_erdos_discover_debug.yaml - -echo "Training complete" diff --git a/erdos_debug_container.slurm b/erdos_debug_container.slurm deleted file mode 100644 index f077d7b529..0000000000 --- a/erdos_debug_container.slurm +++ /dev/null @@ -1,37 +0,0 @@ -#!/bin/bash -#SBATCH --job-name=erdos-debug -#SBATCH --output=logs/erdos-debug-%j.out -#SBATCH --error=logs/erdos-debug-%j.err -#SBATCH --nodes=1 -#SBATCH --ntasks-per-node=1 -#SBATCH --exclusive -#SBATCH --gres=gpu:8 -#SBATCH --cpus-per-task=64 -#SBATCH --container-image=nvcr.io#nvidia/nemo-rl:v0.5.0 -#SBATCH --container-writable -#SBATCH --container-mounts=/home/mormio/RL:/home/mormio/RL,/home/shared/models:/home/shared/models -#SBATCH --exclude=d2dfac12-001,d2dfac12-002,d2dfac12-004,d2dfac12-007,d2dfac12-008,d2dfac12-019,d2dfac12-027,d2dfac12-028,d2dfac12-029 - -# TTT-Discover debug — single node, inside NeMo RL container. -# Skips ray.sub complexity. Starts Ray inline and runs training. - -set -eo pipefail - -echo "Node: $(hostname)" -echo "GPUs: $(nvidia-smi -L | wc -l)" -echo "Python: $(python --version)" -echo "Job: $SLURM_JOB_ID" - -# NCCL config -export NCCL_IB_HCA=mlx5_4:1,mlx5_7:1,mlx5_8:1,mlx5_9:1,mlx5_10:1,mlx5_13:1,mlx5_14:1,mlx5_15:1 -export NCCL_SOCKET_IFNAME=bond0 -export UCX_NET_DEVICES=bond0 -export HF_HUB_ENABLE_HF_TRANSFER=0 - -cd /home/mormio/RL - -# Run training directly — Ray will start automatically via init_ray() -uv run python examples/run_discover.py \ - --config examples/configs/grpo_erdos_discover_debug.yaml - -echo "Training complete" diff --git a/example.txt b/example.txt deleted file mode 100644 index 5dd01c177f..0000000000 --- a/example.txt +++ /dev/null @@ -1 +0,0 @@ -Hello, world! \ No newline at end of file diff --git a/examples/configs/grpo_erdos_debug_16k.yaml b/examples/configs/grpo_erdos_debug_16k.yaml deleted file mode 100644 index e02f8e4f5c..0000000000 --- a/examples/configs/grpo_erdos_debug_16k.yaml +++ /dev/null @@ -1,88 +0,0 @@ -# Debug config: Qwen2.5-1.5B, 1 node, 16k seq len, 15 steps -# Purpose: reproduce step 10 hang at small scale -defaults: "grpo_math_1B.yaml" - -grpo: - num_prompts_per_step: 4 - num_generations_per_prompt: 8 - max_num_steps: 15 - max_rollout_turns: 1 - remove_constant_reward_groups: true - val_period: 0 - val_at_start: false - val_at_end: false - adv_estimator: - name: entropic_adaptive_beta - gamma: 0.6931471805599453 - -loss_fn: - kl_penalty_coef: 0.1 - ratio_clip: 0.2 - token_level_loss: false - -policy: - model_name: "Qwen/Qwen2.5-1.5B-Instruct" - tokenizer: - name: "Qwen/Qwen2.5-1.5B-Instruct" - chat_template_kwargs: null - max_total_sequence_length: 4096 - train_global_batch_size: 32 - train_micro_batch_size: 4 - dtensor_cfg: - enabled: true - tensor_parallel_size: 1 - sequence_parallel: false - cpu_offload: false - activation_checkpointing: true - lora_cfg: - enabled: true - rank: 16 - alpha: 1.0 - dropout: 0.0 - generation: - backend: "vllm" - max_new_tokens: 3072 - temperature: 1.0 - top_p: 1.0 - stop_token_ids: null - stop_strings: null - vllm_cfg: - async_engine: false - tensor_parallel_size: 1 - pipeline_parallel_size: 1 - expert_parallel_size: 1 - gpu_memory_utilization: 0.6 - max_model_len: 4096 - dynamic_batching: - enabled: false - -optimizer: - name: adamw - lr: 1.0e-4 - -data: - shuffle: false - -env: - erdos_discovery: - num_initial_states: 8 - puct_seed_batch_size: 4 - sandbox_timeout: 120 - should_use_nemo_gym: false - -cluster: - gpus_per_node: 8 - num_nodes: 1 - -logger: - log_dir: "logs/erdos-debug-16k" - wandb_enabled: false - tensorboard_enabled: false - mlflow_enabled: false - swanlab_enabled: false - -checkpointing: - enabled: false - checkpoint_dir: "logs/erdos-debug-16k" - save_period: 999999 - checkpoint_must_save_by: null diff --git a/examples/configs/grpo_erdos_discover_debug.yaml b/examples/configs/grpo_erdos_discover_debug.yaml deleted file mode 100644 index ae47eb6ae8..0000000000 --- a/examples/configs/grpo_erdos_discover_debug.yaml +++ /dev/null @@ -1,80 +0,0 @@ -# TTT-Discover DEBUG config. -# Inherits from grpo_math_1B.yaml for all defaults. -# Overrides: entropic advantages, inline reward, small batch, 5 steps. -defaults: "grpo_math_1B.yaml" - -grpo: - num_prompts_per_step: 4 - num_generations_per_prompt: 8 - max_num_epochs: 1 - max_num_steps: 5 - max_rollout_turns: 1 - remove_constant_reward_groups: true - adv_estimator: - name: entropic_adaptive_beta - gamma: 0.6931471805599453 - -loss_fn: - kl_penalty_coef: 0.1 - ratio_clip: 0.2 - token_level_loss: false - -policy: - model_name: "Qwen/Qwen2.5-1.5B-Instruct" - max_total_sequence_length: 4096 - train_global_batch_size: 32 - train_micro_batch_size: 4 - - dtensor_cfg: - enabled: true - cpu_offload: false - activation_checkpointing: true - sequence_parallel: false - - lora_cfg: - enabled: true - rank: 16 - alpha: 1.0 - dropout: 0.0 - - dynamic_batching: - enabled: false - - generation: - backend: "vllm" - max_new_tokens: 2048 - temperature: 1.0 - top_p: 1.0 - stop_token_ids: null - stop_strings: null - vllm_cfg: - async_engine: false - tensor_parallel_size: 1 - pipeline_parallel_size: 1 - expert_parallel_size: 1 - gpu_memory_utilization: 0.6 - max_model_len: ${policy.max_total_sequence_length} - -optimizer: - name: adamw - lr: 1.0e-4 - -data: - shuffle: false - num_workers: 0 # required: PUCT dataset calls Ray from the iterator process - -env: - erdos_discovery: - num_initial_states: 8 - puct_seed_batch_size: 4 - sandbox_timeout: 120 - -checkpointing: - enabled: false - -logger: - log_dir: "logs/erdos-debug" - wandb_enabled: false - tensorboard_enabled: false - mlflow_enabled: false - swanlab_enabled: false diff --git a/examples/configs/grpo_superv3.yaml b/examples/configs/grpo_superv3.yaml deleted file mode 100644 index 5cc06be7dd..0000000000 --- a/examples/configs/grpo_superv3.yaml +++ /dev/null @@ -1,489 +0,0 @@ -checkpointing: - enabled: true - checkpoint_dir: "results/grpo" - metric_name: "val:total_reward/mean" - higher_is_better: true - keep_top_k: 1000000 - save_period: 10 - checkpoint_must_save_by: "00:03:30:00" - model_save_format: "safetensors" - save_consolidated: false - -grpo: - num_prompts_per_step: 128 - num_generations_per_prompt: 16 - num_val_generations_per_prompt: 2 - max_rollout_turns: 1 - max_num_epochs: 1 - max_num_steps: 1000000 - normalize_rewards: true - use_leave_one_out_baseline: true - # Clipping bounds for normalized advantages to prevent extreme values from small std - # Set to null to disable clipping (default), or e.g. -100/100 to clip - advantage_clip_low: null - advantage_clip_high: null - val_period: 5 - val_at_start: false - val_at_end: false - overlong_filtering: false - max_val_samples: null - val_batch_size: 256 - seed: 42 - - use_dynamic_sampling: false - dynamic_sampling_max_gen_batches: 10 - batch_multiplier: 1 - - penalize_invalid_tool_call: true - invalid_tool_call_advantage: -5.0 - penalize_malformed_thinking: true - malformed_thinking_advantage: -5.0 - - reward_shaping: - enabled: false - overlong_buffer_length: 128 - overlong_buffer_penalty: 1 - max_response_length: ${policy.max_total_sequence_length} - stop_properly_penalty_coef: null - reward_scaling: - enabled: false - source_min: 0.0 - source_max: 1.0 - target_min: 0.0 - target_max: 1.0 - - async_grpo: - enabled: false - max_trajectory_age_steps: 1 - in_flight_weight_updates: false - recompute_kv_cache_after_weight_updates: false - - seq_logprob_error_threshold: 2 - -loss_fn: - reference_policy_kl_penalty: 0.0 - reference_policy_kl_type: "k3" - kl_input_clamp_value: null - kl_output_clamp_value: null - - ratio_clip_min: 0.2 - ratio_clip_max: 0.28 - ratio_clip_c: null - use_on_policy_kl_approximation: true - use_importance_sampling_correction: true - truncated_importance_sampling_ratio: null - truncated_importance_sampling_ratio_min: null - truncated_importance_sampling_type: tis - sequence_level_importance_ratios: false - token_level_loss: true - force_on_policy_ratio: false - use_kl_in_reward: false - -policy: - model_name: "/lustre/fsw/portfolios/llmservice/projects/llmservice_nemotron_nano/users/pjin/checkpoints/nano-v3-sft-64gbs-nickel-capybara-5e-5-constant-wd-0-load-bal-1e-4-lcx3-pretool-base-temp1-iter-0013600-hf" - tokenizer: - name: ${policy.model_name} - chat_template_kwargs: null - hf_config_overrides: {} - train_global_batch_size: 2048 - train_micro_batch_size: 1 - generation_batch_size: 64 - logprob_batch_size: 1 - max_total_sequence_length: 16384 - precision: "bfloat16" - logprob_chunk_size: 2048 - offload_optimizer_for_logprob: false - - dtensor_cfg: - _v2: true - enabled: false - cpu_offload: False - sequence_parallel: false - activation_checkpointing: false - tensor_parallel_size: 1 - context_parallel_size: 1 - custom_parallel_plan: null - - megatron_cfg: - enabled: true - empty_unused_memory_level: 1 - activation_checkpointing: true - tensor_model_parallel_size: 2 - expert_tensor_parallel_size: 1 - expert_model_parallel_size: 8 - pipeline_model_parallel_size: 2 - num_layers_in_first_pipeline_stage: null - num_layers_in_last_pipeline_stage: null - context_parallel_size: 4 - pipeline_dtype: ${policy.precision} - sequence_parallel: true - freeze_moe_router: true - moe_router_dtype: "fp32" - moe_router_load_balancing_type: "none" - moe_router_bias_update_rate: 1.0e-3 - moe_permute_fusion: true - moe_enable_deepep: false - moe_token_dispatcher_type: "alltoall" - moe_aux_loss_coeff: 0.0 - moe_router_enable_expert_bias: true - moe_shared_expert_overlap: false - apply_rope_fusion: True - bias_activation_fusion: False - defer_fp32_logits: True - moe_per_layer_logging: True - - mtp_loss_scaling_factor: 0.0 - mtp_use_repeated_layer: true - mtp_num_layers: 0 - mtp_detach_heads: true - - optimizer: - optimizer: "adam" - lr: 3.0e-6 - min_lr: 3.0e-6 - weight_decay: 0.0 - bf16: true - fp16: false - params_dtype: "float32" - - adam_beta1: 0.9 - adam_beta2: 0.999 - adam_eps: 1e-8 - - sgd_momentum: 0.9 - - use_distributed_optimizer: true - use_precision_aware_optimizer: true - - clip_grad: ${policy.max_grad_norm} - - optimizer_cpu_offload: false - optimizer_offload_fraction: 0.0 - - scheduler: - start_weight_decay: ${policy.megatron_cfg.optimizer.weight_decay} - end_weight_decay: ${policy.megatron_cfg.optimizer.weight_decay} - weight_decay_incr_style: "constant" - lr_decay_style: "constant" - lr_decay_iters: null - lr_warmup_iters: 10 - lr_warmup_init: 3e-7 - - distributed_data_parallel_config: - grad_reduce_in_fp32: false - overlap_grad_reduce: true - overlap_param_gather: true - use_custom_fsdp: false - data_parallel_sharding_strategy: "optim_grads_params" - - fp8_cfg: - enabled: false - fp8: "e4m3" - fp8_recipe: "blockwise" - fp8_param: false - - env_vars: null - - dynamic_batching: - enabled: False - train_mb_tokens: ${mul:${policy.max_total_sequence_length}, ${policy.train_micro_batch_size}} - logprob_mb_tokens: ${mul:${policy.max_total_sequence_length}, ${policy.logprob_batch_size}} - sequence_length_round: 64 - - sequence_packing: - enabled: True - train_mb_tokens: ${mul:${policy.max_total_sequence_length}, ${policy.train_micro_batch_size}} - logprob_mb_tokens: ${mul:${policy.max_total_sequence_length}, ${policy.logprob_batch_size}} - algorithm: "modified_first_fit_decreasing" - sequence_length_round: 64 - - make_sequence_length_divisible_by: ${policy.megatron_cfg.tensor_model_parallel_size} - max_grad_norm: 1.0 - - optimizer: null - scheduler: null - - generation: - port_range_low: 11001 - port_range_high: 15000 - backend: "vllm" - max_new_tokens: ${policy.max_total_sequence_length} - temperature: 1.0 - top_p: 1.0 - top_k: null - stop_token_ids: null - stop_strings: null - vllm_cfg: - async_engine: true - precision: ${policy.precision} - kv_cache_dtype: "auto" - tensor_parallel_size: 4 - pipeline_parallel_size: 1 - expert_parallel_size: 1 - gpu_memory_utilization: 0.5 - max_model_len: ${policy.max_total_sequence_length} - enforce_eager: False - use_deep_gemm: False - num_last_layers_in_bf16: 0 - num_first_layers_in_bf16: 0 - enable_vllm_metrics_logger: true - vllm_metrics_logger_interval: 0.5 - expose_http_server: true - http_server_serving_chat_kwargs: - enable_auto_tools: true - tool_parser: qwen3_coder - reasoning_parser: nano_v3 - reasoning_parser_plugin: nemo_rl/utils/nano_v3_reasoning_parser.py - - - vllm_kwargs: - mamba_ssm_cache_dtype: "float32" -# compilation_config: -# mode: 0 - colocated: - enabled: true - resources: - gpus_per_node: null - num_nodes: null - -data: - max_input_seq_length: null - shuffle: false - num_workers: 1 - train: - data_path: "/lustre/fsw/portfolios/llmservice/projects/llmservice_nemotron_nano/users/pjin/data/nano-v3-posttraining-data/curriculum_v7_acrid-teal_main_rename.train.jsonl" - validation: - data_path: "/lustre/fsw/portfolios/llmservice/projects/llmservice_nemotron_nano/users/pjin/data/nano-v3-posttraining-data/curriculum_v7_acrid-teal_main_rename.val.jsonl" - default: - dataset_name: NemoGymDataset - env_name: "nemo_gym" - prompt_file: null - system_prompt_file: null - processor: "nemo_gym_data_processor" - -env: - should_use_nemo_gym: true - use_genrm_compare: true - genrm_agent_names: - - "genrm_simple_agent" - - "genrm_simple_agent_reasoning_off" - genrm_compare_server_name: "genrm_compare" - nemo_gym: - num_gpu_nodes: 4 - port_range_low: 15001 - port_range_high: 20000 - invalid_tool_call_patterns: - - "" - - "" - - "" - - "" - thinking_tags: - - "" - - "" - config_paths: - - responses_api_models/vllm_model/configs/vllm_model_for_training.yaml - - resources_servers/math_with_judge/configs/math_with_judge.yaml - - resources_servers/code_gen/configs/code_gen.yaml - - resources_servers/workplace_assistant/configs/workplace_assistant.yaml - - resources_servers/mcqa/configs/mcqa.yaml - - resources_servers/instruction_following/configs/instruction_following.yaml - - resources_servers/structured_outputs/configs/structured_outputs_json.yaml - - resources_servers/equivalence_llm_judge/configs/lc_judge.yaml - - resources_servers/calendar/configs/calendar.yaml - - resources_servers/genrm_compare/configs/genrm_compare.yaml - - resources_servers/equivalence_llm_judge/configs/nl2bash-equivalency.yaml - - resources_servers/equivalence_llm_judge/configs/equivalence_llm_judge.yaml - - resources_servers/single_step_tool_use_with_argument_comparison/configs/single_step_tool_use_with_argument_comparison.yaml - - resources_servers/reasoning_gym/configs/reasoning_gym.yaml - - resources_servers/terminal_pivot/configs/terminal_pivot.yaml - - resources_servers/ns_tools/configs/ns_tools.yaml - - resources_servers/math_formal_lean/configs/math_formal_lean_multi_turn.yaml - - resources_servers/swerl_gen/configs/swerl_gen.yaml - - resources_servers/jailbreak_detection/configs/jailbreak_detection_nemotron_combined_reward_tp8.yaml - - resources_servers/over_refusal_detection/configs/over_refusal_detection_nemotron_tp8.yaml - - resources_servers/multichallenge/configs/multichallenge.yaml - - resources_servers/inverse_if/configs/inverse_if.yaml - - resources_servers/single_step_tool_use_with_argument_comparison/configs/search_pivot_single_step_tool_use_with_argument_comparison.yaml - - resources_servers/single_step_tool_use_with_argument_comparison/configs/toolcall_schema_single_step_tool_use_with_argument_comparison.yaml - - jailbreak_detection: - resources_servers: - jailbreak_detection: - judge_model_server: - type: responses_api_models - name: safety_judge_model - - safety_judge_model: - responses_api_models: - vllm_model: - entrypoint: app.py - base_url: http://127.0.0.1:8001/v1 - api_key: dummy_key - model: /scratch/fsw/portfolios/llmservice/users/makeshn/super_v3/model_checkpoints/Nemotron-Content-Safety-Reasoning-4B - return_token_id_information: false - uses_reasoning_parser: false - spinup_server: true - router_dp_size: 8 - server_args: - tensor_parallel_size: 1 - gpu_memory_utilization: 0.85 - max_model_len: 96000 - model_loader_extra_config: - enable_multithread_load: true - num_threads: 2 - server_env: - VLLM_ATTENTION_BACKEND: TRITON_ATTN - - terminal_pivot_simple_agent: - responses_api_agents: - simple_agent: - model_server: - name: policy_model - - nl2bash_judge_model: - responses_api_models: - vllm_model: - entrypoint: app.py - base_url: http://127.0.0.1:10000/v1 - api_key: dummy_key - model: "/scratch/fsw/portfolios/llmservice/users/jiaqiz/models/Qwen3-235B-A22B-Instruct-2507-FP8" - return_token_id_information: False - uses_reasoning_parser: False - spinup_server: True - router_dp_size: 2 - server_args: - tensor_parallel_size: 8 - data_parallel_size: 1 - enable_expert_parallel: True - enable_auto_tool_choice: true - tool_call_parser: hermes - gpu_memory_utilization: 0.85 - max_model_len: 131072 - model_loader_extra_config: - enable_multithread_load: true - num_threads: 112 - - inverse_if: - resources_servers: - inverse_if: - judge_model_server: - type: responses_api_models - name: nl2bash_judge_model - - multichallenge: - resources_servers: - multichallenge: - judge_model_server: - type: responses_api_models - name: nl2bash_judge_model - judge_responses_create_params: - max_output_tokens: 8192 - - equivalence_llm_judge: - resources_servers: - equivalence_llm_judge: - judge_model_server: - name: nl2bash_judge_model - judge_responses_create_params: - max_output_tokens: 8192 - - genrm_compare: - resources_servers: - genrm_compare: - # Points to the GenRM model server defined above - genrm_model_server: - type: responses_api_models - name: genrm_model - # GenRM request parameters - genrm_responses_create_params: - max_output_tokens: 16384 - temperature: 0.6 - top_p: 0.95 - # Comparison settings - comparison_strategy: "circular" - num_judges_per_comparison: 1 - use_principle: true - default_principle: "Please act as an impartial judge and evaluate the quality of the responses provided by two AI assistants to the user prompt. Begin your evaluation by generating your own answer to the prompt. You must provide your answer before judging any answers. When evaluating the assistants' answers, compare both assistants' answers with your answer. You must identify and correct any mistakes or inaccurate information. Then consider if the assistant's answers are helpful, relevant, and concise. Helpful means the answer correctly responds to the prompt or follows the instructions. Note when user prompt has any ambiguity or more than one interpretation, it is more helpful and appropriate to ask for clarifications or more information from the user than providing an answer based on assumptions. Relevant means all parts of the response closely connect or are appropriate to what is being asked. Concise means the response is clear and not verbose or excessive. Then consider the creativity and novelty of the assistant's answers when needed. Finally, identify any missing important information in the assistants' answers that would be beneficial to include when responding to the user prompt." - aggregator_method: "simple_tiebreaker" - reasoning_bonus: 0.5 - answer_bonus: 0.5 - top_percentile: 0.2 - group_reasoning_length_penalty_coeff: 0.1 - group_answer_length_penalty_coeff: 0.1 - group_style_penalty_coeff: 0.1 - default_score: 3.0 - default_ranking: 3.5 - - genrm_model: - responses_api_models: - vllm_model: - entrypoint: app.py - base_url: http://127.0.0.1:8000/v1 - api_key: dummy_key - model: "/lustre/fsw/portfolios/llmservice/users/jiaqiz/models/qwen235b_principle_comparison_genrm_step1230" - uses_reasoning_parser: True - return_token_id_information: False - spinup_server: True - router_dp_size: 4 - server_args: - tensor_parallel_size: 8 - reasoning_parser: deepseek_r1 - gpu_memory_utilization: 0.85 - max_model_len: 60000 - model_loader_extra_config: - enable_multithread_load: true - num_threads: 112 - - - lc_judge: - resources_servers: - equivalence_llm_judge: - judge_model_server: - name: nl2bash_judge_model - judge_responses_create_params: - max_output_tokens: 8192 - - math_with_judge: - resources_servers: - math_with_judge: - judge_model_server: - name: nl2bash_judge_model - judge_responses_create_params: - max_output_tokens: 8192 - should_use_judge: true - code_gen: - resources_servers: - code_gen: - num_processes: 1024 - unit_test_timeout_secs: 10 - debug: false - -logger: - log_dir: "logs" - num_val_samples_to_print: 0 - wandb_enabled: false - tensorboard_enabled: false - mlflow_enabled: false - monitor_gpus: true - swanlab_enabled: false - wandb: - project: "grpo-dev" - name: "grpo-dev-logger" - tensorboard: {} - mlflow: - experiment_name: "grpo-dev" - run_name: "grpo-dev-logger" - gpu_monitoring: - collection_interval: 10 - flush_interval: 10 - -cluster: - gpus_per_node: 8 - num_nodes: 1 - -# uncomment to enable effort level training -effort_levels: - low_string: "{reasoning effort: low}" - low_weight: 0.1 - low_penalty: 1 - low_ub: 3000 diff --git a/keep.py b/keep.py deleted file mode 100644 index 1eb95481fd..0000000000 --- a/keep.py +++ /dev/null @@ -1,3 +0,0 @@ - -def test(): - print("Hello") diff --git a/launch_scripts/launch_erdos_debug.sh b/launch_scripts/launch_erdos_debug.sh deleted file mode 100755 index 10e98e65e8..0000000000 --- a/launch_scripts/launch_erdos_debug.sh +++ /dev/null @@ -1,94 +0,0 @@ -#!/bin/bash -# TTT-Discover Erdős GRPO Debug — FINAL -# Only adds our new files to the container's /opt/nemo-rl, does NOT overwrite existing code -set -euo pipefail -cd /home/mormio/RL - -CONTAINER="nvcr.io#nvidia/nemo-rl:v0.5.0" -EXP="results/erdos-debug-$(date +%Y%m%d_%H%M)" -mkdir -p "$EXP" - -MOUNTS="$PWD:/home/mormio/RL,/home/shared/models:/home/shared/models" - -# Only add NEW files to the container's /opt/nemo-rl -# Do NOT overwrite existing files (they match the container's deps) -COMMAND=' -export HF_HUB_ENABLE_HF_TRANSFER=0 -export TORCH_CUDA_ARCH_LIST="9.0 10.0" -export NRL_IGNORE_VERSION_MISMATCH=1 - -SRC=/home/mormio/RL - -# Add only our NEW files (not overwriting anything) -cp $SRC/nemo_rl/algorithms/entropic_advantage_estimator.py /opt/nemo-rl/nemo_rl/algorithms/ -cp $SRC/nemo_rl/environments/erdos_discovery_environment.py /opt/nemo-rl/nemo_rl/environments/ -cp $SRC/nemo_rl/utils/puct_buffer.py /opt/nemo-rl/nemo_rl/utils/ -cp $SRC/examples/run_discover.py /opt/nemo-rl/examples/ -cp $SRC/examples/configs/grpo_erdos_discover_debug.yaml /opt/nemo-rl/examples/configs/ - -# Patch the container grpo.py to register our entropic estimator -# (append the elif branch to _create_advantage_estimator) -python -c " -path = \"/opt/nemo-rl/nemo_rl/algorithms/grpo.py\" -with open(path) as f: - content = f.read() -if \"entropic_adaptive_beta\" not in content: - old = \" else:\\n raise ValueError(f\\\"Invalid adv_estimator name: {adv_estimator_name}\\\")\\n\\n return adv_estimator\" - new = \"\"\" elif adv_estimator_name == \\\"entropic_adaptive_beta\\\": - from nemo_rl.algorithms.entropic_advantage_estimator import ( - EntropicAdaptiveBetaAdvantageEstimator, - ) - adv_estimator = EntropicAdaptiveBetaAdvantageEstimator( - adv_estimator_config, loss_config - ) - print(\\\" Using Entropic Adaptive-Beta advantage estimator (TTT-Discover)\\\") - else: - raise ValueError(f\\\"Invalid adv_estimator name: {adv_estimator_name}\\\") - - return adv_estimator\"\"\" - content = content.replace(old, new) - with open(path, \"w\") as f: - f.write(content) - print(\"Patched grpo.py with entropic_adaptive_beta\") -else: - print(\"grpo.py already patched\") -" - -# Patch environments/utils.py to register erdos_discovery -python -c " -path = \"/opt/nemo-rl/nemo_rl/environments/utils.py\" -with open(path) as f: - content = f.read() -if \"erdos_discovery\" not in content: - content = content.replace( - \"\\\"nemo_gym\\\": {\", - \"\\\"erdos_discovery\\\": {\\n \\\"actor_class_fqn\\\": \\\"nemo_rl.environments.erdos_discovery_environment.ErdosDiscoveryEnvironment\\\",\\n },\\n \\\"nemo_gym\\\": {\" - ) - with open(path, \"w\") as f: - f.write(content) - print(\"Patched utils.py with erdos_discovery\") -else: - print(\"utils.py already patched\") -" - -cd /opt/nemo-rl -python examples/run_discover.py \ - --config examples/configs/grpo_erdos_discover_debug.yaml -' - -echo "Submitting Erdős TTT-Discover debug..." -echo "Experiment dir: $EXP" - -COMMAND="$COMMAND" \ -CONTAINER="$CONTAINER" \ -MOUNTS="$MOUNTS" \ -GPUS_PER_NODE=8 \ -sbatch \ - --nodes=2 --partition=batch --exclusive \ - --job-name=erdos-debug --time=01:00:00 \ - --output="$EXP/slurm-%j.out" \ - --error="$EXP/slurm-%j.err" \ - --exclude=d2dfac12-001,d2dfac12-002,d2dfac12-004,d2dfac12-007,d2dfac12-008,d2dfac12-019,d2dfac12-027,d2dfac12-028,d2dfac12-029 \ - ray.sub - -echo "Logs: $EXP/" diff --git a/launch_scripts/launch_erdos_debug_16k.sh b/launch_scripts/launch_erdos_debug_16k.sh deleted file mode 100755 index 360ce445db..0000000000 --- a/launch_scripts/launch_erdos_debug_16k.sh +++ /dev/null @@ -1,82 +0,0 @@ -#!/bin/bash -# Debug: Qwen2.5-1.5B, 1 node, 16k, 15 steps — test for step 10 hang -set -euo pipefail -cd /home/mormio/RL - -CONTAINER="nvcr.io#nvidia/nemo-rl:v0.5.0" -EXP="results/erdos-debug-16k-$(date +%Y%m%d_%H%M)" -mkdir -p "$EXP" - -MOUNTS="$PWD:$PWD,/home/shared/models:/home/shared/models" - -COMMAND=" -export HF_HUB_ENABLE_HF_TRANSFER=0 -export TORCH_CUDA_ARCH_LIST='9.0 10.0' -export NRL_IGNORE_VERSION_MISMATCH=1 -export ERDOS_LOG_DIR=/home/mormio/RL/results/erdos_debug_outputs - -SRC=/home/mormio/RL -cp \$SRC/nemo_rl/algorithms/entropic_advantage_estimator.py /opt/nemo-rl/nemo_rl/algorithms/ -cp \$SRC/nemo_rl/environments/erdos_discovery_environment.py /opt/nemo-rl/nemo_rl/environments/ -cp \$SRC/nemo_rl/utils/puct_buffer.py /opt/nemo-rl/nemo_rl/utils/ -cp \$SRC/examples/run_discover.py /opt/nemo-rl/examples/ -cp \$SRC/examples/configs/grpo_erdos_debug_16k.yaml /opt/nemo-rl/examples/configs/ - -python -c \" -path = '/opt/nemo-rl/nemo_rl/algorithms/grpo.py' -with open(path) as f: - content = f.read() -if 'entropic_adaptive_beta' not in content: - old = ' else:\\n raise ValueError(f\\\"Invalid adv_estimator name: {adv_estimator_name}\\\")\\n\\n return adv_estimator' - new = ''' elif adv_estimator_name == \\\"entropic_adaptive_beta\\\": - from nemo_rl.algorithms.entropic_advantage_estimator import ( - EntropicAdaptiveBetaAdvantageEstimator, - ) - adv_estimator = EntropicAdaptiveBetaAdvantageEstimator( - adv_estimator_config, loss_config - ) - print(\\\" Using Entropic Adaptive-Beta advantage estimator (TTT-Discover)\\\") - else: - raise ValueError(f\\\"Invalid adv_estimator name: {adv_estimator_name}\\\") - - return adv_estimator''' - content = content.replace(old, new) - with open(path, 'w') as f: - f.write(content) - print('Patched grpo.py') -\" && \ - -python -c \" -path = '/opt/nemo-rl/nemo_rl/environments/utils.py' -with open(path) as f: - content = f.read() -if 'erdos_discovery' not in content: - content = content.replace( - '\\\"nemo_gym\\\": {', - '\\\"erdos_discovery\\\": {\\n \\\"actor_class_fqn\\\": \\\"nemo_rl.environments.erdos_discovery_environment.ErdosDiscoveryEnvironment\\\",\\n },\\n \\\"nemo_gym\\\": {' - ) - with open(path, 'w') as f: - f.write(content) - print('Patched utils.py') -\" && \ - -cd /opt/nemo-rl -python examples/run_discover.py \ - --config examples/configs/grpo_erdos_debug_16k.yaml -" - -echo "Launching debug: Qwen2.5-1.5B, 1 node, 16k, 15 steps" - -COMMAND="$COMMAND" \ -CONTAINER="$CONTAINER" \ -MOUNTS="$MOUNTS" \ -GPUS_PER_NODE=8 \ -sbatch \ - --nodes=1 --partition=batch --exclusive \ - --job-name=erdos-debug-16k --time=02:00:00 \ - --output="$EXP/slurm-%j.out" \ - --error="$EXP/slurm-%j.err" \ - --exclude=d2dfac12-001,d2dfac12-002,d2dfac12-004,d2dfac12-007,d2dfac12-008,d2dfac12-019,d2dfac12-027,d2dfac12-028,d2dfac12-029 \ - ray.sub - -echo "Logs: $EXP/" diff --git a/message (3).md b/message (3).md deleted file mode 100644 index 2e435dd893..0000000000 --- a/message (3).md +++ /dev/null @@ -1,207 +0,0 @@ -# NeMo RL GRPO on Our Cluster — What It Actually Took - -## The Ask -Run the NVIDIA NeMo RL GRPO tutorial on our B200 cluster. - -## The Cluster -- 32 nodes, 8x B200 (183GB) per node -- Slurm + Pyxis/Enroot for containers -- Shared home directory (NFS), no /scratch -- InfiniBand networking (mlx5 HCAs, bond0) - -## What Worked Immediately -- Cloning the repo, downloading models from HuggingFace -- The `ray.sub` script for orchestrating Ray clusters via Slurm (with patches) - -## What Needed Fixing - -### Cluster-Specific Patches to ray.sub -Every run needed these two fixes to `ray.sub`: -```bash -# 1. MPI plugin: cluster has pmi2, not pmix -sed -i 's/--mpi=pmix/--mpi=pmi2/' ray.sub - -# 2. Container filesystem must be writable (ray writes launch scripts to /) -sed -i '/--no-container-mount-home/a COMMON_SRUN_ARGS+=" --container-writable"' ray.sub -``` - -Also needed to remove `--no-container-mount-home` so the shared home dir -(with model conversion cache) is visible across all nodes. - -### NGC Container Authentication -Pyxis/enroot needs NGC credentials to pull containers: -```bash -# ~/.config/enroot/.credentials -machine nvcr.io login $oauthtoken password -``` -The image URI format for Pyxis is `nvcr.io#nvidia/nemo-rl:v0.5.0` (note the `#`). - -### NCCL / InfiniBand Configuration -Multi-node training was getting `Network is unreachable` NCCL errors until -we added the cluster's IB config (copied from our existing torchtitan slurm scripts): -```bash -export NCCL_IB_HCA=mlx5_4:1,mlx5_7:1,mlx5_8:1,mlx5_9:1,mlx5_10:1,mlx5_13:1,mlx5_14:1,mlx5_15:1 -export NCCL_SOCKET_IFNAME=bond0 -export UCX_NET_DEVICES=bond0 -export NCCL_BUFFSIZE=33554432 -export NCCL_IB_AR_THRESHOLD=0 -export NCCL_IB_PCI_RELAXED_ORDERING=1 -# ... etc -``` -These must be set in `ray.sub` (not just the COMMAND) so every node gets them. - -### HF Transfer -The v0.5.0 container sets `HF_HUB_ENABLE_HF_TRANSFER=1` but the package -isn't installed. Set `HF_HUB_ENABLE_HF_TRANSFER=0` in the launch command. - -## The Llama 8B Math Run (worked quickly) -**Container**: `nvcr.io/nvidia/nemo-rl:v0.5.0` -**Branch**: `v0.5.0` tag -**Config**: `examples/configs/grpo_math_8B_megatron.yaml` -**Script**: `examples/run_grpo_math.py` - -This used OpenMathInstruct-2 (auto-downloads), single node, colocated generation. -Worked after fixing the ray.sub patches above and enabling W&B -(`++logger.wandb_enabled=true`). The base config has `wandb_enabled: false`. - -## The Workplace Assistant Tutorial (abandoned) -The original tutorial targets Nemotron Nano 9B v2 with the Workplace Assistant -NeMo Gym environment. This was a nightmare: - -- **v0.5.0 container + v0.5.0 code**: Chat template tokenization assertion errors - (`non-monotonically increasing trajectory`). The Nemotron Nano v2 tokenizer - handles multi-turn tool-calling conversations in a way that breaks the - `_replace_prefix_tokens` function during multi-step rollouts. -- **nano-v3 branch + v0.4.0.nemotron_3_nano container**: The `nemotron_json` - tool parser wasn't registered in the container's vLLM. -- The tutorial's `sed` commands to patch the chat template are insufficient. - The real fix exists only on the `nano-v3` branch which removes the assertion entirely. - -**Lesson**: The Workplace Assistant environment is tightly coupled to specific -branch/container combos. Use any other environment instead. - -## The Nemotron 3 Super 120B Run (what finally worked) - -### Container Build -The `super-v3` branch requires a custom container build because it uses a -patched vLLM for the NemotronH MoE architecture: - -```bash -# On a compute node (docker access required): -docker buildx build \ - --build-context nemo-rl=. \ - --build-arg SKIP_SGLANG_BUILD=1 \ - --build-arg BUILD_CUSTOM_VLLM=1 \ - -f docker/Dockerfile \ - --tag nemo-rl-super:v3 --load . - -# Convert to sqsh for Pyxis: -sudo enroot import -o nemo-rl-super-v3.sqsh "dockerd://nemo-rl-super:v3" -``` - -We had to install `docker-buildx` first (not available on the cluster by default). - -### HF→Megatron Model Conversion -First run converts the HuggingFace checkpoint to Megatron format (~231GB). -This is cached at `~/.cache/huggingface/nemo_rl/model__/`. -**The home dir must be mounted in the container** for this cache to be shared -across nodes. Previous runs with `--no-container-mount-home` caused the -conversion to succeed on the head node but be invisible to training nodes. - -### Chat Template -The base model (`NVIDIA-Nemotron-3-Super-120B-A12B-Base-BF16`) has no chat -template. The `math_hf_data_processor` calls `tokenizer.apply_chat_template()`, -which crashes. We added a minimal one: - -```python -data["chat_template"] = "{% for message in messages %}..." -``` - -### Data -The internal NVIDIA data paths in the configs (`/lustre/fsw/...`) don't exist. -We downloaded DAPO-Math-17k from HuggingFace and converted it, but ultimately -used the built-in `OpenMathInstruct-2` dataset with `math_hf_data_processor` -and `env.math` (rule-based math verification, no LLM judge needed). - -The base config also sets `data.max_input_seq_length: null` which causes a -`TypeError: '>' not supported between instances of 'int' and 'NoneType'`. -Override with `++data.max_input_seq_length=4096`. - -### NeMo Gym -The base `grpo_superv3.yaml` config has `env.nemo_gym.num_gpu_nodes: 4` which -reserves 4 GPU nodes for Gym environment servers (genrm judges, etc.). With -only 6 total nodes this left negative nodes for training. Either: -- Set `++env.nemo_gym.num_gpu_nodes=0` if using simple environments -- Set `++env.should_use_nemo_gym=false` to skip Gym entirely - -The container's built-in `nemo_gym` package also had an `ImportError` -(`cannot import name 'PARENT_DIR'`) when the Gym submodule from our repo -checkout was mounted over the container's version. Don't mount the Gym dir. - -### Parallelism (the hard part) -The 120B MoE model needs careful parallelism to fit in memory and divide evenly: - -**What didn't work:** -- TP=4, CP=4, PP=2: PP=2 from base config made TP×CP×PP=32 > 16 training GPUs -- TP=4, CP=4, PP=1, EP=8, 13 nodes: `World size (40) not divisible by 32` -- TP=4, CP=1, PP=1, EP=2: OOM (105GB allocation, only 32GB free per GPU) -- TP=4, CP=4, PP=1, EP=8, 4 training nodes: Worked for generation+logprobs, - but CUDA illegal memory access during backward (cross-node NCCL before IB fix) -- TP=4, CP=1, PP=1, EP=8 (after IB fix): OOM during training backward pass - -**What worked:** -```yaml -# 6 nodes: 2 inference, 4 training (32 GPUs) -tensor_model_parallel_size: 4 # within-node -pipeline_model_parallel_size: 1 -context_parallel_size: 1 # no cross-node context parallel -expert_model_parallel_size: 8 # MoE experts sharded across all 32 GPUs -# TP×PP×CP×EP = 4×1×1×8 = 32 = world_size, DP=1 - -# Memory optimizations required: -activation_checkpointing: true -empty_unused_memory_level: 2 -optimizer_cpu_offload: true -max_total_sequence_length: 4096 # reduced from 16384 -train_micro_batch_size: 1 -logprob_batch_size: 1 -num_prompts_per_step: 16 # reduced from 128 -num_generations_per_prompt: 8 # reduced from 16 -train_global_batch_size: 128 # reduced from 2048 -``` - -### Final Working Command -```bash -cd /opt/nemo-rl && \ -export NCCL_IB_HCA=mlx5_4:1,mlx5_7:1,... && \ -export NCCL_SOCKET_IFNAME=bond0 && \ -uv run python examples/run_grpo.py \ - --config=examples/configs/grpo_superv3.yaml \ - ++env.should_use_nemo_gym=false \ - ++data.train.dataset_name=OpenMathInstruct-2 \ - ++data.default.processor=math_hf_data_processor \ - ++data.default.env_name=math \ - ++env.math.num_workers=8 \ - ++env.math.math_verify_impl=hf_math_verify \ - ++policy.model_name=/path/to/model \ - ++cluster.num_nodes=6 \ - ++policy.generation.colocated.enabled=false \ - ++policy.generation.colocated.resources.num_nodes=2 \ - ++policy.megatron_cfg.tensor_model_parallel_size=4 \ - ++policy.megatron_cfg.pipeline_model_parallel_size=1 \ - ++policy.megatron_cfg.context_parallel_size=1 \ - ++policy.megatron_cfg.expert_model_parallel_size=8 \ - ++policy.megatron_cfg.activation_checkpointing=true \ - ++policy.megatron_cfg.optimizer.optimizer_cpu_offload=true \ - ++policy.max_total_sequence_length=4096 \ - # ... etc -``` - -~120 seconds per step on 6x B200 nodes (48 GPUs). - -## TL;DR -1. Fix `ray.sub` for your cluster (MPI plugin, container writable, home mount) -2. Set NCCL IB env vars for multi-node -3. Don't use the Workplace Assistant tutorial — use math environments instead -4. For Nemotron Super: build container from `super-v3` branch, use EP=8 to shard MoE experts, offload optimizer to CPU, reduce batch sizes -5. The model conversion cache needs to be on a shared filesystem visible to all nodes \ No newline at end of file diff --git a/scripts/convert_ds_to_nemorl_format.py b/scripts/convert_ds_to_nemorl_format.py deleted file mode 100644 index 3ab3eb8dd1..0000000000 --- a/scripts/convert_ds_to_nemorl_format.py +++ /dev/null @@ -1,175 +0,0 @@ -"""Convert lambda/hermes-agent-reasoning-traces to NeMo RL JSONL format. - -Initially generated on 2026-04-03 to convert the lambda/hermes-agent-reasoning-traces dataset to a usable format for SFT... - -The dataset uses ShareGPT/Hermes conventions: conversations are stored under -a `conversations` key with `from`/`value` turn dicts, and role names differ -from the OpenAI standard (human -> user, gpt -> assistant). - -Tool definitions are already embedded verbatim in the system message as -... XML, so use tokenizer.chat_template: NULL (passthrough) -when training — the content is already formatted for the Hermes chat template. - -Usage: - # default dataset - uv run scripts/convert_ds_to_nemorl_format.py --output_dir /path/to/output - - # custom dataset - uv run scripts/convert_ds_to_nemorl_format.py --dataset org/my-dataset --output_dir /path/to/output - - -Suggested NeMo RL YAML snippet after running: - tokenizer: - chat_template: NULL - - data: - train: - dataset_name: openai_format - data_path: /path/to/output/train.jsonl - chat_key: messages - tool_key: tools - use_preserving_dataset: true - validation: - dataset_name: openai_format - data_path: /path/to/output/val.jsonl - chat_key: messages - tool_key: tools - use_preserving_dataset: true -""" - -import argparse -import json -import os -import random - -from datasets import load_dataset - -ROLE_MAP = { - "system": "system", - "human": "user", - "gpt": "assistant", - "tool": "tool", -} - - -def convert_sample(sample: dict) -> dict: - messages = [] - for turn in sample["conversations"]: - role = ROLE_MAP.get(turn["from"]) - if role is None: - raise ValueError(f"Unexpected role '{turn['from']}' in sample id={sample.get('id')}") - messages.append({"role": role, "content": turn["value"]}) - - if messages[-1]["role"] != "assistant": - raise ValueError( - f"Last turn must be from the assistant, got '{messages[-1]['role']}' " - f"in sample id={sample.get('id')}" - ) - - result: dict = {"messages": messages} - - # The `tools` column is a JSON string; parse it into a list. - # With passthrough template this isn't used for formatting, but including - # it preserves the data for users who switch to a structured template later. - tools_raw = sample.get("tools") - if tools_raw: - try: - result["tools"] = json.loads(tools_raw) - except (json.JSONDecodeError, TypeError): - pass # drop malformed tools rather than corrupting the sample - - return result - - -def write_jsonl(path: str, samples: list[dict]) -> None: - with open(path, "w") as f: - for sample in samples: - f.write(json.dumps(sample) + "\n") - print(f"Wrote {len(samples):,} samples -> {path}") - - -def main() -> None: - parser = argparse.ArgumentParser( - description="Convert a ShareGPT/Hermes-format HuggingFace dataset to NeMo RL JSONL format." - ) - parser.add_argument( - "--dataset", - default="lambda/hermes-agent-reasoning-traces", - help="HuggingFace dataset name (default: lambda/hermes-agent-reasoning-traces).", - ) - parser.add_argument( - "--output_dir", - required=True, - help="Directory where train.jsonl (and optionally val.jsonl) will be written.", - ) - parser.add_argument( - "--val_split", - type=float, - default=0.05, - help="Fraction of data held out for validation (default: 0.05). Pass 0 to skip.", - ) - parser.add_argument( - "--seed", - type=int, - default=42, - help="Random seed for the train/val shuffle (default: 42).", - ) - args = parser.parse_args() - - os.makedirs(args.output_dir, exist_ok=True) - - print(f"Loading {args.dataset} from HuggingFace...") - ds = load_dataset(args.dataset, split="train") - print(f"Loaded {len(ds):,} samples.") - - samples = [] - n_skipped = 0 - for raw in ds: - try: - samples.append(convert_sample(raw)) - except ValueError as e: - print(f" Skipping sample: {e}") - n_skipped += 1 - - if n_skipped: - print(f"Skipped {n_skipped} malformed samples.") - - if args.val_split > 0: - random.seed(args.seed) - random.shuffle(samples) - n_val = max(1, int(len(samples) * args.val_split)) - val_samples = samples[:n_val] - train_samples = samples[n_val:] - else: - train_samples = samples - val_samples = [] - - write_jsonl(os.path.join(args.output_dir, "train.jsonl"), train_samples) - if val_samples: - write_jsonl(os.path.join(args.output_dir, "val.jsonl"), val_samples) - - print("\nDone. Add this to your sft.yaml (adjust paths as needed):") - train_path = os.path.abspath(os.path.join(args.output_dir, "train.jsonl")) - val_path = os.path.abspath(os.path.join(args.output_dir, "val.jsonl")) - print(f""" - tokenizer: - chat_template: NULL # passthrough — content is already Hermes-formatted - - data: - train: - dataset_name: openai_format - data_path: {train_path} - chat_key: messages - tool_key: tools - use_preserving_dataset: true # tools have heterogeneous argument schemas - validation: - dataset_name: openai_format - data_path: {val_path} - chat_key: messages - tool_key: tools - use_preserving_dataset: true -""") - - -if __name__ == "__main__": - main() diff --git a/shim/grpo_erdos_discover_debug.yaml b/shim/grpo_erdos_discover_debug.yaml deleted file mode 100644 index e34cd2d030..0000000000 --- a/shim/grpo_erdos_discover_debug.yaml +++ /dev/null @@ -1,78 +0,0 @@ -# TTT-Discover DEBUG config. -# Inherits from grpo_math_1B.yaml for all defaults. -# Overrides: entropic advantages, inline reward, small batch, 5 steps. -defaults: "grpo_math_1B.yaml" - -grpo: - num_prompts_per_step: 4 - num_generations_per_prompt: 8 - max_num_epochs: 1 - max_num_steps: 5 - max_rollout_turns: 1 - remove_constant_reward_groups: true - adv_estimator: - name: entropic_adaptive_beta - gamma: 0.6931471805599453 - -loss_fn: - kl_penalty_coef: 0.1 - ratio_clip: 0.2 - token_level_loss: false - -policy: - model_name: "Qwen/Qwen2.5-1.5B-Instruct" - max_total_sequence_length: 4096 - train_global_batch_size: 32 - train_micro_batch_size: 4 - - dtensor_cfg: - enabled: true - cpu_offload: true - activation_checkpointing: true - sequence_parallel: false - - lora_cfg: - enabled: true - rank: 16 - alpha: 1.0 - dropout: 0.0 - - generation: - backend: "vllm" - max_new_tokens: 2048 - temperature: 1.0 - top_p: 1.0 - stop_token_ids: null - stop_strings: null - vllm_cfg: - async_engine: false - tensor_parallel_size: 1 - pipeline_parallel_size: 1 - expert_parallel_size: 1 - gpu_memory_utilization: 0.6 - max_model_len: ${policy.max_total_sequence_length} - -optimizer: - name: adamw - lr: 1.0e-4 - -data: - shuffle: false - -env: - erdos_discovery: - resource_server_url: "inline" - num_initial_states: 8 - num_groups_per_step: 4 - sandbox_timeout: 60 - request_timeout: 120 - -checkpointing: - enabled: false - -logger: - log_dir: "logs/erdos-debug" - wandb_enabled: false - tensorboard_enabled: false - mlflow_enabled: false - swanlab_enabled: false diff --git a/shim/nemo_rl/__init__.py b/shim/nemo_rl/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/shim/nemo_rl/algorithms/__init__.py b/shim/nemo_rl/algorithms/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/shim/nemo_rl/algorithms/entropic_advantage_estimator.py b/shim/nemo_rl/algorithms/entropic_advantage_estimator.py deleted file mode 100644 index af78ae5637..0000000000 --- a/shim/nemo_rl/algorithms/entropic_advantage_estimator.py +++ /dev/null @@ -1,179 +0,0 @@ -"""Entropic Adaptive-Beta Advantage Estimator for TTT-Discover. - -Implements the Leave-One-Out (LOO) entropic advantage from -"Learning to Discover at Test Time" (arXiv:2601.16175). - -Instead of standard group-relative advantages (Adv = R - mean(R)), -this estimator: - 1. Solves for β such that KL(softmax_β(R) || uniform) = γ (default γ = ln(2)) - 2. Computes LOO advantages: w_i = exp(β·r_i) / Z_{-i} - 1 - where Z_{-i} is the normalizer excluding the i-th sample. - -Properties: - - Shift-invariant, approximately scale-invariant - - Monotone in reward - - Approximately mean-zero - - Adaptive scaling via β solves the reward-scale sensitivity of standard GRPO -""" - -import math -from typing import Optional - -import torch - - -def _solve_beta( - rewards: torch.Tensor, - gamma: float = math.log(2), - max_iter: int = 50, - tol: float = 1e-6, -) -> float: - """Solve for β such that KL(softmax_β(R) || uniform) = γ via bisection. - - Args: - rewards: [K] tensor of rewards for one group. - gamma: Target KL divergence. Default ln(2) as in the paper. - max_iter: Maximum bisection iterations. - tol: Convergence tolerance on β. - - Returns: - Scalar β value. - """ - K = rewards.shape[0] - if K <= 1: - return 0.0 - - log_K = math.log(K) - r = rewards.double() - r_max = r.max() - - def kl_at_beta(b: float) -> float: - logits = b * (r - r_max) - log_Z = torch.logsumexp(logits, dim=0) - logq = logits - log_Z - q = logq.exp() - kl = (q * (logq + log_K)).sum().item() - return kl - - # Bisect: KL is monotonically increasing in |β| for non-constant rewards - # Find upper bound for β - lo, hi = 0.0, 1.0 - while kl_at_beta(hi) < gamma and hi < 1e8: - hi *= 2.0 - - # Edge case: all rewards identical → β = 0, KL = 0 for any β - if hi >= 1e8: - return 0.0 - - for _ in range(max_iter): - mid = (lo + hi) / 2.0 - kl = kl_at_beta(mid) - if abs(kl - gamma) < tol: - return mid - if kl < gamma: - lo = mid - else: - hi = mid - - return (lo + hi) / 2.0 - - -def compute_entropic_advantages( - rewards: torch.Tensor, - gamma: float = math.log(2), - eps: float = 1e-8, -) -> torch.Tensor: - """Compute LOO entropic advantages for a group of rewards. - - Args: - rewards: [K] tensor of rewards for one group. - gamma: Target KL for adaptive β. - eps: Small constant for numerical stability. - - Returns: - [K] tensor of advantages. - """ - K = rewards.shape[0] - if K <= 1: - return torch.zeros_like(rewards) - - beta = _solve_beta(rewards, gamma=gamma) - if beta == 0.0: - return torch.zeros_like(rewards) - - r = rewards.double() - r_max = r.max() - e = torch.exp(beta * (r - r_max)) - - if K == 1: - Z_loo = e - else: - # Leave-one-out normalizer: Z_{-i} = (sum(e) - e_i) / (K - 1) - Z_loo = (e.sum() - e) / (K - 1) - - w = e / (Z_loo + eps) - advantages = (w - 1.0).to(rewards.dtype) - return advantages - - -class EntropicAdaptiveBetaAdvantageEstimator: - """Advantage estimator using entropic adaptive-β LOO weighting. - - Follows the same interface as GRPOAdvantageEstimator: - compute_advantage(prompt_ids, rewards, mask, **kwargs) -> [B, S] tensor - - Config keys (under grpo.adv_estimator): - gamma: Target KL for β search. Default ln(2) ≈ 0.693. - eps: Numerical stability constant. Default 1e-8. - """ - - def __init__(self, estimator_config: dict, loss_config: dict): - self.gamma = estimator_config.get("gamma", math.log(2)) - self.eps = estimator_config.get("eps", 1e-8) - - def compute_advantage( - self, - prompt_ids: torch.Tensor, - rewards: torch.Tensor, - mask: torch.Tensor, - **kwargs, - ) -> torch.Tensor: - """Compute per-token advantages using entropic adaptive-β LOO. - - Args: - prompt_ids: [B] or [B, S] tensor identifying which prompt each - sample belongs to (same prompt = same group). - rewards: [B] scalar rewards per sample. - mask: [B, S] response token mask (1 = generation token). - - Returns: - [B, S] advantages tensor. Each generation token gets the - sample-level advantage; non-generation tokens get 0. - """ - batch_size, seq_len = mask.shape - advantages = torch.zeros_like(mask, dtype=rewards.dtype) - - # Group by prompt (same as GRPO's per-prompt baseline) - if prompt_ids.dim() > 1: - # prompt_ids is [B, S] — use first token as group key - group_ids = prompt_ids[:, 0] - else: - group_ids = prompt_ids - - unique_prompts = group_ids.unique() - - for pid in unique_prompts: - group_mask = group_ids == pid - group_rewards = rewards[group_mask] - - group_adv = compute_entropic_advantages( - group_rewards, gamma=self.gamma, eps=self.eps - ) - - # Expand sample-level advantages to [group_size, seq_len] - # and mask to generation tokens only - group_indices = group_mask.nonzero(as_tuple=True)[0] - for i, idx in enumerate(group_indices): - advantages[idx] = group_adv[i] * mask[idx] - - return advantages diff --git a/shim/nemo_rl/algorithms/grpo.py b/shim/nemo_rl/algorithms/grpo.py deleted file mode 100644 index 58f722c653..0000000000 --- a/shim/nemo_rl/algorithms/grpo.py +++ /dev/null @@ -1,3205 +0,0 @@ -# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import gc -import os -import time -import warnings -from concurrent.futures import ThreadPoolExecutor -from contextlib import nullcontext -from typing import Any, NotRequired, Optional, TypedDict, TypeVar, cast - -import numpy as np -import ray -import torch -from torchdata.stateful_dataloader import StatefulDataLoader -from transformers import AutoProcessor -from transformers.tokenization_utils_base import PreTrainedTokenizerBase - -from nemo_rl.algorithms.advantage_estimator import ( - GDPOAdvantageEstimator, - GRPOAdvantageEstimator, - ReinforcePlusPlusAdvantageEstimator, -) -from nemo_rl.algorithms.loss import ( - ClippedPGLossConfig, - ClippedPGLossDataDict, - ClippedPGLossFn, -) -from nemo_rl.algorithms.loss.interfaces import LossFunction -from nemo_rl.algorithms.reward_functions import ( - RewardShapingConfig, - apply_reward_shaping, -) -from nemo_rl.algorithms.utils import ( - calculate_baseline_and_std_per_prompt, - get_gdpo_reward_component_keys, - log_generation_metrics_to_wandb, - print_performance_metrics, - set_seed, -) -from nemo_rl.data import DataConfig -from nemo_rl.data.collate_fn import rl_collate_fn -from nemo_rl.data.dataloader import MultipleDataloaderWrapper -from nemo_rl.data.datasets import AllTaskProcessedDataset -from nemo_rl.data.interfaces import DatumSpec -from nemo_rl.data.llm_message_utils import ( - batched_message_log_to_flat_message, - get_keys_from_message_log, -) -from nemo_rl.data.utils import extract_necessary_env_names -from nemo_rl.distributed.batched_data_dict import BatchedDataDict -from nemo_rl.distributed.ray_actor_environment_registry import get_actor_python_env -from nemo_rl.distributed.virtual_cluster import ClusterConfig, RayVirtualCluster -from nemo_rl.environments.interfaces import EnvironmentInterface -from nemo_rl.experience.rollouts import ( - run_async_multi_turn_rollout, - run_async_nemo_gym_rollout, - run_multi_turn_rollout, -) -from nemo_rl.models.generation.interfaces import GenerationInterface -from nemo_rl.models.generation.sglang import SGLangConfig, SGLangGeneration -from nemo_rl.models.generation.vllm import VllmConfig, VllmGeneration -from nemo_rl.models.policy import PolicyConfig -from nemo_rl.models.policy.interfaces import ColocatablePolicyInterface -from nemo_rl.models.policy.lm_policy import Policy -from nemo_rl.utils.checkpoint import CheckpointingConfig, CheckpointManager -from nemo_rl.utils.logger import ( - Logger, - LoggerConfig, - print_message_log_samples, -) -from nemo_rl.utils.memory_tracker import MemoryTracker -from nemo_rl.utils.nsys import maybe_gpu_profile_step -from nemo_rl.utils.timer import TimeoutChecker, Timer -from nemo_rl.utils.venvs import create_local_venv_on_each_node - -# =============================================================================== -# Configuration -# =============================================================================== -TokenizerType = TypeVar("TokenizerType", bound=PreTrainedTokenizerBase) - - -class RewardScalingConfig(TypedDict): - """Configure linear reward scaling with clamping. - - When `enabled` is True, each reward is clamped to the source interval - [source_min, source_max] and linearly mapped to the target interval - [target_min, target_max]. Refer to the scale_rewards function for the implementation. - - Defaults: - source_min=0.0, source_max=1.0, target_min=0.0, target_max=1.0 - """ - - enabled: bool - source_min: NotRequired[float] - source_max: NotRequired[float] - target_min: NotRequired[float] - target_max: NotRequired[float] - - -class AsyncGRPOConfig(TypedDict): - enabled: bool - # Maximum trajectory age in training steps for samples drawn from the - # async replay buffer. Trajectories older than this are excluded during - # sampling; buffer sizing also scales with this value. - max_trajectory_age_steps: int - # Does the weight synchronization as soon as the training is done - # without waiting for the pending generations to finish. - in_flight_weight_updates: NotRequired[bool] - # Recomputes the KV cache after the in-flight weight updates. - recompute_kv_cache_after_weight_updates: NotRequired[bool] - - -class AdvEstimatorConfig(TypedDict): - """Configuration for advantage estimator (GRPO, GDPO, Reinforce++, or Entropic).""" - - name: str # "grpo", "gdpo", "reinforce_plus_plus", or "entropic_adaptive_beta" - # GRPO specific - normalize_rewards: NotRequired[bool] - use_leave_one_out_baseline: NotRequired[bool] - # Reinforce++ specific - minus_baseline: NotRequired[bool] - # Entropic Adaptive-Beta specific (TTT-Discover, arXiv:2601.16175) - gamma: NotRequired[float] # Target KL for beta search; default ln(2) - eps: NotRequired[float] # Numerical stability; default 1e-8 - - -class GRPOConfig(TypedDict): - num_prompts_per_step: int - num_generations_per_prompt: int - max_num_epochs: int - max_num_steps: int - max_rollout_turns: int - normalize_rewards: bool - use_leave_one_out_baseline: bool - val_period: int - val_batch_size: int - val_at_start: bool - # Whether to run validation on the last training step. Setting this to True ensures the - # final checkpoint has validation metrics, which is required for get_best_checkpoint_path(). - val_at_end: bool - max_val_samples: int - skip_reference_policy_logprobs_calculation: NotRequired[bool] - seed: int - async_grpo: NotRequired[AsyncGRPOConfig] - overlong_filtering: NotRequired[bool] - # whether to enable dynamic sampling, i.e. - # whether to discard prompts whose rewards have zero standard deviation - use_dynamic_sampling: bool - # When using dynamic sampling, the maximum number of batches to generate - # before throwing an error - dynamic_sampling_max_gen_batches: NotRequired[int] - # When using dynamic sampling, generation prompt batch size will equal - # num_prompts_per_step * batch_multiplier - batch_multiplier: NotRequired[float] - reward_shaping: RewardShapingConfig - reward_scaling: RewardScalingConfig - # By default advantages are calculated on CPU. Setting this flag to true leverages GPU for their computation. - calculate_advantages_on_gpu: NotRequired[bool] - # Sequence-level logprob error masking for training stability. If set, mask sequences with mult_prob_error exceeding this threshold (same scale as token_mult_prob_error metric, e.g., 1.5) - # Note that this is slightly different than Masked Importance Sampling (MIS) because this uses the absolute value of the difference between the training and generation logprobs, whereas MIS just uses the difference between the training and generation logprobs. - seq_logprob_error_threshold: float | None - # Advantage estimator configuration (grpo or reinforce_plus_plus) - adv_estimator: NotRequired[AdvEstimatorConfig] - - -class GRPOSaveState(TypedDict): - consumed_samples: int - current_step: int - current_epoch: int - total_steps: int - total_valid_tokens: int # Track total number of non-padding tokens during training - val_reward: NotRequired[ - float - ] # Optional field - may not be present during training - - -def _default_grpo_save_state() -> GRPOSaveState: - return { - "consumed_samples": 0, - "current_step": 0, - "current_epoch": 0, - "total_steps": 0, - "total_valid_tokens": 0, - "val_reward": -99999999.0, - } - - -class GRPOLoggerConfig(LoggerConfig): - num_val_samples_to_print: int # number of val samples to print to stdout - - -class MasterConfig(TypedDict): - policy: PolicyConfig - loss_fn: ClippedPGLossConfig - env: dict[str, Any] - data: DataConfig - grpo: GRPOConfig - logger: GRPOLoggerConfig - cluster: ClusterConfig - checkpointing: CheckpointingConfig - - -# =============================================================================== -# Setup & Initialization -# =============================================================================== - - -def setup( - master_config: MasterConfig, - tokenizer: TokenizerType, - dataset: AllTaskProcessedDataset | dict[str, AllTaskProcessedDataset], - val_dataset: Optional[AllTaskProcessedDataset], - processor: Optional[AutoProcessor] = None, -) -> tuple[ - ColocatablePolicyInterface, - Optional[GenerationInterface], - tuple[RayVirtualCluster, RayVirtualCluster], - StatefulDataLoader | MultipleDataloaderWrapper, - Optional[StatefulDataLoader], - ClippedPGLossFn, - Logger, - CheckpointManager, - GRPOSaveState, - MasterConfig, -]: - """Main entry point for running GRPO algorithm. - - Returns: - tuple of policy, cluster, dataloader, tokenizer, loss_fn, math_env, logger, master_config, val_dataloader - """ - # Start timing the entire setup process - setup_start_time = time.perf_counter() - - # Extract individual configs for easier access - policy_config = master_config["policy"] - generation_config = master_config["policy"]["generation"] - env_configs = master_config["env"] - loss_config = master_config["loss_fn"] - grpo_config = master_config["grpo"] - data_config = master_config["data"] - logger_config = master_config["logger"] - cluster_config = master_config["cluster"] - - assert generation_config is not None, ( - "A generation config in the PolicyConfig is required for GRPO" - ) - - # Set seed for all random number generators - set_seed(grpo_config["seed"]) - - # ========================== - # Logger - # ========================== - logger = Logger(logger_config) - logger.log_hyperparams(master_config) - - # ========================== - # Checkpointing - # ========================== - checkpointer = CheckpointManager(master_config["checkpointing"]) - last_checkpoint_path = checkpointer.get_latest_checkpoint_path() - grpo_save_state: Optional[GRPOSaveState] = cast( - Optional[GRPOSaveState], checkpointer.load_training_info(last_checkpoint_path) - ) - if grpo_save_state is None: - grpo_save_state = _default_grpo_save_state() - - # ========================== - # Data - # ========================== - # num_prompts_per_step and dataloader_batch_size will be different when using multiple dataloaders - num_prompts_per_step = grpo_config["num_prompts_per_step"] - if data_config["use_multiple_dataloader"]: - dataloader_batch_size = data_config["num_prompts_per_dataloader"] - else: - dataloader_batch_size = num_prompts_per_step - - # Validate batch_multiplier - batch_multiplier = grpo_config["batch_multiplier"] - if grpo_config["use_dynamic_sampling"]: - num_prompts_per_step = int(num_prompts_per_step * batch_multiplier) - dataloader_batch_size = int(dataloader_batch_size * batch_multiplier) - else: - assert batch_multiplier == 1, ( - "batch_multiplier>1 can only be used if use_dynamic_sampling=True" - ) - - # Validate number of prompts per step - if data_config["use_multiple_dataloader"]: - assert num_prompts_per_step % dataloader_batch_size == 0, ( - "Expected num_prompts_per_step to be a multiple of num_prompts_per_dataloader, " - f"but got {num_prompts_per_step} and {dataloader_batch_size}. " - "Please check the configuration of num_prompts_per_step and num_prompts_per_dataloader. " - "If use_dynamic_sampling is enabled and batch_multiplier is used, please also check the configuration of batch_multiplier." - ) - - # Load train dataset - def init_train_dataloader(dataset, suffix: str = ""): - dataloader = StatefulDataLoader( - dataset, - batch_size=dataloader_batch_size, - shuffle=data_config["shuffle"], - collate_fn=rl_collate_fn, - drop_last=True, - num_workers=data_config["num_workers"], - ) - if last_checkpoint_path is not None: - dataloader_state_dict = torch.load( - os.path.join(last_checkpoint_path, f"train_dataloader{suffix}.pt") - ) - dataloader.load_state_dict(dataloader_state_dict) - return dataloader - - if data_config["use_multiple_dataloader"]: - # Initialize dataloaders - dataloaders = {} - for task_name, task_dataset in dataset.items(): - dataloaders[task_name] = init_train_dataloader( - task_dataset, f"_{task_name}" - ) - print( - f" ✓ Training dataloader {task_name} loaded with {len(task_dataset)} samples", - flush=True, - ) - - train_sample_count = sum( - len(task_dataloader) for task_dataloader in dataloaders.values() - ) - - # Wrap dataloader - dataloader = MultipleDataloaderWrapper( - expected_num_prompts=num_prompts_per_step, - data_config=data_config, - dataloaders=dataloaders, - ) - else: - dataloader = init_train_dataloader(dataset) - train_sample_count = len(dataloader) - print( - f" ✓ Training dataloader loaded with {train_sample_count} samples", - flush=True, - ) - - # Load validation dataset if provided - val_dataloader: Optional[StatefulDataLoader] = None - # If validation is enabled, load the validation dataloader - if ( - grpo_config["val_period"] > 0 - or grpo_config["val_at_start"] - or grpo_config["val_at_end"] - ): - assert val_dataset is not None, ( - "Validation dataset is required if validation is enabled" - ) - val_dataloader = StatefulDataLoader( - val_dataset, - batch_size=grpo_config["val_batch_size"], - shuffle=False, - collate_fn=rl_collate_fn, - num_workers=data_config["num_workers"], - ) - print( - f" ✓ Validation dataloader loaded with {len(val_dataset)} samples", - flush=True, - ) - - # ========================== - # Loss Function - # ========================== - loss_fn = ClippedPGLossFn(loss_config) - - # Validate force_on_policy_ratio - if loss_config.get("force_on_policy_ratio", False): - assert ( - grpo_config["num_prompts_per_step"] - * grpo_config["num_generations_per_prompt"] - == policy_config["train_global_batch_size"] - ), ( - "force_on_policy_ratio requires train_global_batch_size == num_prompts_per_step * num_generations_per_prompt" - ) - os.environ["NRL_IGNORE_TP_ACCURACY_CHECK"] = "1" - print(" ✓ force_on_policy_ratio enabled") - - # ========================== - # Cluster - # ========================== - print("\n▶ Setting up compute cluster...", flush=True) - colocated_inference = generation_config["colocated"]["enabled"] - - env_name_list = extract_necessary_env_names(data_config) - rm_env_enabled = "reward_model" in env_name_list - - total_nodes = cluster_config["num_nodes"] - if rm_env_enabled: - rm_resource = env_configs["reward_model"]["resources"] - rm_nodes = rm_resource["num_nodes"] - rm_gpus_per_node = rm_resource["gpus_per_node"] - else: - rm_nodes = 0 - rm_gpus_per_node = 0 - - if total_nodes == 1: - policy_nodes = total_nodes - else: - policy_nodes = total_nodes - rm_nodes - assert policy_nodes > 0, ( - "policy_nodes must be > 0, but got " - f"policy_nodes:{policy_nodes} + rm_nodes:{rm_nodes} = total_nodes:{total_nodes}" - ) - - if colocated_inference: - if total_nodes == 1: - policy_gpus_per_node = cluster_config["gpus_per_node"] - rm_gpus_per_node - assert policy_gpus_per_node > 0, ( - "policy.generation.colocated.resources.gpus_per_node must be > 0 " - "when cluster.num_nodes = 1, " - f"but got {policy_gpus_per_node}." - ) - else: - policy_gpus_per_node = cluster_config["gpus_per_node"] - - cluster = RayVirtualCluster( - name="grpo_policy_cluster", - bundle_ct_per_node_list=[policy_gpus_per_node] * policy_nodes, - use_gpus=True, - num_gpus_per_node=policy_gpus_per_node, - max_colocated_worker_groups=1 - if generation_config["backend"] == "megatron" - else 2, - ) - train_cluster = cluster - inference_cluster = cluster - print( - f" ✓ Ray cluster for policy initialized with {policy_nodes} nodes", - flush=True, - ) - - else: - assert generation_config["backend"] != "megatron", ( - "Non-colocated inference is not supported for Megatron generation backends. " - "Please use vLLM backend for generation." - ) - - # train resources will be updated through overall and inference resources below - train_gpus_per_node = cluster_config["gpus_per_node"] - train_nodes = policy_nodes - - inference_resources = generation_config["colocated"]["resources"] - inference_gpus_per_node = inference_resources["gpus_per_node"] - inference_nodes = inference_resources["num_nodes"] - - # validate and configure resources - if policy_nodes == 1: - # When policy_nodes == 1, train and inference are on the same node - assert ( - inference_gpus_per_node is not None and inference_gpus_per_node > 0 - ), ( - "policy.generation.colocated.resources.gpus_per_node must be explicitly set to a value > 0 " - "when policy_nodes = 1 and inference is non-colocated, " - f"but got {inference_gpus_per_node}." - ) - assert inference_nodes is None or inference_nodes == 1, ( - "policy.generation.colocated.resources.num_nodes must be 1 or set to null " - "when policy_nodes = 1 and inference is non-colocated, " - f"but got {inference_nodes}." - ) - - inference_nodes = 1 - # If total_nodes == 1, reward model is also on the same node; otherwise it's on a different node - reward_gpus_to_subtract = ( - rm_gpus_per_node if total_nodes == 1 and rm_env_enabled else 0 - ) - train_gpus_per_node -= inference_gpus_per_node + reward_gpus_to_subtract - assert train_gpus_per_node > 0, ( - "No enough GPUs for training, " - f"train_gpus_per_node:{train_gpus_per_node} = cluster_config['gpus_per_node']:{cluster_config['gpus_per_node']} - inference_gpus_per_node:{inference_gpus_per_node}" - + ( - f" - rm_gpus_per_node:{rm_gpus_per_node}" - if total_nodes == 1 and rm_env_enabled - else "" - ) - ) - else: - # train, inference, and reward model are all on different nodes - assert inference_nodes > 0, ( - "policy.generation.colocated.resources.num_nodes must be > 0 " - "when cluster.num_nodes > 1 and inference is non-colocated, " - f"but got {inference_nodes}." - ) - assert ( - inference_gpus_per_node is not None - and inference_gpus_per_node == cluster_config["gpus_per_node"] - ), ( - "policy.generation.colocated.resources.gpus_per_node must be explicitly set and equal to cluster.gpus_per_node " - "when cluster.num_nodes > 1 and inference is non-colocated, " - f"but got inference_gpus_per_node={inference_gpus_per_node}, cluster.gpus_per_node={cluster_config['gpus_per_node']}." - ) - train_nodes -= inference_nodes - - # initialize train cluster - train_cluster = RayVirtualCluster( - name="grpo_train_cluster", - bundle_ct_per_node_list=[train_gpus_per_node] * train_nodes, - use_gpus=True, - num_gpus_per_node=train_gpus_per_node, - max_colocated_worker_groups=1, - ) - print( - f" ✓ Ray train cluster initialized with {train_nodes} nodes with {train_gpus_per_node} GPUs per node", - flush=True, - ) - - # initialize inference cluster - inference_cluster = RayVirtualCluster( - name="grpo_inference_cluster", - bundle_ct_per_node_list=[inference_gpus_per_node] * inference_nodes, - use_gpus=True, - num_gpus_per_node=inference_gpus_per_node, - max_colocated_worker_groups=1, - ) - print( - f" ✓ Ray inference cluster initialized with {inference_nodes} nodes with {inference_gpus_per_node} GPUs per node", - flush=True, - ) - - # ========================== - # Training and Inference - # ========================== - print("\n▶ Setting up model and training...", flush=True) - - # vllm model loading prefers clean environment, initialize policy_generation before policy in colocated mode - backend = generation_config["backend"] - generation_config["model_name"] = policy_config["model_name"] # Needed for vLLM - - # Dictionary to store worker initialization timing stats for logging - worker_init_timing_metrics = {} - - weights_path, optimizer_path = checkpointer.get_resume_paths(last_checkpoint_path) - - if policy_config.get("megatron_cfg", {}).get("enabled", False): - ## NOTE: this is equal to the total number of scheduler steps - total_train_iters = min( - grpo_config["max_num_steps"], - grpo_config["max_num_epochs"] * train_sample_count, - ) - policy_config["megatron_cfg"]["train_iters"] = total_train_iters - - # Define initialization functions that will be used in all paths - def init_policy(): - """Initialize policy training workers.""" - t0 = time.perf_counter() - p = Policy( - cluster=train_cluster, - config=policy_config, - tokenizer=tokenizer, - processor=processor, - weights_path=weights_path, - optimizer_path=optimizer_path, - init_optimizer=True, - ) - return p, time.perf_counter() - t0 - - def init_vllm(): - """Initialize vLLM generation workers.""" - t0 = time.perf_counter() - pg = VllmGeneration(cluster=inference_cluster, config=generation_config) - pg.finish_generation() - return pg, time.perf_counter() - t0 - - def init_sglang(): - """Initialize SGLang generation workers.""" - t0 = time.perf_counter() - pg = SGLangGeneration(cluster=inference_cluster, config=generation_config) - pg.finish_generation() - return pg, time.perf_counter() - t0 - - def initialize_generation_with_policy( - init_generation_fn, - generation_name: str, - init_time_key: str, - colocated_inference: bool, - worker_init_timing_metrics: dict, - ): - """Generic function to initialize a generation engine (vLLM or SGLang) along with policy. - - Args: - init_generation_fn: Function that initializes the generation engine (init_vllm or init_sglang) - generation_name: Name of the generation engine ("vLLM" or "SGLang") - init_time_key: Key name for storing initialization time in metrics ("vllm_init_time_s" or "sglang_init_time_s") - colocated_inference: Whether inference is colocated with training - worker_init_timing_metrics: Dictionary to store timing metrics - - Returns: - Tuple of (policy_generation, policy) - """ - # Determine if parallel initialization is possible (non-colocated mode) - use_parallel_init = not colocated_inference - - if use_parallel_init: - # Parallel initialization: Generation engine and Policy can initialize simultaneously - print( - " ⚡ Using parallel worker initialization (non-colocated mode)", - flush=True, - ) - - # Execute both initializations in parallel - parallel_start_time = time.perf_counter() - with ThreadPoolExecutor(max_workers=2) as executor: - generation_future = executor.submit(init_generation_fn) - policy_future = executor.submit(init_policy) - policy_generation, generation_time = generation_future.result() - policy, policy_time = policy_future.result() - parallel_wall_time = time.perf_counter() - parallel_start_time - - # Store timing metrics - worker_init_timing_metrics[init_time_key] = generation_time - worker_init_timing_metrics["policy_init_time_s"] = policy_time - worker_init_timing_metrics["parallel_wall_time_s"] = parallel_wall_time - worker_init_timing_metrics["parallel_init_enabled"] = True - - else: - # Sequential initialization: colocated mode (GPU memory requires generation engine first) - print( - " ⚙️ Using sequential worker initialization (colocated mode)", - flush=True, - ) - - # Initialize generation engine first (clean GPU memory), then policy - policy_generation, generation_time = init_generation_fn() - worker_init_timing_metrics[init_time_key] = generation_time - - policy, policy_time = init_policy() - worker_init_timing_metrics["policy_init_time_s"] = policy_time - worker_init_timing_metrics["parallel_init_enabled"] = 0.0 - - return policy_generation, policy - - # Handle generation-specific setup - if backend == "megatron": - # Megatron generation: policy_generation is None, only initialize policy - policy_generation = None - print( - f" ✓ Using {backend} backend for generation with {policy_config['model_name']}", - flush=True, - ) - - policy, policy_time = init_policy() - worker_init_timing_metrics["policy_init_time_s"] = policy_time - - elif backend == "vllm": - # vLLM generation: setup config, then initialize with policy - generation_config = cast(VllmConfig, generation_config) - if generation_config["vllm_cfg"]["precision"] == "fp8": - assert loss_config["use_importance_sampling_correction"] is True, ( - "Importance sampling must be enabled for vLLM FP8 generation for good convergence!" - ) - if generation_config["vllm_cfg"]["kv_cache_dtype"].startswith("fp8"): - # FP8 KV cache requires FP8 model precision - assert generation_config["vllm_cfg"]["precision"] == "fp8", ( - f"kv_cache_dtype='{generation_config['vllm_cfg']['kv_cache_dtype']}' requires precision='fp8'. " - "FP8 KV cache can only be used together with FP8 model weights." - ) - # FP8 KV cache compatibility checks - assert policy_config["dtensor_cfg"]["enabled"] == False, ( - "DTensor backend is not supported with kv cache fp8 enabled." - ) - assert not _should_use_async_rollouts(master_config), ( - "Async rollouts is not supported with kv cache fp8 enabled." - ) - assert policy_config["megatron_cfg"]["pipeline_model_parallel_size"] == 1, ( - "Currently when using FP8 KV cache in generation, then in megatron we only support pipeline_model_parallel_size=1. We will add more support in future." - ) - - ## make vllm hf overrides match the training policy - generation_config["vllm_cfg"]["hf_overrides"] = policy_config.get( - "hf_config_overrides", {} - ) - - policy_generation, policy = initialize_generation_with_policy( - init_generation_fn=init_vllm, - generation_name="vLLM", - init_time_key="vllm_init_time_s", - colocated_inference=colocated_inference, - worker_init_timing_metrics=worker_init_timing_metrics, - ) - - print( - f" ✓ Using vLLM backend for generation with {policy_config['model_name']}", - flush=True, - ) - - elif backend == "sglang": - generation_config = cast(SGLangConfig, generation_config) - - # Set model_path if not already set - if "model_path" not in generation_config["sglang_cfg"]: - generation_config["sglang_cfg"]["model_path"] = policy_config["model_name"] - - policy_generation, policy = initialize_generation_with_policy( - init_generation_fn=init_sglang, - generation_name="SGLang", - init_time_key="sglang_init_time_s", - colocated_inference=colocated_inference, - worker_init_timing_metrics=worker_init_timing_metrics, - ) - - print( - f" ✓ Using SGLang backend for generation with {policy_config['model_name']}", - flush=True, - ) - - # Record when worker initialization completes (for calculating other setup time) - worker_init_complete_time = time.perf_counter() - setup_start_time - - # print the node IP and GPU ID of the policy workers for debugging - policy.print_node_ip_and_gpu_id() - - # if it is not colocated inference, initialize collective communication for update weights - if not colocated_inference: - t0 = time.perf_counter() - ip, port = train_cluster.get_master_address_and_port() - print(f"Using ip: {ip}, port: {port} for collective communication", flush=True) - # world includes all training workers and all inference workers - train_world_size = train_cluster.world_size() - inference_world_size = inference_nodes * inference_gpus_per_node - world_size = train_world_size + inference_world_size - # init collective - futures_train = policy.init_collective( - ip, port, world_size, train_world_size=train_world_size - ) - futures_inference = policy_generation.init_collective( - ip, port, world_size, train_world_size=train_world_size - ) # type: ignore - # wait for all futures to complete - ray.get(futures_train + futures_inference) - worker_init_timing_metrics["collective_init_time_s"] = time.perf_counter() - t0 - - # prepare refit info - state_dict_info = policy.prepare_refit_info() - if policy_generation is not None: - policy_generation.prepare_refit_info(state_dict_info) - - # Calculate total setup time - total_setup_time = time.perf_counter() - setup_start_time - worker_init_timing_metrics["total_setup_time_s"] = total_setup_time - - # Log worker initialization timing metrics to logger - if worker_init_timing_metrics: - print("\n▶ Worker Initialization Timing:") - - vllm_time = worker_init_timing_metrics.get("vllm_init_time_s", 0) - policy_time = worker_init_timing_metrics.get("policy_init_time_s", 0) - total_setup = worker_init_timing_metrics.get("total_setup_time_s", 0) - - if vllm_time: - print(f" vLLM init: {vllm_time:.1f}s") - - if policy_time: - print(f" Policy init: {policy_time:.1f}s") - - # Calculate "other" time (time after worker init completes) - other_time = total_setup - worker_init_complete_time - worker_init_timing_metrics["other_setup_time_s"] = other_time - print(f" Other setup: {other_time:.1f}s") - - print(f" Total setup: {total_setup:.1f}s") - - # Log all metrics to the logger for analysis - logger.log_metrics(worker_init_timing_metrics, step=0, prefix="timing/setup") - - print("\n" + "=" * 60) - print(" " * 18 + "SETUP COMPLETE") - print(f" Total setup time: {total_setup_time:.1f}s") - print("=" * 60 + "\n", flush=True) - - return ( - policy, - policy_generation, - (train_cluster, inference_cluster), - dataloader, - val_dataloader, - loss_fn, - logger, - checkpointer, - grpo_save_state, - master_config, - ) - - -# =============================================================================== -# Core Algorithm Functions -# =============================================================================== - - -def dynamic_sampling( - repeated_batch: BatchedDataDict[DatumSpec], - std: torch.Tensor, - baseline: torch.Tensor, - dynamic_sampling_num_gen_batches: int, - master_config: MasterConfig, - timer: Timer, - batch_cache: BatchedDataDict[DatumSpec] = None, -) -> BatchedDataDict[DatumSpec]: - """Implements the dynamic sampling algorithm to select prompts with non-zero standard deviation. - - This function filters the current batch to retain only those prompts that have a non-zero standard deviation. - If the current batch has fewer number of prompts with non-zero standard deviation than the required batch size, defined as num_prompts_per_step * num_generations_per_prompt, - we store it in the batch_cache to be used in later iterations. - If the current batch has more number of prompts with non-zero standard deviation than the required batch size, defined as num_prompts_per_step * num_generations_per_prompt, - the batch is sliced to ensure batch size is num_prompts_per_step * num_generations_per_prompt. - is_batch_complete is set to False to indicate that the current batch is not enough to meet the required batch size. This is used as a signal in the GRPO training loop - to continue sampling or proceed to training. - This approach is based on the dynamic sampling algorithm from the DAPO paper: - https://arxiv.org/pdf/2503.14476. - - Args: - repeated_batch (BatchedDataDict[DatumSpec]): The current batch of data containing prompts, responses, rewards, baselines, and std. - std (torch.Tensor): Tensor representing the standard deviation for each prompt group. - baseline (torch.Tensor): Baseline values for each prompt group. - dynamic_sampling_num_gen_batches (int): Number of generation batches processed at the current step. - master_config (MasterConfig): Configuration containing GRPO and policy settings. - batch_cache (BatchedDataDict[DatumSpec], optional): Cache storing previously selected prompts with non-zero std. - - Returns: - tuple: A tuple containing: - - repeated_batch (BatchedDataDict[DatumSpec]): Updated batch with selected prompts. - - is_batch_complete (bool): Indicates if the batch has enough samples with non-zero std for training. - - batch_cache (BatchedDataDict[DatumSpec]): Updated cache for future iterations. - """ - # is_batch_complete is used to indicate if the current batch was able to generate enough prompts with non-zero std. - is_batch_complete = True - - # Required batch size for training - train_prompts_size = ( - master_config["grpo"]["num_prompts_per_step"] - * master_config["grpo"]["num_generations_per_prompt"] - ) - # Store the baseline, std and total_reward for the current unfiltered batch. - repeated_batch["baseline"] = baseline - repeated_batch["std"] = std - total_rewards = repeated_batch["total_reward"] - dynamic_sampling_metrics = {} - - # Dynamic sampling algorithm (used in DAPO algorithm) - # This block implements dynamic sampling by selecting prompt groups with non-zero std. - # If sampled prompts (with non-zero std) are fewer than num_prompts_per_step * num_generations_per_prompt, continue sampling until dynamic_sampling_max_gen_batches is reached. - if master_config["grpo"]["use_dynamic_sampling"]: - with timer.time("dynamic_sampling"): - # Get the prompt indices with non-zero std - non_zero_std_mask = std != 0.0 - - keep_prompt_indices = torch.arange( - len(non_zero_std_mask), device=std.device - )[non_zero_std_mask].tolist() - - # Only select the inputs that have non-zero std - # total_reward is already a part of repeated_batch so we don't need to add it again - filtered_repeated_batch = repeated_batch.select_indices(keep_prompt_indices) - filtered_repeated_batch["std"] = std[keep_prompt_indices] - filtered_repeated_batch["baseline"] = baseline[keep_prompt_indices] - - # Store filtered and total rewards to track them separately - filtered_rewards = filtered_repeated_batch["total_reward"] - filtered_repeated_batch["total_reward"] = total_rewards - filtered_repeated_batch["filtered_reward"] = filtered_rewards - - # Store the total_reward for the current filtered batch. - # If none of the prompts in current batch have non-zero std, filtered_repeated_batch.size will be 0. - # In this case, the current batch will be ignored and the next batch will be processed and we generate responses for it. - if filtered_repeated_batch.size > 0: - # Concatenate the previous partially filled batch with the current batch. This serves as a cache to store and collect the prompts with non-zero std. - # This is used in the next iteration when the current batch is not enough to fill the buffer. - batch_cache = ( - filtered_repeated_batch - if batch_cache is None - else BatchedDataDict.from_batches( - [batch_cache, filtered_repeated_batch] - ) - ) - filtered_repeated_batch = batch_cache - - filtered_prompts_size = filtered_repeated_batch.size - print( - f"Detected {filtered_prompts_size} prompts with non-zero std; " - f"{train_prompts_size} are required and used for training." - ) - - # If the generation samples size is smaller than a fixed threshold (train_prompts_size), keep generating by processing the next batch - if filtered_prompts_size < train_prompts_size: - dynamic_sampling_max_gen_batches = master_config["grpo"][ - "dynamic_sampling_max_gen_batches" - ] - assert dynamic_sampling_max_gen_batches > 0, ( - "When using grpo.use_dynamic_sampling, grpo.dynamic_sampling_max_gen_batches must be > 0" - ) - if dynamic_sampling_num_gen_batches <= dynamic_sampling_max_gen_batches: - print( - f"Generation sample buffer size: {filtered_prompts_size} is smaller than train_prompts_size: {train_prompts_size}. Processed {dynamic_sampling_num_gen_batches} batches so far out of {dynamic_sampling_max_gen_batches}." - ) - is_batch_complete = False - else: - raise ValueError( - f"Dynamic sampling has reached the maximum allowed number of batches ({dynamic_sampling_max_gen_batches}). Consider evaluating the complexity of your data or adjusting the num_prompts_per_step or num_generations_per_prompt parameters to enhance the diversity of the samples." - ) - else: - num_discarded_valid_samples = filtered_prompts_size - train_prompts_size - dynamic_sampling_metrics[ - "dynamic_sampling_num_discarded_valid_samples" - ] = num_discarded_valid_samples - - # Slice the batch, rewards, baselines and std to ensure batch size is train_prompts_size - filtered_repeated_batch = filtered_repeated_batch.slice( - 0, train_prompts_size - ) - - batch_to_return = ( - filtered_repeated_batch - if master_config["grpo"]["use_dynamic_sampling"] - else repeated_batch - ) - return batch_to_return, is_batch_complete, batch_cache, dynamic_sampling_metrics - - -def scale_rewards( - repeated_batch: BatchedDataDict[DatumSpec], reward_scaling_cfg: RewardScalingConfig -) -> BatchedDataDict[DatumSpec]: - """Linearly scales rewards from a source range to a target range. - - If `reward_scaling.enabled` is True, each reward in `repeated_batch["total_reward"]` - is clamped to the configured source interval [source_min, source_max] and then - rescaled to the target interval [target_min, target_max]. - - Default configuration: - source_min = 0.0 - source_max = 1.0 - target_min = 0.0 - target_max = 1.0 - """ - if reward_scaling_cfg["enabled"]: - rewards = repeated_batch["total_reward"] - source_min = float(reward_scaling_cfg["source_min"]) - source_max = float(reward_scaling_cfg["source_max"]) - target_min = float(reward_scaling_cfg["target_min"]) - target_max = float(reward_scaling_cfg["target_max"]) - - # Detect out-of-range values - out_of_range_mask = (rewards < source_min) | (rewards > source_max) - if torch.any(out_of_range_mask): - print( - f"[reward_scaling] WARNING: {int(out_of_range_mask.sum())} rewards " - f"are outside the configured source range [{source_min}, {source_max}]. " - f"Values will be clipped before scaling." - ) - - # Clamp and scale - def _scale(reward_tensor: torch.Tensor) -> torch.Tensor: - r = torch.clamp(reward_tensor, min=source_min, max=source_max) - return target_min + (r - source_min) / (source_max - source_min) * ( - target_max - target_min - ) - - scaled_rewards = _scale(rewards) - repeated_batch["total_reward"] = scaled_rewards - for key in get_gdpo_reward_component_keys(repeated_batch): - repeated_batch[key] = _scale(repeated_batch[key]) - - return repeated_batch - - -def _should_use_async_rollouts(master_config: MasterConfig) -> bool: - """Determine if async rollouts should be used based on the configuration. - - Returns True if vLLM backend is used with async_engine enabled. - """ - generation_config = master_config["policy"]["generation"] - if generation_config is None: - return False - - backend = generation_config.get("backend", "") - if backend != "vllm": - return False - - vllm_cfg = generation_config.get("vllm_cfg", {}) - return vllm_cfg.get("async_engine", False) - - -def _should_use_nemo_gym(master_config: MasterConfig) -> bool: - """Determine if NeMo-Gym should be used for rollouts and validation based on the configuration.""" - env_config = master_config.get("env") or dict() - should_use_nemo_gym = bool(env_config.get("should_use_nemo_gym")) - if not should_use_nemo_gym: - return should_use_nemo_gym - - # Validate the setup for training with NeMo-Gym - assert _should_use_async_rollouts(master_config), ( - "❌ Error: In order to use NeMo-Gym, you must use vllm generation backend with `async_engine: true`!" - ) - - generation_config = master_config["policy"]["generation"] - - # We piggyback off of `_should_use_async_rollouts` to guarantee the existence of these configs. - should_expose_http_server = generation_config["vllm_cfg"].get("expose_http_server") - assert should_expose_http_server, ( - "In order to use NeMo-Gym, you must expose the vllm server via `expose_http_server: true`!" - ) - - return should_use_nemo_gym - - -def _should_log_nemo_gym_responses(master_config: MasterConfig) -> bool: - env_config = master_config.get("env") or dict() - should_log_nemo_gym_responses = bool( - env_config.get("should_log_nemo_gym_responses") - ) - - return should_log_nemo_gym_responses - - -def _create_advantage_estimator(master_config: MasterConfig): - """Create and return an advantage estimator based on configuration. - - Args: - master_config: The master configuration dictionary. - - Returns: - An advantage estimator instance (GRPO, GDPO, or ReinforcePlusPlus). - - Raises: - ValueError: If the advantage estimator name is not recognized. - """ - grpo_config = master_config["grpo"] - loss_config = master_config["loss_fn"] - - # Provide backward-compatible defaults when adv_estimator is not in config. - # Fall back to top-level grpo.normalize_rewards / grpo.use_leave_one_out_baseline - # which older configs still use. - adv_estimator_config = grpo_config.get( - "adv_estimator", - { - "name": "grpo", - "normalize_rewards": grpo_config.get("normalize_rewards", True), - "use_leave_one_out_baseline": grpo_config.get( - "use_leave_one_out_baseline", False - ), - "minus_baseline": True, - }, - ) - - adv_estimator_name = adv_estimator_config["name"] - if adv_estimator_name == "gdpo": - adv_estimator = GDPOAdvantageEstimator(adv_estimator_config, loss_config) - print(" ✓ Using GDPO advantage estimator (multi-reward)") - elif adv_estimator_name == "grpo": - adv_estimator = GRPOAdvantageEstimator(adv_estimator_config, loss_config) - print(" ✓ Using GRPO advantage estimator") - elif adv_estimator_name == "reinforce_plus_plus": - adv_estimator = ReinforcePlusPlusAdvantageEstimator( - adv_estimator_config, loss_config - ) - print(" ✓ Using Reinforce++ advantage estimator") - elif adv_estimator_name == "entropic_adaptive_beta": - from nemo_rl.algorithms.entropic_advantage_estimator import ( - EntropicAdaptiveBetaAdvantageEstimator, - ) - adv_estimator = EntropicAdaptiveBetaAdvantageEstimator( - adv_estimator_config, loss_config - ) - print(" ✓ Using Entropic Adaptive-Beta advantage estimator (TTT-Discover)") - else: - raise ValueError(f"Invalid adv_estimator name: {adv_estimator_name}") - - return adv_estimator - - -def _extract_prompt_only_messages(message_logs: list) -> list: - """Extract only prompt messages (user/system) from message logs. - - This is used to get prompt IDs for advantage estimation, excluding - any assistant responses. - - Args: - message_logs: List of message logs, where each log is a list of messages. - - Returns: - List of message logs containing only user and system messages. - """ - prompt_only_message_logs = [] - for message_log in message_logs: - prompt_only_log = [] - for message in message_log: - if message["role"] == "user" or message["role"] == "system": - prompt_only_log.append(message) - prompt_only_message_logs.append(prompt_only_log) - return prompt_only_message_logs - - -def refit_policy_generation( - policy: ColocatablePolicyInterface, - policy_generation: GenerationInterface, - colocated_inference: bool, - _refit_buffer_size_gb: Optional[int] = None, - timer: Optional[Timer] = None, - kv_scales: Optional[dict[str, float]] = None, -) -> None: - """Refit the policy generation interface with the latest policy weights. - - Args: - policy: The policy to provide weights to the inference engine. - policy_generation: The inference engine to refit. - _refit_buffer_size_gb: The size of the buffer to use for refitting. - If it is None, the buffer size will be computed by the remaining memory. - This parameter is primarily used for testing. - timer: Optional Timer used to time the prepare/transfer/update phase - kv_scales: Optional dictionary of KV cache scales for FP8 quantization. - """ - if colocated_inference: - policy.offload_before_refit() - policy_generation.prepare_for_generation(tags=["weights"]) - - # Create a context manager that does nothing when timer is None - timer_context = ( - timer.time("prepare_for_generation/transfer_and_update_weights") - if timer is not None - else nullcontext() - ) - with timer_context: - # update weights - update_success = False - if colocated_inference: - # get model param keys, which is grouped by size - if _refit_buffer_size_gb is not None: - buffer_size_bytes = _refit_buffer_size_gb * (1024**3) - else: - # Empirically sets ratio as 30% to maximize efficiency. - # The remaining 70% is a necessary buffer reserved for the parameter all-gathering across the expert-parallelism dimension. - memory_ratio = os.getenv("NRL_REFIT_BUFFER_MEMORY_RATIO", "0.3") - buffer_size_bytes = int( - policy.get_free_memory_bytes() * float(memory_ratio) - ) - - if isinstance(policy_generation, SGLangGeneration): - sglang_url_to_gpu_uuids = ( - policy_generation.get_sglang_url_to_gpu_uuids() - ) - # Stream weights via HTTP - flush_success = policy_generation.invalidate_kv_cache() - if not flush_success: - print("SGLang KV cache invalidation failed before weight update. ") - futures_train = policy.stream_weights_via_http( - sglang_url_to_gpu_uuids=sglang_url_to_gpu_uuids, - ) - # Wait for all workers to complete - ray.get(futures_train) - update_success = True - else: - # Original ZMQ IPC path for vLLM - futures_train = policy.stream_weights_via_ipc_zmq( - buffer_size_bytes=buffer_size_bytes - ) - futures_inference = policy_generation.update_weights_via_ipc_zmq() - # wait for all futures to complete - ray.get(futures_train) - results = ray.get(futures_inference) - update_success = all(result for result in results if result is not None) - else: - # update weights through nccl - # SGLang haven't implemented non-colocated inference mode. - if isinstance(policy_generation, SGLangGeneration): - raise NotImplementedError( - "SGLang haven't implemented non-colocated inference mode. " - ) - futures_train = policy.broadcast_weights_for_collective(kv_scales=kv_scales) - futures_inference = policy_generation.update_weights_from_collective() - # wait for all futures to complete - ray.get(futures_train) - results = ray.get(futures_inference) - update_success = all(result for result in results if result is not None) - - # check if update is successful - if not update_success: - error_tag = "cuda-ipc" if colocated_inference else "nccl" - error_message = ( - "❌ Error: Updating weights for the generation policy failed during refit.\n" - f"This often indicates an issue with {error_tag} or " - "a problem within the generation backend (e.g., vLLM worker).\n" - ) - raise RuntimeError(error_message) - - if colocated_inference: - policy.offload_after_refit() - policy_generation.prepare_for_generation(tags=["kv_cache"]) - - -def _log_mixed_rewards_and_advantages_information( - logger: Logger, - total_steps: int, - metrics: dict[str, Any], - baseline: torch.Tensor, - advantages: torch.Tensor, -) -> None: - # The histograms that are logged are logged with a prefix "train/" to the name, since that is what the remaining metrics will be logged with. - logger.log_histogram( - baseline.numpy(), total_steps + 1, "train/baseline_reward/histogram" - ) - metrics["baseline_reward/pct_0"] = 100 * (baseline == 0).float().mean().item() - metrics["baseline_reward/pct_1"] = 100 * (baseline == 1).float().mean().item() - metrics["baseline_reward/pct_mixed"] = ( - 100 - metrics["baseline_reward/pct_0"] - metrics["baseline_reward/pct_1"] - ) - - logger.log_histogram( - advantages.numpy(), total_steps + 1, "train/advantages/histogram" - ) - metrics["advantages/sum"] = advantages.float().sum().item() - metrics["advantages/mean"] = advantages.float().mean().item() - - -def compute_and_apply_seq_logprob_error_masking( - train_data: BatchedDataDict, - rewards: torch.Tensor, - seq_logprob_error_threshold: Optional[float], -) -> tuple[float, int, float]: - """Compute sequence-level logprob error metrics and optionally mask high-error sequences. - - This function computes the multiplicative probability error per sequence - (same calculation as token_mult_prob_error but aggregated per-sequence) and - optionally masks sequences that exceed the configured threshold. - - Args: - train_data: Training data dict containing token_mask, sample_mask, - prev_logprobs, and generation_logprobs. If masking is applied, - sample_mask will be updated in-place. - rewards: Reward tensor for computing statistics on masked sequences. - seq_logprob_error_threshold: If set, mask sequences with mult_prob_error - exceeding this threshold. If None, only compute metrics. - - Returns: - Tuple of (max_seq_mult_prob_error, num_masked_seqs, masked_correct_pct) - """ - # Compute sequence-level logprob error metrics (always) - token_mask = train_data["token_mask"][:, 1:] - sample_mask = train_data["sample_mask"] - prev_logprobs = train_data["prev_logprobs"][:, 1:] - generation_logprobs = train_data["generation_logprobs"][:, 1:] - lp_error = torch.abs(generation_logprobs - prev_logprobs) - - # Use combined mask exactly as in loss function - mask = token_mask * sample_mask.unsqueeze(-1) - - # Calculate sequence-level multiplicative prob error - # EXACT same calculation as token_mult_prob_error but per-sequence - seq_mult_prob_error = (torch.exp(lp_error * mask) * mask).sum(dim=-1) / mask.sum( - dim=-1 - ).clamp(min=1) - max_seq_mult_prob_error = ( - seq_mult_prob_error.max().item() if seq_mult_prob_error.numel() > 0 else 0.0 - ) - - # Apply sequence-level masking if configured - num_masked_seqs = 0 - masked_correct_pct = 0.0 - - if seq_logprob_error_threshold is not None: - print( - f"▶ Applying sequence-level logprob error masking (threshold={seq_logprob_error_threshold})...", - flush=True, - ) - - original_sample_mask = sample_mask.clone() - - # Create mask for sequences below threshold - seq_error_mask = ( - seq_mult_prob_error <= seq_logprob_error_threshold - ).float() * original_sample_mask - - diff_mask = original_sample_mask - seq_error_mask - num_masked_seqs = int(diff_mask.sum().item()) - - if num_masked_seqs > 0: - diff_mask_bool = diff_mask.bool() - masked_correct_count = (rewards.view(-1)[diff_mask_bool] == 1).sum().item() - masked_correct_pct = masked_correct_count / num_masked_seqs - - # Update sample_mask in train_data - train_data["sample_mask"] = seq_error_mask - - print( - f" Masked {num_masked_seqs} sequences with mult_prob_error > {seq_logprob_error_threshold}", - flush=True, - ) - if num_masked_seqs > 0: - print( - f" • {masked_correct_count}/{num_masked_seqs} masked sequences were correct (reward=1)" - f" → {masked_correct_pct:.2%}", - flush=True, - ) - - return max_seq_mult_prob_error, num_masked_seqs, masked_correct_pct - - -# =============================================================================== -# Training & Validation -# =============================================================================== - - -def grpo_train( - policy: ColocatablePolicyInterface, - policy_generation: Optional[GenerationInterface], - wrapped_dataloader: StatefulDataLoader | MultipleDataloaderWrapper, - val_dataloader: Optional[StatefulDataLoader], - tokenizer: TokenizerType, - loss_fn: LossFunction, - task_to_env: dict[str, EnvironmentInterface], - val_task_to_env: Optional[dict[str, EnvironmentInterface]], - logger: Logger, - checkpointer: CheckpointManager, - grpo_save_state: GRPOSaveState, - master_config: MasterConfig, -) -> None: - """Run GRPO training algorithm.""" - timer = Timer() - timeout = TimeoutChecker( - timeout=master_config["checkpointing"]["checkpoint_must_save_by"], - fit_last_save_time=True, - ) - timeout.start_iterations() - memory_tracker = MemoryTracker() - - kv_scales_cache = None # Cache reused for computed kv scales - - NEED_REFIT = True - # If policy_generation is None, use the policy as the generation interface (megatron framework backend) - if policy_generation is None: - policy_generation = policy # type: ignore - NEED_REFIT = False - POLICY_GENERATION_STALE = True # tracks if generation needs a refit before running - assert policy_generation is not None # for mypy type check - - if master_config["grpo"].get("skip_reference_policy_logprobs_calculation"): - assert master_config["loss_fn"]["reference_policy_kl_penalty"] == 0 - print( - "Reference policy logprob calculation will be skipped since `grpo.skip_reference_policy_logprobs_calculation` is set to True and `loss_fn.reference_policy_kl_penalty` is 0." - ) - - # Check if we need to sync KV cache scales - # When fallback to policy as the policy_generation, we use getattr to check. - sync_kv_scales = getattr(policy_generation, "requires_kv_scale_sync", False) - - # common config/state times - current_step = grpo_save_state["current_step"] # current step within an epoch - total_steps = grpo_save_state["total_steps"] # total steps across all epochs - max_num_steps = master_config["grpo"][ - "max_num_steps" - ] # max number of steps to train for - current_epoch = grpo_save_state["current_epoch"] # current epoch - max_num_epochs = master_config["grpo"][ - "max_num_epochs" - ] # max number of epochs to train for - consumed_samples = grpo_save_state[ - "consumed_samples" - ] # total samples consumed across all epochs - total_valid_tokens = grpo_save_state.get( - "total_valid_tokens", 0 - ) # total valid tokens processed across all epochs; default to 0 for backward compatibility with older checkpoints - val_at_start = master_config["grpo"]["val_at_start"] - val_at_end = master_config["grpo"]["val_at_end"] - val_period = master_config["grpo"]["val_period"] - colocated_inference = master_config["policy"]["generation"]["colocated"]["enabled"] - - # Initialize advantage estimator - adv_estimator = _create_advantage_estimator(master_config) - - # Run validation at the start if configured - # TODO: Add validation with kv scales if needed - if val_at_start and current_step == 0: - print("\n🔍 Running initial validation...", flush=True) - memory_tracker.snapshot_start_of_stage("Initial validation", dir()) - - if NEED_REFIT and POLICY_GENERATION_STALE: - refit_policy_generation(policy, policy_generation, colocated_inference) - POLICY_GENERATION_STALE = False - else: - policy_generation.prepare_for_generation() - val_metrics, validation_timings = validate( - policy_generation, - val_dataloader, - tokenizer, - val_task_to_env, - step=0, - master_config=master_config, - logger=logger, - ) - policy_generation.finish_generation() - logger.log_metrics(val_metrics, current_step, prefix="validation") - logger.log_metrics(validation_timings, current_step, prefix="timing/validation") - - if master_config["data"]["use_multiple_dataloader"]: - warnings.warn( - "When using multiple dataloaders, MultipleDataloaderWrapper operates as an infinite iterator. " - "As a result, grpo.max_num_epochs will be ignored, and only grpo.max_num_steps will be used. " - "See https://github.com/NVIDIA-NeMo/RL/blob/main/docs/guides/grpo.md#multiple-dataloaders for more details." - ) - - while current_epoch < max_num_epochs and total_steps < max_num_steps: - memory_tracker.snapshot_start_of_stage("Preparing batch", dir()) - print(f"\n{'=' * 25} Epoch {current_epoch + 1}/{max_num_epochs} {'=' * 25}") - # batch cache is used for DAPO. We store prompts with non-zero standard deviation in this cache. - batch_cache: BatchedDataDict[DatumSpec] = None - # This is the number of batches we processed so far at each step to generate responses whose std is non-zero. Maximum threshold is set by dynamic_sampling_max_gen_batches. Used in the case of dynamic sampling. - dynamic_sampling_num_gen_batches = 0 - - # Run grpo/dapo training loop (single-turn) - for batch in wrapped_dataloader: - # A central place to store logging data that won't be deleted until the loop ends - metrics_logging_data = dict() - metrics = dict() - - if master_config["data"]["use_multiple_dataloader"]: - print( - f"\n{'=' * 25} Step {current_step + 1}/{max_num_steps} {'=' * 25}", - flush=True, - ) - else: - print( - f"\n{'=' * 25} Step {current_step + 1}/{min(len(wrapped_dataloader), max_num_steps)} {'=' * 25}", - flush=True, - ) - - maybe_gpu_profile_step(policy, total_steps + 1) - if policy != policy_generation: - maybe_gpu_profile_step(policy_generation, total_steps + 1) - val_metrics, validation_timings = None, None - - with timer.time("total_step_time"): - # Prepare batch - print("▶ Preparing batch...", flush=True) - with timer.time("data_processing"): - # Repeat batch items - repeated_batch: BatchedDataDict[DatumSpec] = ( - batch.repeat_interleave( - master_config["grpo"]["num_generations_per_prompt"] - ) - ) - # Convert LLMMessageLogType to FlatMessagesType for generation - batched_flat, input_lengths = batched_message_log_to_flat_message( - repeated_batch["message_log"], - pad_value_dict={"token_ids": tokenizer.pad_token_id}, - ) - input_ids = batched_flat["token_ids"] - - # Generate responses - this updates the LLMMessageLogType in repeated_batch - memory_tracker.snapshot_start_of_stage("Generation", dir()) - print( - f"▶ Generating responses for batch of size {repeated_batch.size}...", - flush=True, - ) - with timer.time("prepare_for_generation/total"): - if NEED_REFIT and POLICY_GENERATION_STALE: - # Compute KV scales if needed for FP8 quantization - if sync_kv_scales and kv_scales_cache is None: - print("▶ Computing KV cache scales...", flush=True) - policy.prepare_for_lp_inference() - # Align with training data processing to ensure parallel training compatibility - calib_flat, calib_input_lengths = ( - batched_message_log_to_flat_message( - repeated_batch["message_log"], - pad_value_dict={ - "token_ids": tokenizer.pad_token_id - }, - make_sequence_length_divisible_by=master_config[ - "policy" - ]["make_sequence_length_divisible_by"], - ) - ) - # Create calibration data from flattened messages - calibration_data = BatchedDataDict[ClippedPGLossDataDict]( - { - "input_ids": calib_flat["token_ids"], - "input_lengths": calib_input_lengths, - } - ) - calibration_data.update( - calib_flat.get_multimodal_dict(as_tensors=False) - ) - calibration_data.to("cpu") - kv_scales_cache = policy.calibrate_qkv_fp8_scales( - calibration_data, include_q=True - )["layers"] - - refit_policy_generation( - policy, - policy_generation, - colocated_inference, - timer=timer, - kv_scales=kv_scales_cache if sync_kv_scales else None, - ) - POLICY_GENERATION_STALE = False - else: - if colocated_inference: - policy.offload_after_refit() # unload optimizer to make space for generation - policy_generation.prepare_for_generation() - - dynamic_sampling_num_gen_batches += 1 - if dynamic_sampling_num_gen_batches == 1 and hasattr( - policy_generation, "snapshot_step_metrics" - ): - policy_generation.snapshot_step_metrics() - with timer.time("generation"): - # Clear logger metrics for each generation step - if policy_generation is not None: - policy_generation.clear_logger_metrics() - # Use NeMo-Gym rollouts if enabled. We cascade NeMo-Gym first since NeMo-Gym requires async rollouts. - if _should_use_nemo_gym(master_config): - generation_config = master_config["policy"]["generation"] - nemo_gym_rollout_result = run_async_nemo_gym_rollout( - policy_generation=policy_generation, - input_batch=repeated_batch, - tokenizer=tokenizer, - task_to_env=task_to_env, - max_seq_len=None, - generation_config=generation_config, - max_rollout_turns=None, - greedy=False, - ) - input_ids = nemo_gym_rollout_result.input_ids - repeated_batch = nemo_gym_rollout_result.final_batch - rollout_metrics = nemo_gym_rollout_result.rollout_metrics - del nemo_gym_rollout_result - - # NeMo Gym responses can be very large and expensive to log. Here we have logic to opt-in to logging. - if not _should_log_nemo_gym_responses(master_config): - for key in list(rollout_metrics): - if "full_result" in key: - rollout_metrics.pop(key) - - # Use async rollouts if vLLM async engine is enabled - elif _should_use_async_rollouts(master_config): - ( - repeated_batch, - rollout_metrics, - ) = run_async_multi_turn_rollout( - policy_generation=policy_generation, - input_batch=repeated_batch, - tokenizer=tokenizer, - task_to_env=task_to_env, - max_seq_len=master_config["policy"][ - "max_total_sequence_length" - ], - max_rollout_turns=master_config["grpo"][ - "max_rollout_turns" - ], - greedy=False, - ) - else: - repeated_batch, rollout_metrics = run_multi_turn_rollout( - policy_generation=policy_generation, - input_batch=repeated_batch, - tokenizer=tokenizer, - task_to_env=task_to_env, - max_seq_len=master_config["policy"][ - "max_total_sequence_length" - ], - max_rollout_turns=master_config["grpo"][ - "max_rollout_turns" - ], - greedy=False, - ) - policy_generation.finish_generation() - # Collect generation logger metrics for performance reporting after each generation step - # inflight batch sizes and num pending samples are collected from each worker - if policy_generation is not None: - generation_logger_metrics = ( - policy_generation.get_logger_metrics() - ) - - metrics_logging_data["mean_gen_tokens_per_sample"] = ( - rollout_metrics["mean_gen_tokens_per_sample"] - ) - logger.log_metrics(rollout_metrics, total_steps + 1, prefix="train") - - repeated_batch = scale_rewards( - repeated_batch, master_config["grpo"]["reward_scaling"] - ) - # Process rewards with custom reward function - if master_config["grpo"]["reward_shaping"]["enabled"]: - repeated_batch = apply_reward_shaping( - repeated_batch, master_config["grpo"]["reward_shaping"] - ) - - # Calculate rewards & advantages - memory_tracker.snapshot_start_of_stage("Processing rewards", dir()) - print("▶ Processing rewards...,", flush=True) - with timer.time("reward_calculation"): - # Extract rewards from final_batch - rewards = repeated_batch["total_reward"] - - print("▶ Computing advantages...", flush=True) - if master_config["grpo"].get("calculate_advantages_on_gpu"): - print("Computing advantages on GPU!") - # Just fix the device id for now - device_id = 0 - baseline, std = calculate_baseline_and_std_per_prompt( - input_ids.cuda(device_id), - rewards.cuda(device_id), - torch.ones_like(rewards).cuda(device_id), - leave_one_out_baseline=master_config["grpo"][ - "use_leave_one_out_baseline" - ], - ) - baseline = baseline.cpu() - std = std.cpu() - else: - baseline, std = calculate_baseline_and_std_per_prompt( - input_ids, - rewards, - torch.ones_like(rewards), - leave_one_out_baseline=master_config["grpo"][ - "use_leave_one_out_baseline" - ], - ) - - # Apply dynamic sampling to filter prompts with non-zero std (DAPO algorithm) - repeated_batch, is_batch_complete, batch_cache, ds_metrics = ( - dynamic_sampling( - repeated_batch, - std, - baseline, - dynamic_sampling_num_gen_batches, - master_config, - timer, - batch_cache, - ) - ) - if ds_metrics: - ds_metrics["dynamic_sampling_num_gen_batches"] = ( - dynamic_sampling_num_gen_batches - ) - # Get the updated rewards and baselines. For DAPO, these rewards and baselines only correspond to the prompts with non-zero std. - rewards = ( - repeated_batch["total_reward"] - if not master_config["grpo"]["use_dynamic_sampling"] - else repeated_batch["filtered_reward"] - ) - baseline = repeated_batch["baseline"] - std = repeated_batch["std"] - - # If the current batch is not enough to fill the buffer during dynamic sampling, we update the cache and process the next batch. - if not is_batch_complete: - continue - - gen_step_metrics = {} - if hasattr(policy_generation, "get_step_metrics"): - gen_step_metrics = policy_generation.get_step_metrics() - - # Save baseline for logging (before deletion) - baseline_for_log = baseline.clone() - - # Extract prompt-only messages for advantage estimation - prompt_only_message_logs = _extract_prompt_only_messages( - repeated_batch["message_log"] - ) - prompt_batched_flat, _ = batched_message_log_to_flat_message( - prompt_only_message_logs, - pad_value_dict={"token_ids": tokenizer.pad_token_id}, - ) - prompt_ids_for_adv = prompt_batched_flat["token_ids"] - del prompt_only_message_logs - del prompt_batched_flat - del input_ids - del baseline - del std - - with timer.time("data_processing"): - use_overlong_filtering = master_config["grpo"]["overlong_filtering"] - if use_overlong_filtering: - loss_multiplier = repeated_batch["loss_multiplier"].clone() - truncated = repeated_batch["truncated"] - - if isinstance(truncated, list): - truncated = torch.tensor(truncated, dtype=torch.bool) - - loss_multiplier[truncated] = 0 - repeated_batch["loss_multiplier"] = loss_multiplier - # Add loss mask to each message in LLMMessageLogType - for i, message_log in enumerate(repeated_batch["message_log"]): - for j, message in enumerate(message_log): - if message["role"] == "assistant": - message["token_loss_mask"] = torch.ones_like( - message["token_ids"] - ) - else: - message["token_loss_mask"] = torch.zeros_like( - message["token_ids"] - ) - if "generation_logprobs" not in message: - message["generation_logprobs"] = torch.zeros_like( - message["token_ids"], dtype=torch.float32 - ) - - # Convert updated LLMMessageLogType to FlatMessagesType for training - flat_messages, input_lengths = batched_message_log_to_flat_message( - repeated_batch["message_log"], - pad_value_dict={"token_ids": tokenizer.pad_token_id}, - make_sequence_length_divisible_by=master_config["policy"][ - "make_sequence_length_divisible_by" - ], - ) - - # Create training data from flattened messages - # Note: advantages will be computed and added after logprobs are available - train_data = BatchedDataDict[ClippedPGLossDataDict]( - { - "input_ids": flat_messages["token_ids"], - "input_lengths": input_lengths, - "generation_logprobs": flat_messages["generation_logprobs"], - "token_mask": flat_messages["token_loss_mask"], - "sample_mask": repeated_batch["loss_multiplier"], - } - ) - # this will be mini-batched inside the policy, so maintain the packed multimodal structure - # This is also used to populate part of the downstream logprob calculation data - extra_multimodal_data = flat_messages.get_multimodal_dict( - as_tensors=False - ) - train_data.update(extra_multimodal_data) - train_data.to("cpu") - - metrics_logging_data["content"] = flat_messages["content"] - - memory_tracker.snapshot_start_of_stage("Computing logprobs", dir()) - print("▶ Preparing for logprob inference...", flush=True) - with timer.time("logprob_inference_prep"): - policy.prepare_for_lp_inference() - - print("▶ Computing logprobs...", flush=True) - with timer.time("policy_and_reference_logprobs"): - # Custom create this logprob_data so we avoid Ray comm overheads sending unused data to workers. - logprob_data = BatchedDataDict[ClippedPGLossDataDict]( - { - "input_ids": train_data["input_ids"], - "input_lengths": train_data["input_lengths"], - "token_mask": flat_messages["token_loss_mask"], - "sample_mask": repeated_batch["loss_multiplier"], - **extra_multimodal_data, - } - ) - train_data["prev_logprobs"] = policy.get_logprobs( - logprob_data, timer=timer - )["logprobs"] - - if not master_config["grpo"].get( - "skip_reference_policy_logprobs_calculation" - ): - train_data["reference_policy_logprobs"] = ( - policy.get_reference_policy_logprobs( - logprob_data, - timer=timer, - )["reference_logprobs"] - ) - - del logprob_data - del extra_multimodal_data - - ( - max_seq_mult_prob_error, - num_masked_seqs, - masked_correct_pct, - ) = compute_and_apply_seq_logprob_error_masking( - train_data=train_data, - rewards=rewards, - seq_logprob_error_threshold=master_config["grpo"][ - "seq_logprob_error_threshold" - ], - ) - - # Compute advantages with adv_estimator using correct mask and logprobs - with timer.time("advantage_calculation"): - print("▶ Computing advantages...", flush=True) - # Get token-level mask: token_mask * sample_mask - token_mask = train_data["token_mask"] - sample_mask = train_data["sample_mask"] - mask = token_mask * sample_mask.unsqueeze(-1) - - train_data["advantages"] = adv_estimator.compute_advantage( - prompt_ids=prompt_ids_for_adv, - rewards=rewards, - mask=mask, - repeated_batch=repeated_batch, - logprobs_policy=train_data["prev_logprobs"], - logprobs_reference=train_data.get("reference_policy_logprobs"), - ) - del prompt_ids_for_adv - - # Log rewards and advantages information - _log_mixed_rewards_and_advantages_information( - logger=logger, - total_steps=total_steps, - metrics=metrics, - baseline=baseline_for_log, - advantages=train_data["advantages"], - ) - del baseline_for_log - - memory_tracker.snapshot_start_of_stage("Policy train", dir()) - print("▶ Preparing for training...", flush=True) - with timer.time("training_prep"): - policy.prepare_for_training() # set model train and reload optim to GPU - POLICY_GENERATION_STALE = True - - print("▶ Training policy...", flush=True) - with timer.time("policy_training"): - train_results = policy.train( - train_data, - loss_fn, - timer=timer, - ) - - # Recompute KV scales after policy training if needed - if sync_kv_scales: - with timer.time("recompute_kv_scales"): - print( - "▶ Recomputing KV cache scales after policy update...", - flush=True, - ) - kv_scales_cache = policy.calibrate_qkv_fp8_scales( - train_data, include_q=True - )["layers"] - # Set generation as stale to force refit with new scales - POLICY_GENERATION_STALE = True - - is_last_step = total_steps + 1 >= max_num_steps - if not master_config["data"]["use_multiple_dataloader"]: - is_last_step = is_last_step or ( - (current_epoch + 1 == max_num_epochs) - and (current_step + 1 == len(wrapped_dataloader)) - ) - - # Run validation if it's a validation step or last step with val_at_end - if (val_period > 0 and (total_steps + 1) % val_period == 0) or ( - val_at_end and is_last_step - ): - memory_tracker.snapshot_start_of_stage("Validation", dir()) - if NEED_REFIT and POLICY_GENERATION_STALE: - refit_policy_generation( - policy, - policy_generation, - colocated_inference, - kv_scales=kv_scales_cache if sync_kv_scales else None, - ) - POLICY_GENERATION_STALE = False - else: - if colocated_inference: - policy.offload_after_refit() # unload optimizer to make space for generation - policy_generation.prepare_for_generation() - val_metrics, validation_timings = validate( - policy_generation, - val_dataloader, - tokenizer, - val_task_to_env, - step=total_steps + 1, - master_config=master_config, - logger=logger, - ) - policy_generation.finish_generation() - logger.log_metrics( - validation_timings, total_steps + 1, prefix="timing/validation" - ) - logger.log_metrics( - val_metrics, total_steps + 1, prefix="validation" - ) - - # Get flat advantages and token mask for masked metrics computation - flat_advantages = train_data["advantages"] - flat_token_mask = flat_messages["token_loss_mask"] - - # Filter advantages using token mask (only valid response tokens) - response_advantages = torch.masked_select( - flat_advantages, flat_token_mask.bool() - ) - - memory_tracker.snapshot_start_of_stage("Metrics", dir()) - metrics = { - **metrics, - "loss": train_results["loss"].numpy(), - "grad_norm": train_results["grad_norm"].numpy(), - "reward": rewards.numpy(), - "mean_prompt_length": repeated_batch["length"].numpy(), - "total_num_tokens": input_lengths.numpy(), - # Add masked advantages tracking metrics (only for valid response tokens) - "advantages/mean": torch.mean(response_advantages).detach().item() - if response_advantages.numel() > 0 - else 0.0, - "advantages/max": torch.max(response_advantages).detach().item() - if response_advantages.numel() > 0 - else 0.0, - "advantages/min": torch.min(response_advantages).detach().item() - if response_advantages.numel() > 0 - else 0.0, - **ds_metrics, - } - if "moe_metrics" in train_results: - metrics.update( - {f"moe/{k}": v for k, v in train_results["moe_metrics"].items()} - ) - if master_config["grpo"]["use_dynamic_sampling"]: - metrics["filtered_reward"] = rewards.numpy() - metrics["reward"] = repeated_batch["total_reward"].numpy() - - metrics.update(train_results["all_mb_metrics"]) - metrics.update(gen_step_metrics) - for k, v in metrics.items(): - if k in {"probs_ratio_min", "probs_ratio_clamped_min"}: - valid_values = [x for x in v if not np.isinf(x)] - metrics[k] = ( - np.min(valid_values).item() if valid_values else -1.0 - ) - elif k in {"probs_ratio_max", "probs_ratio_clamped_max"}: - valid_values = [x for x in v if not np.isinf(x)] - metrics[k] = ( - np.max(valid_values).item() if valid_values else -1.0 - ) - elif k in { - "lr", - "wd", - "reward", - "filtered_reward", - "global_valid_seqs", - "global_valid_toks", - "mean_prompt_length", - }: - metrics[k] = np.mean(v).item() - elif isinstance(v, (np.ndarray, list)): - metrics[k] = np.sum(v).item() - else: - print(f"Skipping aggregation for {k} ({type(v)})") - - metrics.update(rollout_metrics) - metrics["generation_logger_metrics"] = generation_logger_metrics - total_valid_tokens += metrics["global_valid_toks"] - - # Always log sequence-level error metrics (useful for deciding threshold) - metrics["max_seq_mult_prob_error"] = max_seq_mult_prob_error - metrics["num_masked_seqs_by_logprob_error"] = num_masked_seqs - metrics["masked_correct_pct"] = masked_correct_pct - - ## Checkpointing - consumed_samples += master_config["grpo"]["num_prompts_per_step"] - timeout.mark_iteration() - - should_save_by_step = ( - is_last_step - or (total_steps + 1) % master_config["checkpointing"]["save_period"] - == 0 - ) - # +1 because step is 0-indexed - # Check if timeout-based checkpointing is enabled in config. - should_save_by_timeout = timeout.check_save() - - memory_tracker.snapshot_start_of_stage("Checkpointing", dir()) - if master_config["checkpointing"]["enabled"] and ( - should_save_by_step or should_save_by_timeout - ): - policy.prepare_for_training() - - # +1 because step is 0-indexed - grpo_save_state["current_step"] = current_step + 1 - grpo_save_state["total_steps"] = total_steps + 1 - grpo_save_state["current_epoch"] = current_epoch - grpo_save_state["total_valid_tokens"] = total_valid_tokens - if val_metrics is not None: - grpo_save_state["val_reward"] = val_metrics["accuracy"] - elif "val_reward" in grpo_save_state: - del grpo_save_state["val_reward"] - grpo_save_state["consumed_samples"] = consumed_samples - - full_metric_name = master_config["checkpointing"]["metric_name"] - if full_metric_name is not None: - assert full_metric_name.startswith( - "train:" - ) or full_metric_name.startswith("val:"), ( - f"metric_name={full_metric_name} must start with 'val:' or 'train:',\n" - f'followed by the corresponding name in the "val" or "train" metrics dictionary.' - f" If you are using an old config, please updated checkpointing.metric_name to the new format, " - f" e.g. 'val_reward --> 'val:reward'" - ) - prefix, metric_name = full_metric_name.split(":", 1) - metrics_source = metrics if prefix == "train" else val_metrics - if not metrics_source: - warnings.warn( - f"You asked to save checkpoints based on {metric_name} but no {prefix} metrics were collected. " - "This checkpoint will not be saved as top-k.", - stacklevel=2, - ) - if full_metric_name in grpo_save_state: - del grpo_save_state[full_metric_name] - elif metric_name not in metrics_source: - raise ValueError( - f"Metric {metric_name} not found in {prefix} metrics" - ) - else: - grpo_save_state[full_metric_name] = metrics_source[ - metric_name - ] - - with timer.time("checkpointing"): - print( - f"Saving checkpoint for step {total_steps + 1}...", - flush=True, - ) - checkpoint_path = checkpointer.init_tmp_checkpoint( - total_steps + 1, grpo_save_state, master_config - ) - policy.save_checkpoint( - weights_path=os.path.join( - checkpoint_path, "policy", "weights" - ), - optimizer_path=os.path.join( - checkpoint_path, "policy", "optimizer" - ) - if checkpointer.save_optimizer - else None, - tokenizer_path=os.path.join( - checkpoint_path, "policy", "tokenizer" - ), - checkpointing_cfg=master_config["checkpointing"], - ) - if master_config["data"]["use_multiple_dataloader"]: - for ( - task_name, - task_dataloader, - ) in wrapped_dataloader.dataloaders.items(): - torch.save( - task_dataloader.state_dict(), - os.path.join( - checkpoint_path, - f"train_dataloader_{task_name}.pt", - ), - ) - else: - torch.save( - wrapped_dataloader.state_dict(), - os.path.join(checkpoint_path, "train_dataloader.pt"), - ) - checkpointer.finalize_checkpoint(checkpoint_path) - - # Logging - # Log training data - memory_tracker.snapshot_start_of_stage("Logging", dir()) - if not _should_log_nemo_gym_responses(master_config): - log_data = {} - if "agent_ref" in repeated_batch: - log_data["agent_ref"] = repeated_batch["agent_ref"] - log_data["content"] = flat_messages["content"] - log_data["rewards"] = rewards.tolist() - if master_config["grpo"]["use_dynamic_sampling"]: - log_data["filtered_rewards"] = rewards.tolist() - log_data["rewards"] = repeated_batch["total_reward"].tolist() - log_data["input_lengths"] = input_lengths.tolist() - log_data["token_ids"] = train_data["input_ids"].tolist() - log_data["token_loss_mask"] = train_data["token_mask"].tolist() - log_data["sample_loss_mask"] = train_data["sample_mask"].tolist() - log_data["advantages"] = train_data["advantages"].tolist() - log_data["generation_logprobs"] = train_data[ - "generation_logprobs" - ].tolist() - log_data["prev_logprobs"] = train_data["prev_logprobs"].tolist() - - logger.log_batched_dict_as_jsonl( - log_data, f"train_data_step{total_steps + 1}.jsonl" - ) - del log_data - del flat_messages - - timing_metrics: dict[str, float] = timer.get_timing_metrics( - reduction_op="sum" - ) # type: ignore - # track example with high token mult prob error above 1.05 - if metrics["token_mult_prob_error"] > 1.05: - logger.log_plot_token_mult_prob_error( - { - "prompt_lengths": repeated_batch["length"], - "full_lengths": input_lengths, - "generation_logprobs": train_data["generation_logprobs"], - "prev_logprobs": train_data["prev_logprobs"], - "token_mask": train_data["token_mask"], - "sample_mask": train_data["sample_mask"], - }, - total_steps + 1, - name="train/token_mult_prob_error_plot_sample", - ) - del train_data - if master_config["policy"]["generation"].get("vllm_cfg", {}).get( - "enable_vllm_metrics_logger", False - ) and master_config.get("logger", {}).get("wandb_enabled", False): - log_generation_metrics_to_wandb( - generation_logger_metrics, - total_steps + 1, - master_config["policy"]["generation"]["vllm_cfg"][ - "vllm_metrics_logger_interval" - ], - logger, - ) - - # Plot ISL/OSL/ISL+OSL histograms to wandb - if ( - master_config["policy"]["generation"] - .get("vllm_cfg", {}) - .get("async_engine", False) - ): - for metric_name in metrics.keys(): - if metric_name.startswith("histogram/"): - logger.log_histogram( - metrics[metric_name], - total_steps + 1, - f"generation_metrics/{metric_name}", - ) - - print("\n📊 Training Results:") - - print(f" • Loss: {metrics['loss']:.4f}") - print(f" • Generation KL Error: {metrics['gen_kl_error']:.4f}") - if master_config["grpo"]["use_dynamic_sampling"]: - print(f" • Avg Filtered Reward: {np.mean(rewards.numpy()):.4f}") - print( - f" • Avg Total Reward: {np.mean(repeated_batch['total_reward'].numpy()):.4f}" - ) - else: - print(f" • Avg Reward: {np.mean(rewards.numpy()):.4f}") - print( - f" • Mean Generation Length: {metrics_logging_data['mean_gen_tokens_per_sample']:.4f}", - flush=True, - ) - - print("\n⏱️ Timing:", flush=True) - # Display total time first, separately - total_time = timing_metrics.get("total_step_time", 0) - - number_of_samples_per_step = ( - master_config["grpo"]["num_prompts_per_step"] - * master_config["grpo"]["num_generations_per_prompt"] - ) - total_num_gpus = ( - master_config["cluster"]["num_nodes"] - * master_config["cluster"]["gpus_per_node"] - ) - - print(f" • Total step time: {total_time:.2f}s", flush=True) - - # Display all other timing metrics - for k, v in sorted( - timing_metrics.items(), key=lambda item: item[1], reverse=True - ): - if k != "total_step_time": - percent = (v / total_time * 100) if total_time > 0 else 0 - print(f" • {k}: {v:.2f}s ({percent:.1f}%)", flush=True) - - timing_metrics["valid_tokens_per_sec_per_gpu"] = ( - metrics["global_valid_toks"] / total_time / total_num_gpus - ) - performance_metrics = print_performance_metrics( - train_results, metrics, timing_metrics, master_config - ) - - logger.log_metrics(metrics, total_steps + 1, prefix="train") - logger.log_metrics( - performance_metrics, total_steps + 1, prefix="performance" - ) - # step_finished=True here since this is the final log of our current step. - logger.log_metrics( - timing_metrics, - total_steps + 1, - prefix="timing/train", - step_finished=True, - ) - - # Reset the batch and set dynamic_sampling_num_gen_batches to 0 - batch_cache = None - dynamic_sampling_num_gen_batches = 0 - - # Clear mem - memory_tracker.snapshot_start_of_stage("After CPU memory clear", dir()) - - # processing rewards - del repeated_batch - del rewards - # train_data already deleted after logging above - # logging - del metrics - if "val_metrics" in dir(): - del val_metrics - - timer.reset() - current_step += 1 - total_steps += 1 - if should_save_by_timeout: - memory_tracker.snapshot_start_of_stage("", dir()) - print("Timeout has been reached, stopping training early", flush=True) - return - if total_steps >= max_num_steps: - memory_tracker.snapshot_start_of_stage("", dir()) - print( - "Max number of steps has been reached, stopping training early", - flush=True, - ) - return - - current_epoch += 1 - current_step = 0 # Reset step counter for new epoch - - -def validate( - policy_generation: GenerationInterface, - val_dataloader: Optional[StatefulDataLoader], - tokenizer, - val_task_to_env: Optional[dict[str, EnvironmentInterface]], - step: int, - master_config: MasterConfig, - logger: Optional[Logger] = None, -) -> tuple[dict[str, Any], dict[str, Any]]: - """Run validation on the validation dataset.""" - if val_dataloader is None: - assert val_dataloader is not None or master_config["dpo"]["val_period"] == 0, ( - "val_dataloader is None, so dpo.val_period must be 0" - ) - print(" ⚠️ No validation dataloader provided, skipping validation", flush=True) - return {}, {} - - timer = Timer() - with timer.time("total_validation_time"): - print(f"▶ Starting validation at step {step}...", flush=True) - - total_rewards = [] - total_lengths = [] - all_message_logs = [] # Collect all message logs - - max_batches = ( - master_config["grpo"]["max_val_samples"] - // master_config["grpo"]["val_batch_size"] - ) - for batch_idx, val_batch in enumerate(val_dataloader): - if batch_idx >= max_batches: - break - - additional_metrics_to_report = dict() - # Generate responses (updates the LLMMessageLogType in batch_with_msg_logs) - # Use async rollouts if vLLM async engine is enabled - # We cascade NeMo-Gym first since NeMo-Gym also uses async rollouts. - if _should_use_nemo_gym(master_config): - generation_config = master_config["policy"]["generation"] - nemo_gym_rollout_result = run_async_nemo_gym_rollout( - policy_generation=policy_generation, - input_batch=val_batch, - tokenizer=tokenizer, - task_to_env=val_task_to_env, - max_seq_len=None, - generation_config=generation_config, - max_rollout_turns=None, - greedy=False, - ) - val_batch = nemo_gym_rollout_result.final_batch - gen_metrics = nemo_gym_rollout_result.rollout_metrics - additional_metrics_to_report = gen_metrics - elif _should_use_async_rollouts(master_config): - val_batch, gen_metrics = run_async_multi_turn_rollout( - policy_generation, - val_batch, - tokenizer, - val_task_to_env, - max_seq_len=master_config["policy"]["max_total_sequence_length"], - max_rollout_turns=master_config["grpo"]["max_rollout_turns"], - greedy=False, - ) - else: - val_batch, gen_metrics = run_multi_turn_rollout( - policy_generation, - val_batch, - tokenizer, - val_task_to_env, - max_seq_len=master_config["policy"]["max_total_sequence_length"], - max_rollout_turns=master_config["grpo"]["max_rollout_turns"], - greedy=False, - ) - - total_rewards.extend(val_batch["total_reward"].tolist()) - total_lengths.append(gen_metrics["mean_gen_tokens_per_sample"]) - - # Collect message logs for later display - to_env = [ - get_keys_from_message_log( - val_batch["message_log"][i], ["role", "content"] - ) - for i in range(len(val_batch["message_log"])) - ] - - all_message_logs.extend(to_env) - - # Calculate validation metrics - num_samples = len(total_rewards) - if num_samples > 0: - rewards_t = torch.tensor(total_rewards, dtype=torch.float32) - accuracy = rewards_t.mean().item() - else: - accuracy = 0.0 - - avg_length = ( - sum(total_lengths) / len(total_lengths) if len(total_lengths) > 0 else 0.0 - ) - - val_metrics = { - "accuracy": accuracy, - "avg_length": avg_length, - **additional_metrics_to_report, - } - - # Print sample conversations only once at the end of validation - try: - print_message_log_samples( - all_message_logs, - total_rewards, - num_samples=min( - master_config["logger"]["num_val_samples_to_print"], - len(all_message_logs), - ), - step=step, - ) - except Exception as e: - print(f"\n ⚠️ Error displaying message samples: {str(e)}") - print(" ⚠️ Continuing validation without displaying samples...", flush=True) - - # Get timing metrics - timing_metrics = timer.get_timing_metrics(reduction_op="sum") - validation_time = timing_metrics.get("total_validation_time", 0) - - # Print summary of validation results - print("\n📊 Validation Results:") - print(f" • Accuracy: {accuracy:.4f}") - print(f" • Average response length: {avg_length:.1f} tokens") - print(f" • Samples processed: {len(total_rewards)}", flush=True) - - # Print timing information - print("\n ⏱️ Validation Timing:") - validation_time = timing_metrics.get("total_validation_time", 0) - print(f" • Total validation time: {validation_time:.2f}s", flush=True) - - # Log validation data to JSONL file - if logger is not None: - val_log_data = { - "content": all_message_logs, - "rewards": total_rewards, - } - logger.log_batched_dict_as_jsonl(val_log_data, f"val_data_step{step}.jsonl") - - # Make sure to reset the timer after validation - timer.reset() - - # Explicit GPU memory cleanup after validation - gc.collect() - torch.cuda.empty_cache() - - return val_metrics, timing_metrics - - -def async_grpo_train( - policy: ColocatablePolicyInterface, - policy_generation: Optional[GenerationInterface], - dataloader: StatefulDataLoader, - val_dataloader: Optional[StatefulDataLoader], - tokenizer: TokenizerType, - loss_fn: LossFunction, - task_to_env: dict[str, EnvironmentInterface], - val_task_to_env: Optional[dict[str, EnvironmentInterface]], - logger: Logger, - checkpointer: CheckpointManager, - grpo_save_state: GRPOSaveState, - master_config: MasterConfig, - max_trajectory_age_steps: int = 1, -) -> None: - """Run asynchronous GRPO training with replay buffer. - - Args: - policy: Training policy - policy_generation: Generation interface - dataloader: Training data loader - val_dataloader: Validation data loader - tokenizer: Tokenizer - loss_fn: Loss function - task_to_env: Training environments - val_task_to_env: Validation environments - logger: Logger - checkpointer: Checkpoint manager - grpo_save_state: Training state - master_config: Master configuration - max_trajectory_age_steps: Maximum age (in training steps) for trajectories to be used in training - """ - # Ensure we are running with a compatible async generation backend - assert _should_use_async_rollouts(master_config), ( - "Async GRPO requires vLLM backend with vllm_cfg.async_engine=True. " - "Set policy.generation.vllm_cfg.async_engine to true in your config." - ) - assert master_config["loss_fn"]["use_importance_sampling_correction"] is True, ( - "Importance sampling correction must be enabled for async GRPO for good convergence due to off-policy samples!" - ) - - if master_config["grpo"]["async_grpo"]["max_trajectory_age_steps"] > 1: - if not master_config["grpo"]["async_grpo"].get( - "in_flight_weight_updates", False - ): - print( - "⚠️ WARNING: In-flight weight updates must be enabled for async GRPO with max_trajectory_age_steps > 1. " - "Without in-flight weight updates, having more max_trajectory_age_steps will not give any performance benefit." - ) - - # Import async utilities only when needed - from nemo_rl.algorithms.async_utils import AsyncTrajectoryCollector, ReplayBuffer - - timer = Timer() - timeout = TimeoutChecker( - timeout=master_config["checkpointing"]["checkpoint_must_save_by"], - fit_last_save_time=True, - ) - timeout.start_iterations() - NEED_REFIT = True - - # Setup generation interface - if policy_generation is None: - policy_generation = policy - NEED_REFIT = False - POLICY_GENERATION_STALE = True - assert policy_generation is not None - - # Training state - step = grpo_save_state["current_step"] - weight_version = step # Tracks refitted weight versions - consumed_samples = grpo_save_state["consumed_samples"] - total_valid_tokens = grpo_save_state.get( - "total_valid_tokens", 0 - ) # Default to 0 for backward compatibility with older checkpoints - val_period = master_config["grpo"]["val_period"] - val_at_start = master_config["grpo"]["val_at_start"] - val_at_end = master_config["grpo"]["val_at_end"] - colocated_inference = master_config["policy"]["generation"]["colocated"]["enabled"] - - # Initialize advantage estimator - adv_estimator = _create_advantage_estimator(master_config) - - assert not colocated_inference, ( - "Colocated inference is not supported for async GRPO. Please use non-colocated inference." - ) - - # Calculate minimum buffer size from training requirements - # In per-prompt buffer mode, one buffer entry is 1 prompt * num_generations_per_prompt - num_prompts_per_step = master_config["grpo"]["num_prompts_per_step"] - samples_per_prompt_group = master_config["grpo"]["num_generations_per_prompt"] - train_gbs = master_config["policy"]["train_global_batch_size"] - - # Ensure the buffer has at least one step worth of prompt-groups before training - min_trajectories_needed = num_prompts_per_step - - print("📊 Buffer requirements calculation:") - print(f" - num_prompts_per_step: {num_prompts_per_step}") - print(f" - num_generations_per_prompt: {samples_per_prompt_group}") - print(f" - samples_per_prompt_group: {samples_per_prompt_group}") - print(f" - train_global_batch_size: {train_gbs}") - print(f" - min_trajectories_needed: {min_trajectories_needed} (async mode)") - - _replay_py_exec = get_actor_python_env( - "nemo_rl.algorithms.async_utils.ReplayBuffer" - ) - if _replay_py_exec.startswith("uv"): - # Lazily build a dedicated venv across all Ray nodes on-demand. - _replay_py_exec = create_local_venv_on_each_node( - _replay_py_exec, - "nemo_rl.algorithms.async_utils.ReplayBuffer", - ) - - _replay_runtime_env = { - "py_executable": _replay_py_exec, - "env_vars": { - **os.environ, - "VIRTUAL_ENV": _replay_py_exec, - "UV_PROJECT_ENVIRONMENT": _replay_py_exec, - }, - } - - # Calculate optimal buffer size based on generation limits to prevent length bias - # Each weight version generates exactly num_prompts_per_step trajectories - # With max_age_steps, we keep trajectories from multiple weight versions - num_prompts_per_step = master_config["grpo"]["num_prompts_per_step"] - late_arrival_slack = 2 - optimal_buffer_size = ( - num_prompts_per_step * max_trajectory_age_steps * late_arrival_slack - ) - - replay_buffer = ReplayBuffer.options(runtime_env=_replay_runtime_env).remote( - max_size=optimal_buffer_size - ) - - _tc_py_exec = get_actor_python_env( - "nemo_rl.algorithms.async_utils.AsyncTrajectoryCollector" - ) - if _tc_py_exec.startswith("uv"): - _tc_py_exec = create_local_venv_on_each_node( - _tc_py_exec, - "nemo_rl.algorithms.async_utils.AsyncTrajectoryCollector", - ) - - _tc_runtime_env = { - "py_executable": _tc_py_exec, - "env_vars": { - **os.environ, - "VIRTUAL_ENV": _tc_py_exec, - "UV_PROJECT_ENVIRONMENT": _tc_py_exec, - }, - } - - # Initialize trajectory collector with synchronized collection - trajectory_collector = AsyncTrajectoryCollector.options( - runtime_env=_tc_runtime_env - ).remote( - policy_generation=policy_generation, - tokenizer=tokenizer, - task_to_env=task_to_env, - master_config=master_config, - replay_buffer=replay_buffer, - start_step=step, - ) - - # Start trajectory collection in background - collection_task = trajectory_collector.start_collection.remote(dataloader) - - # Ensure collector knows initial weight version - trajectory_collector.set_weight_version.remote(weight_version) - - print("📦 Started continuous background trajectory collection") - - print( - f"🚀 Starting async GRPO training with buffer_size={optimal_buffer_size}, max_age={max_trajectory_age_steps} steps" - ) - - print("⏳ Preparing policy generation for training...") - if NEED_REFIT and POLICY_GENERATION_STALE: - print("🔄 Refitting policy generation with actual model weights...") - try: - refit_policy_generation(policy, policy_generation, colocated_inference) - print("✅ Policy generation refit completed successfully") - POLICY_GENERATION_STALE = False - except Exception as e: - print(f"❌ Policy generation refit failed: {e}") - import traceback - - traceback.print_exc() - return - else: - print("🔄 Preparing policy generation for inference...") - try: - policy_generation.prepare_for_generation() - print("✅ Policy generation preparation completed successfully") - except Exception as e: - print(f"❌ Policy generation preparation failed: {e}") - import traceback - - traceback.print_exc() - return - - print("✅ Policy generation setup complete, proceeding to validation...") - - # Run validation at start if configured - if val_at_start and step == 0: - print("\n🔍 Running initial validation...") - # Pause trajectory collection during initial validation - trajectory_collector.pause.remote() - - try: - val_metrics, validation_timings = validate( - policy_generation, - val_dataloader, - tokenizer, - val_task_to_env, - step=0, - master_config=master_config, - logger=logger, - ) - policy_generation.finish_generation() - logger.log_metrics(val_metrics, step, prefix="validation") - logger.log_metrics(validation_timings, step, prefix="timing/validation") - print("✅ Initial validation completed successfully") - except Exception as e: - print(f"❌ Initial validation failed: {e}") - import traceback - - traceback.print_exc() - # Continue anyway since validation is optional - finally: - # Resume trajectory collection after initial validation - trajectory_collector.resume.remote() - - print("✅ All setup complete, starting buffer wait...") - # Clear logger metrics at start of training - if policy_generation is not None: - policy_generation.clear_logger_metrics() - - # Wait for initial buffer fill - print( - f"⏳ Waiting for replay buffer to have sufficient trajectories ({min_trajectories_needed} trajectories)..." - ) - wait_iterations = 0 - while True: - buffer_size_current = ray.get(replay_buffer.size.remote()) - - print( - f" Wait iteration {wait_iterations}: buffer_filled_ratio={buffer_size_current}/{min_trajectories_needed}" - ) - - if buffer_size_current >= min_trajectories_needed: - break - - time.sleep(1.0) - - print("✅ Buffer ready! Starting training loop...") - - # Main training loop - try: - while step < master_config["grpo"]["max_num_steps"]: - print( - f"\n{'=' * 25} Step {step + 1}/{master_config['grpo']['max_num_steps']} {'=' * 25}" - ) - maybe_gpu_profile_step(policy, step + 1) - if policy != policy_generation: - maybe_gpu_profile_step(policy_generation, step + 1) - - with timer.time("total_step_time"): - # Sample trajectories from replay buffer - print("📦 Sampling from replay buffer...") - with timer.time("exposed_generation"): - buffer_size_current = ray.get(replay_buffer.size.remote()) - print( - f"📊 Step coordination: training_step={step}, max_age={max_trajectory_age_steps}, buffer_size={buffer_size_current}" - ) - - # Sample the required number of per-prompt groups. - num_prompt_groups_needed = master_config["grpo"][ - "num_prompts_per_step" - ] - sample_result = ray.get( - replay_buffer.sample.remote( - num_prompt_groups=num_prompt_groups_needed, - current_weight_version=weight_version, - max_age_steps=max_trajectory_age_steps, - ) - ) - - if ( - sample_result is None - or len(sample_result["trajectories"]) - != num_prompt_groups_needed - ): - print( - "⏳ Buffer empty or not enough groups to form a full step, waiting..." - ) - - # Get buffer debug info to help diagnose the issue - buffer_debug = ray.get(replay_buffer.get_debug_info.remote()) - buffer_size = buffer_debug["total_trajectories"] - - if buffer_size > 0: - print( - f"🔍 Debug: Buffer has {buffer_size} trajectories but sampling requires exactly {num_prompt_groups_needed}." - ) - print(f" Current weight version: {weight_version}") - print(f" Max trajectory age: {max_trajectory_age_steps}") - print( - f" Trajectory versions in buffer: {buffer_debug['trajectory_versions']}" - ) - - time.sleep(0.5) - continue - - # Extract trajectories and metadata from sample result - trajectories = sample_result["trajectories"] - avg_trajectory_age = sample_result["avg_trajectory_age"] - - print( - f"✅ Sampled {len(trajectories)} trajectory groups from buffer (avg age: {avg_trajectory_age:.2f} steps)" - ) - - # Concatenate per-prompt groups into a single training batch - per_prompt_batches = [t["batch"] for t in trajectories] - repeated_batch = BatchedDataDict.from_batches(per_prompt_batches) - # Aggregate rollout metrics across groups (simple mean where applicable) - rollout_metrics = {} - for t in trajectories: - for k, v in t["rollout_metrics"].items(): - rollout_metrics.setdefault(k, []).append(v) - # TODO: this simple averaging might cause misleading information for such data as max_gen_tokens, etc. - rollout_metrics = { - k: (sum(v) / len(v) if isinstance(v[0], (int, float)) else v) - for k, v in rollout_metrics.items() - } - - # Enforce fixed training batch: num_prompts_per_step * num_generations_per_prompt - expected_batch_size = ( - master_config["grpo"]["num_prompts_per_step"] - * master_config["grpo"]["num_generations_per_prompt"] - ) - if repeated_batch.size != expected_batch_size: - print( - f"❌ Unexpected training batch size: got {repeated_batch.size}, expected {expected_batch_size}. Skipping step and waiting for correct buffer content." - ) - time.sleep(0.5) - continue - - # Optional sanity: ensure DP divisibility to avoid sharding issues - dp_size = policy.sharding_annotations.get_axis_size("data_parallel") - if expected_batch_size % dp_size != 0: - raise AssertionError( - f"Configuration error: (num_prompts_per_step * num_generations_per_prompt) = {expected_batch_size} must be divisible by data_parallel size {dp_size}." - ) - - print(f"Got trajectory batch (size: {repeated_batch.size})") - - print("▶ Processing rewards...") - with timer.time("reward_calculation"): - # Extract prompt-only messages for advantage estimation - prompt_only_message_logs = _extract_prompt_only_messages( - repeated_batch["message_log"] - ) - prompt_batched_flat, _ = batched_message_log_to_flat_message( - prompt_only_message_logs, - pad_value_dict={"token_ids": tokenizer.pad_token_id}, - ) - prompt_ids_for_adv = prompt_batched_flat["token_ids"] - del prompt_only_message_logs - del prompt_batched_flat - - rewards = repeated_batch["total_reward"] - - print( - f" 📊 Rewards stats: min={rewards.min():.4f}, max={rewards.max():.4f}, mean={rewards.mean():.4f}, std={rewards.std():.4f}" - ) - - # Prepare training data (same as sync version) - with timer.time("data_processing"): - # Add loss mask to each message - for i, message_log in enumerate(repeated_batch["message_log"]): - for j, message in enumerate(message_log): - if message["role"] == "assistant": - message["token_loss_mask"] = torch.ones_like( - message["token_ids"] - ) - else: - message["token_loss_mask"] = torch.zeros_like( - message["token_ids"] - ) - if "generation_logprobs" not in message: - message["generation_logprobs"] = torch.zeros_like( - message["token_ids"], dtype=torch.float32 - ) - - # Convert to flat format for training - flat_messages, input_lengths = batched_message_log_to_flat_message( - repeated_batch["message_log"], - pad_value_dict={"token_ids": tokenizer.pad_token_id}, - make_sequence_length_divisible_by=master_config["policy"][ - "make_sequence_length_divisible_by" - ], - ) - - # Create training data - # Note: advantages will be computed and added after logprobs are available - train_data = BatchedDataDict[ClippedPGLossDataDict]( - { - "input_ids": flat_messages["token_ids"], - "input_lengths": input_lengths, - "generation_logprobs": flat_messages["generation_logprobs"], - "token_mask": flat_messages["token_loss_mask"], - "sample_mask": repeated_batch["loss_multiplier"], - } - ) - train_data.to("cpu") - - # Training phase (same as sync version) - print("▶ Preparing for logprob inference...") - with timer.time("logprob_inference_prep"): - policy.prepare_for_lp_inference() - - print("▶ Computing logprobs...") - with timer.time("policy_and_reference_logprobs"): - fprop_logprobs = policy.get_logprobs( - train_data, - timer=timer, - )["logprobs"] - reference_logprobs = policy.get_reference_policy_logprobs( - train_data, - timer=timer, - )["reference_logprobs"] - train_data["prev_logprobs"] = fprop_logprobs - train_data["reference_policy_logprobs"] = reference_logprobs - - ( - max_seq_mult_prob_error, - num_masked_seqs, - masked_correct_pct, - ) = compute_and_apply_seq_logprob_error_masking( - train_data=train_data, - rewards=rewards, - seq_logprob_error_threshold=master_config["grpo"][ - "seq_logprob_error_threshold" - ], - ) - - # Compute advantages with adv_estimator using correct mask and logprobs - with timer.time("advantage_calculation"): - print("▶ Computing advantages...", flush=True) - # Get token-level mask: token_mask * sample_mask - token_mask = train_data["token_mask"] - sample_mask = train_data["sample_mask"] - mask = token_mask * sample_mask.unsqueeze(-1) - - train_data["advantages"] = adv_estimator.compute_advantage( - prompt_ids=prompt_ids_for_adv, - rewards=rewards, - mask=mask, - repeated_batch=repeated_batch, - logprobs_policy=train_data["prev_logprobs"], - logprobs_reference=train_data.get("reference_policy_logprobs"), - ) - del prompt_ids_for_adv - - # Log advantages stats - # Note: For GRPOAdvantageEstimator with normalize_rewards=True, these are - # already normalized advantages (equivalent to "Normalized advantages stats" - # in older versions). For ReinforcePlusPlusAdvantageEstimator, advantages - # are globally normalized across valid tokens. - advantages = train_data["advantages"] - print( - f" 📊 Advantages stats: min={advantages.min():.4f}, max={advantages.max():.4f}, mean={advantages.mean():.4f}, std={advantages.std():.4f}" - ) - - print("▶ Preparing for training...") - with timer.time("training_prep"): - policy.prepare_for_training() - POLICY_GENERATION_STALE = True - - print("▶ Training policy...") - with timer.time("policy_training"): - train_results = policy.train( - train_data, - loss_fn, - timer=timer, - ) - - print("🔄 Synchronizing policy weights to trajectory collector…") - generation_logger_metrics = None - if NEED_REFIT: - # Measure pending-generation wait as exposed_generation time - print("🔄 Coordinating with trajectory collector before refit...") - with timer.time("exposed_generation"): - ray.get(trajectory_collector.prepare_for_refit.remote()) - - # Collect generation logger metrics for performance reporting - # inflight batch sizes and num pending samples are collected from each worker - if policy_generation is not None: - generation_logger_metrics = ( - policy_generation.get_logger_metrics() - ) - - # Only the actual refit/weight transfer should be counted as weight_sync - print("🔄 Performing policy generation refit...") - with timer.time("weight_sync"): - refit_policy_generation( - policy, policy_generation, colocated_inference - ) - POLICY_GENERATION_STALE = False - - # Update weight version before resuming trajectory collection so that all trajectories are updated with the new correct weight version - weight_version += 1 - trajectory_collector.set_weight_version.remote(weight_version) - trajectory_collector.resume_after_refit.remote() - - # Clear logger metrics after each refit (weight sync), starting a new logging cycle - if policy_generation is not None: - policy_generation.clear_logger_metrics() - - # Validation - val_metrics, validation_timings = None, None - is_last_step = step + 1 == master_config["grpo"]["max_num_steps"] - - # Run validation if it's a validation step or last step with val_at_end - if (val_period > 0 and (step + 1) % val_period == 0) or ( - val_at_end and is_last_step - ): - # Pause trajectory collection during validation to reduce memory pressure - trajectory_collector.pause.remote() - - if NEED_REFIT and POLICY_GENERATION_STALE: - refit_policy_generation( - policy, policy_generation, colocated_inference - ) - POLICY_GENERATION_STALE = False - else: - policy_generation.prepare_for_generation() - val_metrics, validation_timings = validate( - policy_generation, - val_dataloader, - tokenizer, - val_task_to_env, - step=step + 1, - master_config=master_config, - logger=logger, - ) - policy_generation.finish_generation() - logger.log_metrics( - validation_timings, step + 1, prefix="timing/validation" - ) - logger.log_metrics(val_metrics, step + 1, prefix="validation") - - # Explicit GPU memory cleanup after validation in async mode - import gc - - gc.collect() - torch.cuda.empty_cache() - - # Resume trajectory collection after validation - trajectory_collector.resume.remote() - # Get flat advantages and token mask for masked metrics computation - flat_advantages = train_data["advantages"] - flat_token_mask = flat_messages["token_loss_mask"] - # Save content for logging before deleting flat_messages - flat_messages_content = flat_messages.get("content", []) - del flat_messages - - # Filter advantages using token mask (only valid response tokens) - response_advantages = torch.masked_select( - flat_advantages, flat_token_mask.bool() - ) - - metrics = { - "loss": train_results["loss"].numpy(), - "reward": rewards.numpy(), - "grad_norm": train_results["grad_norm"].numpy(), - "mean_prompt_length": repeated_batch["length"].numpy(), - "total_num_tokens": input_lengths.numpy(), - # Add masked advantages tracking metrics (only for valid response tokens) - "advantages/mean": torch.mean(response_advantages).detach().item() - if response_advantages.numel() > 0 - else 0.0, - "advantages/max": torch.max(response_advantages).detach().item() - if response_advantages.numel() > 0 - else 0.0, - "advantages/min": torch.min(response_advantages).detach().item() - if response_advantages.numel() > 0 - else 0.0, - } - if "moe_metrics" in train_results: - metrics.update( - {f"moe/{k}": v for k, v in train_results["moe_metrics"].items()} - ) - metrics.update(train_results["all_mb_metrics"]) - for k, v in metrics.items(): - if k in {"probs_ratio_min", "probs_ratio_clamped_min"}: - valid_values = [x for x in v if not np.isinf(x)] - metrics[k] = ( - np.min(valid_values).item() if valid_values else -1.0 - ) - elif k in {"probs_ratio_max", "probs_ratio_clamped_max"}: - valid_values = [x for x in v if not np.isinf(x)] - metrics[k] = ( - np.max(valid_values).item() if valid_values else -1.0 - ) - elif k in { - "lr", - "wd", - "reward", - "global_valid_seqs", - "global_valid_toks", - "mean_prompt_length", - }: - metrics[k] = np.mean(v).item() - else: - metrics[k] = np.sum(v).item() - metrics.update(rollout_metrics) - if generation_logger_metrics is not None: - metrics["generation_logger_metrics"] = generation_logger_metrics - total_valid_tokens += metrics["global_valid_toks"] - - # Always log sequence-level error metrics (useful for deciding threshold) - metrics["max_seq_mult_prob_error"] = max_seq_mult_prob_error - metrics["num_masked_seqs_by_logprob_error"] = num_masked_seqs - metrics["masked_correct_pct"] = masked_correct_pct - - # Checkpointing (same as sync version) - consumed_samples += master_config["grpo"]["num_prompts_per_step"] - timeout.mark_iteration() - - should_save_by_step = ( - is_last_step - or (step + 1) % master_config["checkpointing"]["save_period"] == 0 - ) - # +1 because step is 0-indexed - # Check if timeout-based checkpointing is enabled in config. - should_save_by_timeout = timeout.check_save() - - if master_config["checkpointing"]["enabled"] and ( - should_save_by_step or should_save_by_timeout - ): - grpo_save_state["current_step"] = step + 1 - grpo_save_state["total_valid_tokens"] = total_valid_tokens - if val_metrics is not None: - grpo_save_state["val_reward"] = val_metrics["accuracy"] - elif "val_reward" in grpo_save_state: - del grpo_save_state["val_reward"] - grpo_save_state["consumed_samples"] = consumed_samples - - full_metric_name = master_config["checkpointing"]["metric_name"] - if full_metric_name is not None: - assert full_metric_name.startswith( - "train:" - ) or full_metric_name.startswith("val:"), ( - f"metric_name={full_metric_name} must start with 'val:' or 'train:',\n" - f'followed by the corresponding name in the "val" or "train" metrics dictionary.' - f" If you are using an old config, please updated checkpointing.metric_name to the new format, " - f" e.g. 'val_reward --> 'val:accuracy'" - ) - prefix, metric_name = full_metric_name.split(":", 1) - metrics_source = metrics if prefix == "train" else val_metrics - if not metrics_source: - warnings.warn( - f"You asked to save checkpoints based on {metric_name} but no {prefix} metrics were collected. " - "This checkpoint will not be saved as top-k.", - stacklevel=2, - ) - if full_metric_name in grpo_save_state: - del grpo_save_state[full_metric_name] - elif metric_name not in metrics_source: - raise ValueError( - f"Metric {metric_name} not found in {prefix} metrics" - ) - else: - grpo_save_state[full_metric_name] = metrics_source[ - metric_name - ] - - with timer.time("checkpointing"): - print(f"Saving checkpoint for step {step + 1}...") - checkpoint_path = checkpointer.init_tmp_checkpoint( - step + 1, grpo_save_state, master_config - ) - policy.save_checkpoint( - weights_path=os.path.join( - checkpoint_path, "policy", "weights" - ), - optimizer_path=os.path.join( - checkpoint_path, "policy", "optimizer" - ) - if checkpointer.save_optimizer - else None, - tokenizer_path=os.path.join( - checkpoint_path, "policy", "tokenizer" - ), - checkpointing_cfg=master_config["checkpointing"], - ) - # Get dataloader state from trajectory collector - actual_dataloader_state = ray.get( - trajectory_collector.get_dataloader_state.remote() - ) - torch.save( - actual_dataloader_state, - os.path.join(checkpoint_path, "train_dataloader.pt"), - ) - checkpointer.finalize_checkpoint(checkpoint_path) - - # Logging - # Log training data (match sync GRPO logging payload for parity) - log_data = {} - if "agent_ref" in repeated_batch: - log_data["agent_ref"] = repeated_batch["agent_ref"] - log_data["content"] = flat_messages_content - log_data["rewards"] = rewards.tolist() - if master_config["grpo"]["use_dynamic_sampling"]: - # In dynamic sampling, `rewards` corresponds to filtered rewards - log_data["filtered_rewards"] = rewards.tolist() - log_data["rewards"] = repeated_batch["total_reward"].tolist() - log_data["input_lengths"] = input_lengths.tolist() - log_data["token_ids"] = train_data["input_ids"].tolist() - log_data["token_loss_mask"] = train_data["token_mask"].tolist() - log_data["sample_loss_mask"] = train_data["sample_mask"].tolist() - log_data["advantages"] = train_data["advantages"].tolist() - log_data["generation_logprobs"] = train_data["generation_logprobs"].tolist() - log_data["prev_logprobs"] = train_data["prev_logprobs"].tolist() - logger.log_batched_dict_as_jsonl( - log_data, f"train_data_step{step + 1}.jsonl" - ) - del train_data - del flat_messages_content - - timing_metrics: dict[str, float] = timer.get_timing_metrics( - reduction_op="sum" - ) - - # Add buffer stats - buffer_size_current = ray.get(replay_buffer.size.remote()) - metrics["buffer_size"] = buffer_size_current - metrics["avg_trajectory_age"] = avg_trajectory_age - - if master_config["policy"]["generation"].get("vllm_cfg", {}).get( - "enable_vllm_metrics_logger", False - ) and master_config.get("logger", {}).get("wandb_enabled", False): - log_generation_metrics_to_wandb( - generation_logger_metrics, - step + 1, - master_config["policy"]["generation"]["vllm_cfg"][ - "vllm_metrics_logger_interval" - ], - logger, - ) - - # Plot ISL/OSL/ISL+OSL histograms to wandb - if ( - master_config["policy"]["generation"] - .get("vllm_cfg", {}) - .get("async_engine", False) - ): - for metric_name in metrics.keys(): - if metric_name.startswith("histogram/"): - logger.log_histogram( - metrics[metric_name], - step + 1, - f"generation_metrics/{metric_name}", - ) - - print("\n📊 Training Results:") - print(f" • Loss: {metrics['loss']:.4f}") - print(f" • Generation KL Error: {metrics['gen_kl_error']:.4f}") - print(f" • Avg Reward: {np.mean(rewards.numpy()):.4f}") - print(f" • Buffer Size: {buffer_size_current}") - print(f" • Avg Trajectory Age: {avg_trajectory_age:.2f} steps") - - print("\n⏱️ Timing:") - total_time = timing_metrics.get("total_step_time", 0) - print(f" • Total step time: {total_time:.2f}s") - for k, v in sorted( - timing_metrics.items(), key=lambda item: item[1], reverse=True - ): - if k != "total_step_time": - percent = (v / total_time * 100) if total_time > 0 else 0 - print(f" • {k}: {v:.2f}s ({percent:.1f}%)") - - total_num_gpus = ( - master_config["cluster"]["num_nodes"] - * master_config["cluster"]["gpus_per_node"] - ) - timing_metrics["valid_tokens_per_sec_per_gpu"] = ( - metrics["global_valid_toks"] / total_time / total_num_gpus - ) - performance_metrics = print_performance_metrics( - train_results, metrics, timing_metrics, master_config - ) - - logger.log_metrics(performance_metrics, step + 1, prefix="performance") - logger.log_metrics(metrics, step + 1, prefix="train") - logger.log_metrics(timing_metrics, step + 1, prefix="timing/train") - - timer.reset() - step += 1 - if should_save_by_timeout: - print("Timeout has been reached, stopping training early", flush=True) - return - if step >= master_config["grpo"]["max_num_steps"]: - print( - "Max number of steps has been reached, stopping training early", - flush=True, - ) - return - - except Exception as e: - print(f"❌ Error in async loop: {e}") - import traceback - - traceback.print_exc() - - finally: - # Clean up - print("🛑 Stopping trajectory collection...") - try: - ray.kill(trajectory_collector) - except Exception as e: - print(f"Error stopping trajectory collector: {e}") - - try: - ray.kill(replay_buffer) - except Exception as e: - print(f"Error stopping replay buffer: {e}") - - print("Async GRPO training complete!") diff --git a/shim/nemo_rl/environments/__init__.py b/shim/nemo_rl/environments/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/shim/nemo_rl/environments/erdos_discovery_environment.py b/shim/nemo_rl/environments/erdos_discovery_environment.py deleted file mode 100644 index 70c4710fca..0000000000 --- a/shim/nemo_rl/environments/erdos_discovery_environment.py +++ /dev/null @@ -1,362 +0,0 @@ -"""Erdős Discovery Environment for NeMo RL. - -Implements EnvironmentInterface for TTT-Discover with the Erdős Minimum -Overlap Problem. Calls the NeMo Gym resource server for code execution -and reward computation. - -Reference: "Learning to Discover at Test Time" (arXiv:2601.16175) - -The environment: - 1. Receives LLM-generated code from the GRPO rollout - 2. Sends it to the Erdős Gym resource server for sandboxed execution + scoring - 3. Returns reward = 1/bound (or 0 on failure) - 4. Tracks best constructions and buffer statistics via metrics -""" - -import logging -import math -from typing import Any, Optional - -import aiohttp -import ray -import torch - -from nemo_rl.environments.interfaces import EnvironmentInterface, EnvironmentReturn - -logger = logging.getLogger(__name__) - - -# ═══════════════════════════════════════════════════════════════════ -# Inline reward computation (no Gym server needed for debug/testing) -# ═══════════════════════════════════════════════════════════════════ - -def _inline_compute_reward(response_text: str, timeout: int = 60) -> dict: - """Compute reward directly in-process. No HTTP call needed.""" - import re - import signal - import builtins - import math as _math - import itertools as _itertools - import functools as _functools - import collections as _collections - - import numpy as _np - from numpy.fft import rfft, irfft - - _ALLOWED_MODULES = frozenset({ - "numpy", "np", "math", "cmath", "random", - "itertools", "functools", "collections", "fractions", "decimal", - }) - _SAFE_BUILTIN_NAMES = [ - "abs", "all", "any", "bool", "dict", "divmod", "enumerate", - "filter", "float", "format", "int", "isinstance", "issubclass", - "iter", "len", "list", "map", "max", "min", "next", "object", - "print", "range", "repr", "reversed", "round", "set", "slice", - "sorted", "str", "sum", "tuple", "type", "zip", - "Exception", "ValueError", "TypeError", "KeyError", "IndexError", - "StopIteration", "RuntimeError", "NotImplementedError", - "OverflowError", "ZeroDivisionError", "AttributeError", - ] - - # Extract code - code_re = re.compile(r"```(?:python)?\s*\n(.*?)```", re.DOTALL) - blocks = code_re.findall(response_text) - code = blocks[-1].strip() if blocks else response_text.strip() - - # Build sandbox - import random as _random - safe_builtins = {k: getattr(builtins, k) for k in _SAFE_BUILTIN_NAMES if hasattr(builtins, k)} - def _safe_import(name, globals=None, locals=None, fromlist=(), level=0): - if name.split(".")[0] not in _ALLOWED_MODULES: - raise ImportError(f"Module '{name}' not allowed") - return builtins.__import__(name, globals, locals, fromlist, level) - safe_builtins["__import__"] = _safe_import - namespace = { - "__builtins__": safe_builtins, - "np": _np, "numpy": _np, "math": _math, "random": _random, - "itertools": _itertools, "functools": _functools, "collections": _collections, - } - - try: - class _Timeout(Exception): - pass - def _handler(s, f): - raise _Timeout() - old = signal.signal(signal.SIGALRM, _handler) - signal.alarm(timeout) - try: - exec(compile(code, "", "exec"), namespace) - finally: - signal.alarm(0) - signal.signal(signal.SIGALRM, old) - - if "f" not in namespace: - return {"reward": 0.0, "bound": None, "error_msg": "no variable f"} - - f = _np.asarray(namespace["f"], dtype=float).flatten() - - # Validate - if len(f) < 1 or len(f) > 1000: - return {"reward": 0.0, "bound": None, "error_msg": "bad length"} - if _np.any(~_np.isfinite(f)) or _np.any(f < 0) or _np.any(f > 1): - return {"reward": 0.0, "bound": None, "error_msg": "bad values"} - if abs(float(_np.mean(f)) - 0.5) > 1e-3: - return {"reward": 0.0, "bound": None, "error_msg": "bad mean"} - - # Compute bound - n = len(f) - F = rfft(f, n=2*n) - autocorr = irfft(F * _np.conj(F), n=2*n) - bound = float(2 * n * _np.max(autocorr.real) / (_np.sum(f)**2)) - - if bound <= 0 or not _math.isfinite(bound): - return {"reward": 0.0, "bound": None, "error_msg": "bad bound"} - return {"reward": 1.0 / bound, "bound": bound, "error_msg": ""} - - except Exception as e: - return {"reward": 0.0, "bound": None, "error_msg": str(e)[:200]} - -# Type alias matching NeMo RL's convention -LLMMessageLogType = list[dict[str, Any]] -ErdosMetadata = dict[str, Any] - - -@ray.remote(max_restarts=-1, max_task_retries=-1) -class ErdosDiscoveryEnvironment(EnvironmentInterface[ErdosMetadata]): - """Erdős Minimum Overlap Problem environment for GRPO training. - - Communicates with the NeMo Gym Erdős resource server via HTTP for: - - /verify: code execution + reward computation - - /select_state: PUCT state selection for prompts - - /seed_session: buffer initialization - - /compute_entropic_advantages: LOO entropic advantages - - /update_buffer: add new discoveries to PUCT tree - - Config (under env.erdos_discovery): - resource_server_url: Base URL of the Erdős Gym resource server. - seed: Random seed for PUCT buffer initialization. - num_initial_states: States to seed the buffer with. - sandbox_timeout: Code execution timeout in seconds. - """ - - def __init__(self, config: dict): - self.config = config - self.resource_server_url = config.get( - "resource_server_url", "http://localhost:8080" - ) - self.seed = config.get("seed", None) - self.num_initial_states = config.get("num_initial_states", 16) - self.sandbox_timeout = config.get("sandbox_timeout", 600) - self.request_timeout = config.get("request_timeout", 660) - - self.best_reward = 0.0 - self.best_bound = float("inf") - self.total_verified = 0 - self.total_valid = 0 - self._session_initialized = False - self._inline_mode = (self.resource_server_url == "inline") - if self._inline_mode: - logger.info("ErdosDiscovery: running in INLINE mode (no Gym server)") - self._session_initialized = True # No server to init - - async def _ensure_session(self): - """Initialize the PUCT buffer on the resource server if not done.""" - if self._session_initialized: - return - try: - timeout = aiohttp.ClientTimeout(total=30) - async with aiohttp.ClientSession(timeout=timeout) as session: - async with session.post( - f"{self.resource_server_url}/seed_session", - json={ - "num_initial_states": self.num_initial_states, - "seed": self.seed, - }, - ) as resp: - data = await resp.json() - self.best_reward = data.get("best_initial_reward", 0.0) - self.best_bound = data.get( - "best_initial_bound", float("inf") - ) - logger.info( - "ErdosDiscovery: seeded buffer with %d states, " - "best_reward=%.4f, best_bound=%.6f", - data.get("num_states", 0), - self.best_reward, - self.best_bound, - ) - self._session_initialized = True - except Exception as e: - logger.error("ErdosDiscovery: seed_session failed: %s", e) - - async def _verify_single( - self, - session: Optional[aiohttp.ClientSession], - response_text: str, - parent_state: Optional[list[float]] = None, - ) -> dict: - """Call /verify on the resource server, or compute inline.""" - if self._inline_mode: - return _inline_compute_reward( - response_text, timeout=self.sandbox_timeout - ) - # Build a minimal NeMoGymResponse-like payload - # The resource server extracts output_text from response.output_text - body = { - "responses_create_params": { - "input": [{"role": "user", "content": ""}], - }, - "response": { - "id": "verify", - "output": [ - { - "type": "message", - "role": "assistant", - "content": [{"type": "output_text", "text": response_text}], - } - ], - "output_text": response_text, - }, - "parent_state": parent_state, - } - try: - timeout = aiohttp.ClientTimeout(total=self.request_timeout) - async with session.post( - f"{self.resource_server_url}/verify", - json=body, - timeout=timeout, - ) as resp: - return await resp.json() - except Exception as e: - logger.warning("ErdosDiscovery: verify failed: %s", e) - return {"reward": 0.0, "bound": None, "error_msg": str(e)} - - def step( - self, - message_log_batch: list[LLMMessageLogType], - metadata: list[ErdosMetadata], - ) -> EnvironmentReturn[ErdosMetadata]: - """Evaluate a batch of LLM responses. - - Extracts the assistant's last message from each conversation, - sends it to the resource server for code execution + scoring, - returns rewards. - """ - import asyncio - - return asyncio.get_event_loop().run_until_complete( - self._async_step(message_log_batch, metadata) - ) - - async def _async_step( - self, - message_log_batch: list[LLMMessageLogType], - metadata: list[ErdosMetadata], - ) -> EnvironmentReturn[ErdosMetadata]: - await self._ensure_session() - - batch_size = len(message_log_batch) - rewards = torch.zeros(batch_size) - terminateds = torch.ones(batch_size) # Always single-turn - observations = [{}] * batch_size - answers = [None] * batch_size - updated_metadata = list(metadata) - - if self._inline_mode: - session = None - else: - session = aiohttp.ClientSession( - timeout=aiohttp.ClientTimeout(total=self.request_timeout) - ) - - try: - tasks = [] - for i, message_log in enumerate(message_log_batch): - # Extract the last assistant message - response_text = "" - for msg in reversed(message_log): - if msg.get("role") == "assistant": - response_text = msg.get("content", "") - break - - # Get parent_state from metadata if available - parent_state = None - if metadata and i < len(metadata): - parent_state = metadata[i].get("parent_state", None) - - tasks.append( - self._verify_single(session, response_text, parent_state) - ) - - results = await asyncio.gather(*tasks, return_exceptions=True) - finally: - if session is not None: - await session.close() - - for i, result in enumerate(results): - if isinstance(result, Exception): - logger.warning( - "ErdosDiscovery: verify exception for sample %d: %s", - i, - result, - ) - continue - - reward = result.get("reward", 0.0) - rewards[i] = reward - self.total_verified += 1 - - if reward > 0: - self.total_valid += 1 - bound = result.get("bound", None) - if reward > self.best_reward: - self.best_reward = reward - self.best_bound = bound or ( - 1.0 / reward if reward > 0 else float("inf") - ) - - answers[i] = ( - f"bound={bound:.6f}" if bound else f"reward={reward:.4f}" - ) - - # Update metadata with verification results - if i < len(updated_metadata): - updated_metadata[i] = { - **updated_metadata[i], - "reward": reward, - "bound": result.get("bound"), - "error_msg": result.get("error_msg", ""), - "best_reward_ever": result.get( - "best_reward_ever", self.best_reward - ), - } - - return EnvironmentReturn( - observations=observations, - metadata=updated_metadata, - next_stop_strings=[None] * batch_size, - rewards=rewards, - terminateds=terminateds, - answers=answers, - ) - - def global_post_process_and_metrics( - self, batch: dict - ) -> tuple[dict, dict]: - """Compute and return environment-level metrics.""" - valid_rate = ( - self.total_valid / max(self.total_verified, 1) - ) - metrics = { - "env/best_reward": self.best_reward, - "env/best_bound": self.best_bound - if self.best_bound < float("inf") - else 0.0, - "env/total_verified": self.total_verified, - "env/valid_rate": valid_rate, - } - return batch, metrics - - def shutdown(self): - """Cleanup.""" - pass diff --git a/shim/nemo_rl/environments/utils.py b/shim/nemo_rl/environments/utils.py deleted file mode 100644 index a1bc6ace3f..0000000000 --- a/shim/nemo_rl/environments/utils.py +++ /dev/null @@ -1,139 +0,0 @@ -# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import os -from typing import Any, Dict, NotRequired, TypedDict - -from hydra.utils import get_object - -from nemo_rl.distributed.ray_actor_environment_registry import get_actor_python_env -from nemo_rl.environments.interfaces import EnvironmentInterface -from nemo_rl.utils.venvs import create_local_venv_on_each_node - - -# Environment registry entry schema. -class EnvRegistryEntry(TypedDict, total=False): - actor_class_fqn: str - default_processor: NotRequired[str] - - -# Environment registry. Key is the env name, value is a dictionary with the actor class FQN and optional default processor. -ENV_REGISTRY: Dict[str, EnvRegistryEntry] = { - "math_default": { - "actor_class_fqn": "nemo_rl.environments.math_environment.MathEnvironment", - }, - "math": { - "actor_class_fqn": "nemo_rl.environments.math_environment.MathEnvironment", - }, - "math_multi_reward": { - "actor_class_fqn": "nemo_rl.environments.math_environment.MathMultiRewardEnvironment", - }, - "code": { - "actor_class_fqn": "nemo_rl.environments.code_environment.CodeEnvironment", - }, - "reward_model": { - "actor_class_fqn": "nemo_rl.environments.reward_model_environment.RewardModelEnvironment", - }, - "code_jaccard": { - "actor_class_fqn": "nemo_rl.environments.code_jaccard_environment.CodeJaccardEnvironment", - }, - "vlm": { - "actor_class_fqn": "nemo_rl.environments.vlm_environment.VLMEnvironment", - }, - "erdos_discovery": { - "actor_class_fqn": "nemo_rl.environments.erdos_discovery_environment.ErdosDiscoveryEnvironment", - }, - "nemo_gym": { - "actor_class_fqn": "nemo_rl.environments.nemo_gym.NemoGym", - }, -} - - -def chunk_list_to_workers(to_chunk: list[Any], num_workers: int) -> list[list[Any]]: - """Chunk a list into a list of lists, where each sublist is assigned to a worker. Keeps ordering of elements. - - If the list is not divisible by the number of workers, the last worker may have fewer elements. - If there are more workers than elements, the first len(list) workers will have a single element each, - and the remaining workers will have empty lists. - - Args: - list: The list to be chunked. - num_workers: The number of workers to distribute the list to. - - Returns: - A list of lists, where each sublist contains elements assigned to a worker. - - Examples: - ```{doctest} - >>> from nemo_rl.environments.utils import chunk_list_to_workers - >>> chunk_list_to_workers([1, 2, 3, 4, 5], 3) - [[1, 2], [3, 4], [5]] - ``` - """ - if not to_chunk: - return [[] for _ in range(num_workers)] - - # Handle case where we have more workers than elements - if len(to_chunk) <= num_workers: - result = [[item] for item in to_chunk] - # Add empty lists for remaining workers - result.extend([[] for _ in range(num_workers - len(to_chunk))]) - return result - - # Calculate chunk size (ceiling division to ensure all elements are covered) - chunk_size = (len(to_chunk) + num_workers - 1) // num_workers - - # Create chunks - chunks = [] - for i in range(0, len(to_chunk), chunk_size): - chunks.append(to_chunk[i : i + chunk_size]) - - # If we somehow ended up with more chunks than workers (shouldn't happen with ceiling division) - # merge the last chunks - if len(chunks) > num_workers: - chunks[num_workers - 1 :] = [sum(chunks[num_workers - 1 :], [])] - - return chunks - - -def create_env(env_name: str, env_config: dict) -> EnvironmentInterface: - assert env_name in ENV_REGISTRY, ( - f"Env name {env_name} is not registered in ENV_REGISTRY. Please call register_env() to register the environment." - ) - actor_class_fqn = ENV_REGISTRY[env_name]["actor_class_fqn"] - actor_class = get_object(actor_class_fqn) - actor_py_exec = get_actor_python_env(actor_class_fqn) - extra_env_vars = {} - if actor_py_exec.startswith("uv"): - actor_py_exec = create_local_venv_on_each_node( - actor_py_exec, - actor_class_fqn, - ) - extra_env_vars = { - "VIRTUAL_ENV": actor_py_exec, - "UV_PROJECT_ENVIRONMENT": actor_py_exec, - } - env = actor_class.options( # type: ignore # it's wrapped with ray.remote - runtime_env={ - "py_executable": actor_py_exec, - "env_vars": {**dict(os.environ), **extra_env_vars}, - } - ).remote(env_config) - return env - - -def register_env(env_name: str, actor_class_fqn: str) -> None: - if env_name in ENV_REGISTRY: - raise ValueError(f"Env name {env_name} already registered") - - ENV_REGISTRY[env_name] = {"actor_class_fqn": actor_class_fqn} diff --git a/shim/nemo_rl/utils/__init__.py b/shim/nemo_rl/utils/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/shim/nemo_rl/utils/puct_buffer.py b/shim/nemo_rl/utils/puct_buffer.py deleted file mode 100644 index 53f808d783..0000000000 --- a/shim/nemo_rl/utils/puct_buffer.py +++ /dev/null @@ -1,561 +0,0 @@ -""" -PUCT buffer for TTT-Discover state reuse. - -Reference: "Learning to Discover at Test Time" (arXiv:2601.04116) - -The buffer maintains a tree of (state, reward) nodes. At each training step, -PUCT scoring selects which states to warm-start rollouts from, balancing: - - Exploitation: states whose children have achieved high rewards (high Q) - - Exploration: states that haven't been visited much yet (low n) - -Pure data structure — no ML framework dependencies. -""" - -import math -import dataclasses -from typing import Any, Optional - -import numpy as np - - -# --------------------------------------------------------------------------- -# Internal node -# --------------------------------------------------------------------------- - -@dataclasses.dataclass -class _Node: - state: Any - reward: float # reward of THIS state (from its own evaluation) - parent_key: Any # key of parent node, or None for roots - children_keys: list # keys of direct children - n: int # visit count (number of times selected for expansion) - Q: float # max reward among all descendants (or own reward if leaf) - - -# --------------------------------------------------------------------------- -# Helpers -# --------------------------------------------------------------------------- - -def _make_key(state: Any) -> Any: - """Convert state to a hashable key. - - Supports: str, int, float, tuple, list, np.ndarray, and arbitrary objects - (fallback: id-based, so two different objects with equal content are - treated as distinct — acceptable for LLM response strings). - """ - if isinstance(state, (str, int, float, bool)): - return state - if isinstance(state, np.ndarray): - return (state.dtype, state.shape, state.tobytes()) - if isinstance(state, (list, tuple)): - return tuple(_make_key(x) for x in state) - # Fallback: identity-based key — wrap id so it doesn't collide with ints - return ("__id__", id(state)) - - -# --------------------------------------------------------------------------- -# PUCTBuffer -# --------------------------------------------------------------------------- - -class PUCTBuffer: - """ - Tree-structured buffer with PUCT selection. - - PUCT score for node s: - score(s) = Q(s) + c · P(s) · sqrt(1 + T) / (1 + n(s)) - - Where: - Q(s) = max reward among all descendants of s (own reward if leaf) - P(s) = rank-based prior: rank states by reward, normalize by total rank - n(s) = visit count of s - T = total visit count across all nodes - c = exploration constant (default 1.0) - """ - - def __init__(self, c: float = 1.0) -> None: - self.c = c - self._nodes: dict[Any, _Node] = {} # key → _Node - self._T: int = 0 # total expansions so far - - # ------------------------------------------------------------------ - # Public API - # ------------------------------------------------------------------ - - def add(self, state: Any, reward: float, parent_state: Any = None) -> None: - """Insert a new node into the buffer. - - If the state is already present, this is a no-op (deduplication). - If parent_state is given and present in the buffer, the new node is - linked as a child and Q values are propagated upward. - - Args: - state: The state to insert (any type with a consistent identity). - reward: Scalar reward associated with this state. - parent_state: Parent state, or None for a root node. - """ - key = _make_key(state) - if key in self._nodes: - return # already present — deduplicate - - parent_key = _make_key(parent_state) if parent_state is not None else None - node = _Node( - state=state, - reward=float(reward), - parent_key=parent_key, - children_keys=[], - n=0, - Q=float(reward), # leaf: Q = own reward - ) - self._nodes[key] = node - - if parent_key is not None and parent_key in self._nodes: - self._nodes[parent_key].children_keys.append(key) - self._propagate_Q(parent_key) - - def select( - self, batch_size: int, num_groups: int = 8 - ) -> list[tuple[Any, list]]: - """Select states to warm-start rollouts from. - - Scores each node with PUCT, picks the top `num_groups` distinct states, - and returns `batch_size` (state, context) pairs grouped so that each - group of `batch_size // num_groups` entries shares the same state. - - Context is the ancestry path from root to the selected node: - [(ancestor_state, ancestor_reward), ..., (selected_state, selected_reward)] - The env uses this to build the prompt (previous attempts / warm start). - - Visit counts are incremented for the selected nodes, and T is updated. - - Args: - batch_size: Total number of (state, context) pairs to return. - Must be divisible by num_groups. - num_groups: Number of distinct initial states to select. - - Returns: - List of (state, context) tuples, length == batch_size. - """ - if not self._nodes: - raise ValueError("Buffer is empty — call add() before select()") - if batch_size % num_groups != 0: - raise ValueError( - f"batch_size ({batch_size}) must be divisible by num_groups ({num_groups})" - ) - rollouts_per_group = batch_size // num_groups - - priors = self._rank_priors() - scores = { - key: self._puct_score(node, priors[key]) - for key, node in self._nodes.items() - } - - # Top num_groups keys by PUCT score (at most len(nodes) if buffer is small) - k = min(num_groups, len(self._nodes)) - top_keys = sorted(scores, key=lambda x: scores[x], reverse=True)[:k] - - result: list[tuple[Any, list]] = [] - for key in top_keys: - node = self._nodes[key] - context = self._ancestry(key) - pair = (node.state, context) - result.extend([pair] * rollouts_per_group) - # Increment visit count for this selection - node.n += 1 - self._T += 1 - - return result - - def update( - self, parent_state: Any, child_state: Any, reward: float - ) -> None: - """Add a child node and update Q values up the tree. - - Convenience wrapper around add() that makes the parent/child - relationship explicit. - - Args: - parent_state: The state that was selected and rolled out from. - child_state: The resulting new state produced by the rollout. - reward: Reward of the new child state. - """ - self.add(child_state, reward, parent_state=parent_state) - - def best(self) -> tuple[Any, float]: - """Return the (state, reward) with the highest reward ever seen. - - Returns: - (state, reward) tuple. - """ - if not self._nodes: - raise ValueError("Buffer is empty") - best_key = max(self._nodes, key=lambda k: self._nodes[k].reward) - node = self._nodes[best_key] - return node.state, node.reward - - def __len__(self) -> int: - return len(self._nodes) - - # ------------------------------------------------------------------ - # Private helpers - # ------------------------------------------------------------------ - - def _puct_score(self, node: _Node, prior: float) -> float: - return node.Q + self.c * prior * math.sqrt(1 + self._T) / (1 + node.n) - - def _rank_priors(self) -> dict[Any, float]: - """Rank-based prior: rank by node reward, normalize by sum of ranks. - - Rank 1 = lowest reward, rank N = highest. Ties get the same rank - (average of tied ranks), consistent with scipy.stats.rankdata. - """ - keys = list(self._nodes.keys()) - rewards = np.array([self._nodes[k].reward for k in keys], dtype=float) - - # argsort twice gives rank (0-indexed); add 1 to make 1-indexed - order = np.argsort(rewards) - ranks = np.empty_like(order, dtype=float) - ranks[order] = np.arange(1, len(rewards) + 1, dtype=float) - - # Handle ties: assign average rank to tied rewards. - # Use ranks[tied].mean() — not tied.mean()+1, which would use array - # indices instead of the already-assigned rank values. - # (simple O(N²) loop is fine for buffer sizes we care about) - for i, r in enumerate(rewards): - tied = np.where(rewards == r)[0] - if len(tied) > 1: - ranks[tied] = ranks[tied].mean() - - total = ranks.sum() - return {k: float(ranks[i] / total) for i, k in enumerate(keys)} - - def _propagate_Q(self, key: Any) -> None: - """Propagate max-Q upward from `key` to the root.""" - node = self._nodes[key] - if node.children_keys: - child_rewards = [ - self._nodes[ck].Q - for ck in node.children_keys - if ck in self._nodes - ] - new_Q = max(node.reward, max(child_rewards)) if child_rewards else node.reward - else: - new_Q = node.reward - - if new_Q == node.Q: - return # no change — stop propagation - - node.Q = new_Q - if node.parent_key is not None and node.parent_key in self._nodes: - self._propagate_Q(node.parent_key) - - def _ancestry(self, key: Any) -> list[tuple[Any, float]]: - """Return the path from root to `key` as [(state, reward), ...].""" - path = [] - cur = key - while cur is not None: - node = self._nodes[cur] - path.append((node.state, node.reward)) - cur = node.parent_key - path.reverse() - return path - - -# --------------------------------------------------------------------------- -# Unit tests -# --------------------------------------------------------------------------- - -def _run_tests() -> None: - import sys - - failures: list[str] = [] - - def check(name: str, cond: bool, msg: str = "") -> None: - if not cond: - failures.append(f"FAIL [{name}]: {msg}") - else: - print(f" PASS [{name}]") - - print("=== puct_buffer unit tests ===\n") - - # ------------------------------------------------------------------ - # Basic add / best - # ------------------------------------------------------------------ - print("-- add / best --") - - buf = PUCTBuffer(c=1.0) - buf.add("s0", 0.5) - buf.add("s1", 0.8) - buf.add("s2", 0.3) - - state, reward = buf.best() - check("best_returns_max_reward_state", reward == 0.8, f"reward={reward}") - check("best_returns_correct_state", state == "s1", f"state={state!r}") - check("len_after_adds", len(buf) == 3, f"len={len(buf)}") - - # Duplicate add is a no-op - buf.add("s0", 99.0) - check("duplicate_add_noop", len(buf) == 3, "duplicate changed buffer size") - check("duplicate_reward_unchanged", buf._nodes[_make_key("s0")].reward == 0.5) - - # ------------------------------------------------------------------ - # Q uses MAX not mean - # ------------------------------------------------------------------ - print("\n-- Q = MAX not mean --") - - buf2 = PUCTBuffer() - buf2.add("root", 0.0) - buf2.add("child_low", 0.1, parent_state="root") - buf2.add("child_high", 0.9, parent_state="root") - - root_node = buf2._nodes[_make_key("root")] - check( - "Q_is_max_not_mean", - root_node.Q == 0.9, - f"root.Q={root_node.Q}, expected 0.9 (max), mean would be 0.5", - ) - - # Add another child with even higher reward — Q should update - buf2.add("child_best", 0.95, parent_state="root") - check( - "Q_updates_when_better_child_added", - root_node.Q == 0.95, - f"root.Q={root_node.Q}, expected 0.95", - ) - - # ------------------------------------------------------------------ - # Q propagates through grandchildren (MAX of all descendants) - # ------------------------------------------------------------------ - print("\n-- Q propagation --") - - buf3 = PUCTBuffer() - buf3.add("r", 0.0) - buf3.add("c1", 0.3, parent_state="r") - buf3.add("gc", 0.99, parent_state="c1") # grandchild - - r_node = buf3._nodes[_make_key("r")] - c1_node = buf3._nodes[_make_key("c1")] - check("grandchild_Q_propagates_to_child", c1_node.Q == 0.99, f"c1.Q={c1_node.Q}") - check("grandchild_Q_propagates_to_root", r_node.Q == 0.99, f"r.Q={r_node.Q}") - - # Parent with high own reward should NOT lose Q when children underperform - buf3b = PUCTBuffer() - buf3b.add("great_parent", 0.9) - buf3b.add("weak_child", 0.2, parent_state="great_parent") - gp_node = buf3b._nodes[_make_key("great_parent")] - check( - "parent_Q_not_lowered_by_weak_child", - gp_node.Q == 0.9, - f"great_parent.Q={gp_node.Q}, expected 0.9 (own reward dominates)", - ) - - # ------------------------------------------------------------------ - # Rank priors: ties get correct average rank (not index-based) - # ------------------------------------------------------------------ - print("\n-- rank prior tie handling --") - - buf_ties = PUCTBuffer() - # rewards: s0=0.1 (rank 1), s1=0.5 (tied), s2=0.3 (rank 2), s3=0.5 (tied) - # After tie-averaging: s0→1, s2→2, s1&s3→(3+4)/2=3.5 - buf_ties.add("s0", 0.1) - buf_ties.add("s1", 0.5) - buf_ties.add("s2", 0.3) - buf_ties.add("s3", 0.5) - priors_ties = buf_ties._rank_priors() - p1 = priors_ties[_make_key("s1")] - p3 = priors_ties[_make_key("s3")] - p2 = priors_ties[_make_key("s2")] - check("tied_states_equal_prior", abs(p1 - p3) < 1e-9, f"p1={p1:.6f} p3={p3:.6f}") - check("tied_states_outrank_lower", p1 > p2, f"tied={p1:.4f} vs s2={p2:.4f}") - - # ------------------------------------------------------------------ - # update() convenience wrapper - # ------------------------------------------------------------------ - print("\n-- update() --") - - buf4 = PUCTBuffer() - buf4.add("p", 0.5) - buf4.update("p", "child_via_update", 0.7) - check("update_adds_child", len(buf4) == 2, f"len={len(buf4)}") - check("update_links_child", "child_via_update" in [ - buf4._nodes[ck].state for ck in buf4._nodes[_make_key("p")].children_keys - ]) - - # ------------------------------------------------------------------ - # Exploration: unvisited high-reward states get selected - # ------------------------------------------------------------------ - print("\n-- exploration: unvisited high-reward states --") - - buf5 = PUCTBuffer(c=1.0) - # Old state, visited many times - buf5.add("visited", 0.6) - buf5._nodes[_make_key("visited")].n = 100 - # New high-reward state, never visited - buf5.add("fresh_high", 0.9) - - selected = buf5.select(batch_size=2, num_groups=2) - selected_states = [s for s, _ in selected] - check( - "unvisited_high_reward_selected", - "fresh_high" in selected_states, - f"selected states: {selected_states}", - ) - - # ------------------------------------------------------------------ - # Exploitation: Q(parent) rises after adding a high-reward child, making - # the parent score higher than a sibling with no children. - # We verify PUCT scores directly — not via select() — because select() - # would correctly pick the child itself (even better warm-start). - # ------------------------------------------------------------------ - print("\n-- exploitation: high-Q parent outscores peer --") - - buf6 = PUCTBuffer(c=0.01) # low exploration → scores dominated by Q - buf6.add("peer_no_children", 0.5) - buf6.add("parent_explored", 0.5) - # Give parent_explored a great child: Q should propagate to 0.99 - buf6.add("great_child_2", 0.99, parent_state="parent_explored") - - priors6 = buf6._rank_priors() - pk_peer = _make_key("peer_no_children") - pk_parent = _make_key("parent_explored") - score_peer = buf6._puct_score(buf6._nodes[pk_peer], priors6[pk_peer]) - score_parent = buf6._puct_score(buf6._nodes[pk_parent], priors6[pk_parent]) - - check( - "parent_Q_raised_by_great_child", - buf6._nodes[pk_parent].Q == 0.99, - f"parent.Q={buf6._nodes[pk_parent].Q}", - ) - check( - "high_Q_parent_outscores_peer", - score_parent > score_peer, - f"score_parent={score_parent:.4f}, score_peer={score_peer:.4f}", - ) - - # ------------------------------------------------------------------ - # select() group structure - # ------------------------------------------------------------------ - print("\n-- select() group structure --") - - buf7 = PUCTBuffer() - for i in range(10): - buf7.add(f"s{i}", float(i) / 10) - - result = buf7.select(batch_size=16, num_groups=4) - check("select_total_length", len(result) == 16, f"len={len(result)}") - - # Each group of 4 should share the same state - groups_of_4 = [result[i*4:(i+1)*4] for i in range(4)] - for gi, group in enumerate(groups_of_4): - states_in_group = [s for s, _ in group] - check( - f"group_{gi}_same_state", - len(set(states_in_group)) == 1, - f"group {gi} has mixed states: {states_in_group}", - ) - - # Each group should have a DIFFERENT initial state from the others - group_states = [group[0][0] for group in groups_of_4] - check( - "groups_have_distinct_states", - len(set(group_states)) == 4, - f"group states: {group_states}", - ) - - # ------------------------------------------------------------------ - # select() raises on batch_size not divisible by num_groups - # ------------------------------------------------------------------ - print("\n-- select() error handling --") - - buf8 = PUCTBuffer() - buf8.add("x", 1.0) - try: - buf8.select(batch_size=7, num_groups=3) - check("indivisible_batch_raises", False, "should have raised ValueError") - except ValueError: - check("indivisible_batch_raises", True) - - # select() on empty buffer raises - buf_empty = PUCTBuffer() - try: - buf_empty.select(batch_size=4, num_groups=2) - check("empty_buffer_select_raises", False, "should have raised ValueError") - except ValueError: - check("empty_buffer_select_raises", True) - - # ------------------------------------------------------------------ - # Context (ancestry path) - # ------------------------------------------------------------------ - print("\n-- context / ancestry path --") - - buf9 = PUCTBuffer() - buf9.add("root", 0.1) - buf9.add("child", 0.5, parent_state="root") - buf9.add("grand", 0.9, parent_state="child") - - # Force select to pick "grand" by making it best by far - buf9._nodes[_make_key("grand")].reward = 10.0 - buf9._propagate_Q(_make_key("child")) - buf9._propagate_Q(_make_key("root")) - - result9 = buf9.select(batch_size=1, num_groups=1) - state9, context9 = result9[0] - check("context_is_list", isinstance(context9, list)) - check( - "context_starts_at_root", - context9[0][0] == "root", - f"context[0]={context9[0]}", - ) - check( - "context_ends_at_selected", - context9[-1][0] == state9, - f"context[-1]={context9[-1]}, state={state9!r}", - ) - check( - "context_length_equals_depth", - len(context9) == 3, - f"len={len(context9)}, expected 3", - ) - - # ------------------------------------------------------------------ - # Visit count increments on select - # ------------------------------------------------------------------ - print("\n-- visit count tracking --") - - buf10 = PUCTBuffer() - buf10.add("a", 0.5) - buf10.add("b", 0.6) - n_before_a = buf10._nodes[_make_key("a")].n - buf10.select(batch_size=4, num_groups=2) - T_after = buf10._T - check("T_incremented_by_num_groups", T_after == 2, f"T={T_after}") - total_n = sum(n.n for n in buf10._nodes.values()) - check("total_n_equals_T", total_n == T_after, f"sum(n)={total_n}, T={T_after}") - - # ------------------------------------------------------------------ - # numpy array states - # ------------------------------------------------------------------ - print("\n-- numpy array states --") - - buf11 = PUCTBuffer() - arr_a = np.array([0.1, 0.5, 0.4]) - arr_b = np.array([0.3, 0.3, 0.4]) - buf11.add(arr_a, 0.7) - buf11.add(arr_b, 0.9) - check("numpy_states_len", len(buf11) == 2, f"len={len(buf11)}") - best_s, best_r = buf11.best() - check("numpy_best_reward", best_r == 0.9, f"best_r={best_r}") - check("numpy_best_state", np.array_equal(best_s, arr_b), f"best_s={best_s}") - - # ------------------------------------------------------------------ - print() - if failures: - for f in failures: - print(f) - print(f"\n{len(failures)} test(s) FAILED") - import sys; sys.exit(1) - else: - print("All tests passed.") - - -if __name__ == "__main__": - _run_tests() diff --git a/shim/run_discover.py b/shim/run_discover.py deleted file mode 100644 index 6b2d7bc80a..0000000000 --- a/shim/run_discover.py +++ /dev/null @@ -1,349 +0,0 @@ -"""Run script for TTT-Discover GRPO training on the Erdős Minimum Overlap Problem. - -This follows the sliding_puzzle pattern: custom IterableDataset that generates -prompts dynamically from a PUCT buffer, wired into the standard GRPO loop. - -Usage: - # Start the Gym resource server first (separate process/node): - cd ~/Gym && ng_run "+config_paths=[resources_servers/erdos_discovery/configs/erdos_discovery.yaml]" - - # Then run training: - cd ~/RL && uv run python examples/run_discover.py [--config examples/configs/grpo_erdos_discover.yaml] - -Reference: "Learning to Discover at Test Time" (arXiv:2601.16175) -""" - -import itertools -import argparse -import itertools -import logging -import os -import sys -from typing import Optional - -import aiohttp -import asyncio -import numpy as np -import ray -import torch -from torch.utils.data import IterableDataset - -from nemo_rl.algorithms.grpo import MasterConfig, grpo_train, setup -from nemo_rl.algorithms.utils import get_tokenizer, set_seed -from nemo_rl.data.interfaces import DatumSpec, LLMMessageLogType -from nemo_rl.distributed.virtual_cluster import init_ray -from nemo_rl.environments.erdos_discovery_environment import ( - ErdosDiscoveryEnvironment, -) -from nemo_rl.models.generation import configure_generation_config -from nemo_rl.utils.config import load_config, register_omegaconf_resolvers - -logger = logging.getLogger(__name__) - -# ═══════════════════════════════════════════════════════════════════ -# Problem description (same as in the Gym resource server) -# ═══════════════════════════════════════════════════════════════════ - -PROBLEM_DESCRIPTION = """\ -Erdos Minimum Overlap Problem -============================== - -Goal: Find a step function f (Python list or NumPy array) giving the -tightest possible upper bound on the Erdos minimum overlap constant c. - -Background: - For integer n, partition {1,...,2n} into equal sets A, B. - M_k = #{(a,b) : a in A, b in B, a-b=k}. - c = lim_{n->inf} min_{A,B} max_k M_k / n. - -Known bounds: 0.379005 < c < 0.380927 (Haugland 2016) -Current best upper bound: 0.380876 (2026) - -Upper Bound via Step Functions: - f : [0,1] -> [0,1] with mean(f) = 0.5 gives: - bound = 2*n*max(autocorr(f)) / sum(f)^2 - where autocorr is computed via FFT. - Smaller bound -> higher reward (reward = 1/bound). - -Constraints: 1 <= len(f) <= 1000, 0 <= f[i] <= 1, mean(f) ~ 0.5 (tol 1e-3). - -Output: Python code defining variable `f` in a ```python block. -Allowed: numpy, math, random, itertools, functools, collections. -Execution limit: 600 seconds. Target: bound < 0.380876.\ -""" - - -# ═══════════════════════════════════════════════════════════════════ -# Datum generation -# ═══════════════════════════════════════════════════════════════════ - - -def generate_discover_datum( - tokenizer, - state_info: dict, - idx: int, - task_name: str = "erdos_discovery", -) -> DatumSpec: - """Create a DatumSpec from a PUCT-selected state. - - Args: - tokenizer: HuggingFace tokenizer. - state_info: Dict from /select_state with keys: - state, context, reward, system_prompt, user_prompt. - idx: Datum index. - task_name: Task name for env routing. - - Returns: - DatumSpec ready for the GRPO training loop. - """ - system_prompt = state_info.get("system_prompt", PROBLEM_DESCRIPTION) - user_prompt = state_info["user_prompt"] - - messages: LLMMessageLogType = [ - {"role": "system", "content": system_prompt}, - {"role": "user", "content": user_prompt}, - ] - - # Tokenize the prompt - prompt_text = tokenizer.apply_chat_template( - messages, tokenize=False, add_generation_prompt=True - ) - prompt_ids = tokenizer.encode(prompt_text, add_special_tokens=False) - prompt_tensor = torch.tensor(prompt_ids, dtype=torch.long) - - # Attach token_ids to messages for NeMo RL's message_log format - for msg in messages: - msg_text = tokenizer.apply_chat_template( - [msg], tokenize=False, add_generation_prompt=False - ) - msg_ids = tokenizer.encode(msg_text, add_special_tokens=False) - msg["token_ids"] = torch.tensor(msg_ids, dtype=torch.long) - - return DatumSpec( - message_log=messages, - length=len(prompt_ids), - extra_env_info={ - "parent_state": state_info.get("state"), - "context": state_info.get("context"), - "reward": state_info.get("reward", 0.0), - }, - loss_multiplier=1.0, - idx=idx, - task_name=task_name, - ) - - -# ═══════════════════════════════════════════════════════════════════ -# Dynamic dataset backed by PUCT buffer -# ═══════════════════════════════════════════════════════════════════ - - -class DiscoverDataset(IterableDataset): - """Iterable dataset that fetches prompts from the PUCT buffer each step. - - Each iteration fetches `num_groups_per_step` states from the Gym resource - server's /select_state endpoint and yields them as DatumSpecs. - - The dataset loops indefinitely — the training loop controls termination - via max_num_steps in the GRPO config. - """ - - def __init__( - self, - tokenizer, - resource_server_url: str, - num_groups_per_step: int = 8, - task_name: str = "erdos_discovery", - length: int = 1000, # Nominal length for dataloader - ): - self.tokenizer = tokenizer - self.resource_server_url = resource_server_url - self.num_groups_per_step = num_groups_per_step - self.task_name = task_name - self.length = length - self._idx_counter = itertools.count() - - def _fetch_states_sync(self) -> list[dict]: - """Synchronously fetch states from the PUCT buffer.""" - import requests - - try: - resp = requests.post( - f"{self.resource_server_url}/select_state", - json={ - "batch_size": self.num_groups_per_step, - "num_groups": self.num_groups_per_step, - }, - timeout=30, - ) - resp.raise_for_status() - data = resp.json() - return data.get("states", []) - except Exception as e: - logger.error("Failed to fetch states from PUCT buffer: %s", e) - # Return fallback: single default prompt - return [ - { - "state": [0.5] * 50, - "context": [], - "reward": 0.5, - "system_prompt": PROBLEM_DESCRIPTION, - "user_prompt": ( - "Starting construction (bound=2.000000, 50 pieces):\n" - "[0.5000, 0.5000, ..., 0.5000]\n\n" - "Improve on this construction. Write Python code that " - "defines a better step function `f`. Think carefully." - ), - } - ] - - def __iter__(self): - for _ in itertools.count(): - states = self._fetch_states_sync() - for state_info in states: - idx = next(self._idx_counter) - yield generate_discover_datum( - self.tokenizer, - state_info, - idx=idx, - task_name=self.task_name, - ) - - def __len__(self): - return self.length - - -# ═══════════════════════════════════════════════════════════════════ -# Setup -# ═══════════════════════════════════════════════════════════════════ - - -def setup_discover_data(config: MasterConfig, tokenizer): - """Create dataset, environment, and wire them together. - - Returns: - (train_dataset, val_dataset, task_to_env, val_task_to_env) - """ - env_config = config.get("env", {}).get("erdos_discovery", {}) - resource_server_url = env_config.get( - "resource_server_url", "http://localhost:8080" - ) - num_groups_per_step = env_config.get("num_groups_per_step", 8) - task_name = "erdos_discovery" - - # Create the dynamic dataset - train_dataset = DiscoverDataset( - tokenizer=tokenizer, - resource_server_url=resource_server_url, - num_groups_per_step=num_groups_per_step, - task_name=task_name, - length=config["grpo"]["max_num_steps"] * num_groups_per_step, - ) - - # Validation dataset: same thing (could be a fixed set, but for discovery - # we just re-sample from the buffer) - val_dataset = DiscoverDataset( - tokenizer=tokenizer, - resource_server_url=resource_server_url, - num_groups_per_step=num_groups_per_step, - task_name=task_name, - length=num_groups_per_step, - ) - - # Create the environment as a Ray actor - env = ErdosDiscoveryEnvironment.options( - num_gpus=0, - max_restarts=-1, - max_task_retries=-1, - ).remote(config=env_config) - - task_to_env = {task_name: env} - val_task_to_env = {task_name: env} - - return train_dataset, val_dataset, task_to_env, val_task_to_env - - -# ═══════════════════════════════════════════════════════════════════ -# Main -# ═══════════════════════════════════════════════════════════════════ - - -def main(): - import os - from omegaconf import OmegaConf - from nemo_rl.utils.config import load_config, register_omegaconf_resolvers - - register_omegaconf_resolvers() - - # Parse --config argument - config_path = None - for i, arg in enumerate(sys.argv[1:], 1): - if arg.startswith("--config="): - config_path = arg.split("=", 1)[1] - elif arg == "--config" and i < len(sys.argv) - 1: - config_path = sys.argv[i + 1] - elif not arg.startswith("--") and config_path is None: - config_path = arg - - if config_path is None: - config_path = os.path.join( - os.path.dirname(__file__), "configs", "grpo_erdos_discover_debug.yaml" - ) - - print(f"Loading config from: {config_path}") - config = load_config(config_path) - - # Resolve OmegaConf interpolations (e.g. ${policy.model_name}) - oc = OmegaConf.create(config) - config = OmegaConf.to_container(oc, resolve=True) - - # Initialize Ray - init_ray() - set_seed(config.get("seed", 42)) - - # Tokenizer - tokenizer = get_tokenizer(config["policy"]["tokenizer"]) - - # Generation config - config["policy"]["generation"] = configure_generation_config( - config["policy"]["generation"], tokenizer - ) - - # Setup data + environment - train_dataset, val_dataset, task_to_env, val_task_to_env = ( - setup_discover_data(config, tokenizer) - ) - - # Setup policy, generation, cluster, dataloader, etc. - ( - policy, - policy_generation, - clusters, - dataloader, - val_dataloader, - loss_fn, - nemo_logger, - checkpointer, - grpo_state, - master_config, - ) = setup(config, tokenizer, train_dataset, val_dataset) - - # Run GRPO training - grpo_train( - policy, - policy_generation, - dataloader, - val_dataloader, - tokenizer, - loss_fn, - task_to_env, - val_task_to_env, - nemo_logger, - checkpointer, - grpo_state, - master_config, - ) - - -if __name__ == "__main__": - main() diff --git a/test_gptoss_vllm.sh b/test_gptoss_vllm.sh deleted file mode 100755 index 64010a9f07..0000000000 --- a/test_gptoss_vllm.sh +++ /dev/null @@ -1,42 +0,0 @@ -#!/bin/bash -set -euo pipefail -cd /home/mormio/RL - -CONTAINER="/home/shared/containers/nemo-rl-super-v3.sqsh" -MODEL="/home/shared/models/gpt-oss-120b-bf16" -MOUNTS="$PWD:$PWD,/home/shared/models:/home/shared/models" - -# Use uv run which activates the right venv with vLLM -COMMAND=" -cd /opt/nemo-rl -uv run python -c \" -from vllm import LLM -print('Attempting to load gpt-oss-120b...') -try: - llm = LLM( - model='$MODEL', - tensor_parallel_size=8, - trust_remote_code=True, - max_model_len=1024, - gpu_memory_utilization=0.5, - enforce_eager=True, - ) - print('SUCCESS: gpt-oss-120b loaded!') - out = llm.generate(['Hello world'], max_tokens=10) - print('Generated:', out[0].outputs[0].text) -except Exception as e: - print(f'FAILED: {type(e).__name__}: {e}') -\" -" - -COMMAND="$COMMAND" \ -CONTAINER="$CONTAINER" \ -MOUNTS="$MOUNTS" \ -GPUS_PER_NODE=8 \ -sbatch \ - --nodes=1 --partition=batch --exclusive \ - --job-name=test-gptoss --time=00:30:00 \ - --output=logs/test-gptoss-%j.out \ - --error=logs/test-gptoss-%j.err \ - --exclude=d2dfac12-001,d2dfac12-002,d2dfac12-004,d2dfac12-007,d2dfac12-008,d2dfac12-019,d2dfac12-027,d2dfac12-028,d2dfac12-029 \ - ray.sub From 47d4b0b5acb7f0adb0b3d1d18ba0f865b3f103f4 Mon Sep 17 00:00:00 2001 From: Morgane Moss Date: Tue, 7 Apr 2026 19:51:20 +0000 Subject: [PATCH 47/48] cleanup: remove unused puct_buffer.py and ray.sub.bak --- nemo_rl/utils/puct_buffer.py | 561 ----------------------------------- ray.sub.bak | 487 ------------------------------ 2 files changed, 1048 deletions(-) delete mode 100644 nemo_rl/utils/puct_buffer.py delete mode 100644 ray.sub.bak diff --git a/nemo_rl/utils/puct_buffer.py b/nemo_rl/utils/puct_buffer.py deleted file mode 100644 index 53f808d783..0000000000 --- a/nemo_rl/utils/puct_buffer.py +++ /dev/null @@ -1,561 +0,0 @@ -""" -PUCT buffer for TTT-Discover state reuse. - -Reference: "Learning to Discover at Test Time" (arXiv:2601.04116) - -The buffer maintains a tree of (state, reward) nodes. At each training step, -PUCT scoring selects which states to warm-start rollouts from, balancing: - - Exploitation: states whose children have achieved high rewards (high Q) - - Exploration: states that haven't been visited much yet (low n) - -Pure data structure — no ML framework dependencies. -""" - -import math -import dataclasses -from typing import Any, Optional - -import numpy as np - - -# --------------------------------------------------------------------------- -# Internal node -# --------------------------------------------------------------------------- - -@dataclasses.dataclass -class _Node: - state: Any - reward: float # reward of THIS state (from its own evaluation) - parent_key: Any # key of parent node, or None for roots - children_keys: list # keys of direct children - n: int # visit count (number of times selected for expansion) - Q: float # max reward among all descendants (or own reward if leaf) - - -# --------------------------------------------------------------------------- -# Helpers -# --------------------------------------------------------------------------- - -def _make_key(state: Any) -> Any: - """Convert state to a hashable key. - - Supports: str, int, float, tuple, list, np.ndarray, and arbitrary objects - (fallback: id-based, so two different objects with equal content are - treated as distinct — acceptable for LLM response strings). - """ - if isinstance(state, (str, int, float, bool)): - return state - if isinstance(state, np.ndarray): - return (state.dtype, state.shape, state.tobytes()) - if isinstance(state, (list, tuple)): - return tuple(_make_key(x) for x in state) - # Fallback: identity-based key — wrap id so it doesn't collide with ints - return ("__id__", id(state)) - - -# --------------------------------------------------------------------------- -# PUCTBuffer -# --------------------------------------------------------------------------- - -class PUCTBuffer: - """ - Tree-structured buffer with PUCT selection. - - PUCT score for node s: - score(s) = Q(s) + c · P(s) · sqrt(1 + T) / (1 + n(s)) - - Where: - Q(s) = max reward among all descendants of s (own reward if leaf) - P(s) = rank-based prior: rank states by reward, normalize by total rank - n(s) = visit count of s - T = total visit count across all nodes - c = exploration constant (default 1.0) - """ - - def __init__(self, c: float = 1.0) -> None: - self.c = c - self._nodes: dict[Any, _Node] = {} # key → _Node - self._T: int = 0 # total expansions so far - - # ------------------------------------------------------------------ - # Public API - # ------------------------------------------------------------------ - - def add(self, state: Any, reward: float, parent_state: Any = None) -> None: - """Insert a new node into the buffer. - - If the state is already present, this is a no-op (deduplication). - If parent_state is given and present in the buffer, the new node is - linked as a child and Q values are propagated upward. - - Args: - state: The state to insert (any type with a consistent identity). - reward: Scalar reward associated with this state. - parent_state: Parent state, or None for a root node. - """ - key = _make_key(state) - if key in self._nodes: - return # already present — deduplicate - - parent_key = _make_key(parent_state) if parent_state is not None else None - node = _Node( - state=state, - reward=float(reward), - parent_key=parent_key, - children_keys=[], - n=0, - Q=float(reward), # leaf: Q = own reward - ) - self._nodes[key] = node - - if parent_key is not None and parent_key in self._nodes: - self._nodes[parent_key].children_keys.append(key) - self._propagate_Q(parent_key) - - def select( - self, batch_size: int, num_groups: int = 8 - ) -> list[tuple[Any, list]]: - """Select states to warm-start rollouts from. - - Scores each node with PUCT, picks the top `num_groups` distinct states, - and returns `batch_size` (state, context) pairs grouped so that each - group of `batch_size // num_groups` entries shares the same state. - - Context is the ancestry path from root to the selected node: - [(ancestor_state, ancestor_reward), ..., (selected_state, selected_reward)] - The env uses this to build the prompt (previous attempts / warm start). - - Visit counts are incremented for the selected nodes, and T is updated. - - Args: - batch_size: Total number of (state, context) pairs to return. - Must be divisible by num_groups. - num_groups: Number of distinct initial states to select. - - Returns: - List of (state, context) tuples, length == batch_size. - """ - if not self._nodes: - raise ValueError("Buffer is empty — call add() before select()") - if batch_size % num_groups != 0: - raise ValueError( - f"batch_size ({batch_size}) must be divisible by num_groups ({num_groups})" - ) - rollouts_per_group = batch_size // num_groups - - priors = self._rank_priors() - scores = { - key: self._puct_score(node, priors[key]) - for key, node in self._nodes.items() - } - - # Top num_groups keys by PUCT score (at most len(nodes) if buffer is small) - k = min(num_groups, len(self._nodes)) - top_keys = sorted(scores, key=lambda x: scores[x], reverse=True)[:k] - - result: list[tuple[Any, list]] = [] - for key in top_keys: - node = self._nodes[key] - context = self._ancestry(key) - pair = (node.state, context) - result.extend([pair] * rollouts_per_group) - # Increment visit count for this selection - node.n += 1 - self._T += 1 - - return result - - def update( - self, parent_state: Any, child_state: Any, reward: float - ) -> None: - """Add a child node and update Q values up the tree. - - Convenience wrapper around add() that makes the parent/child - relationship explicit. - - Args: - parent_state: The state that was selected and rolled out from. - child_state: The resulting new state produced by the rollout. - reward: Reward of the new child state. - """ - self.add(child_state, reward, parent_state=parent_state) - - def best(self) -> tuple[Any, float]: - """Return the (state, reward) with the highest reward ever seen. - - Returns: - (state, reward) tuple. - """ - if not self._nodes: - raise ValueError("Buffer is empty") - best_key = max(self._nodes, key=lambda k: self._nodes[k].reward) - node = self._nodes[best_key] - return node.state, node.reward - - def __len__(self) -> int: - return len(self._nodes) - - # ------------------------------------------------------------------ - # Private helpers - # ------------------------------------------------------------------ - - def _puct_score(self, node: _Node, prior: float) -> float: - return node.Q + self.c * prior * math.sqrt(1 + self._T) / (1 + node.n) - - def _rank_priors(self) -> dict[Any, float]: - """Rank-based prior: rank by node reward, normalize by sum of ranks. - - Rank 1 = lowest reward, rank N = highest. Ties get the same rank - (average of tied ranks), consistent with scipy.stats.rankdata. - """ - keys = list(self._nodes.keys()) - rewards = np.array([self._nodes[k].reward for k in keys], dtype=float) - - # argsort twice gives rank (0-indexed); add 1 to make 1-indexed - order = np.argsort(rewards) - ranks = np.empty_like(order, dtype=float) - ranks[order] = np.arange(1, len(rewards) + 1, dtype=float) - - # Handle ties: assign average rank to tied rewards. - # Use ranks[tied].mean() — not tied.mean()+1, which would use array - # indices instead of the already-assigned rank values. - # (simple O(N²) loop is fine for buffer sizes we care about) - for i, r in enumerate(rewards): - tied = np.where(rewards == r)[0] - if len(tied) > 1: - ranks[tied] = ranks[tied].mean() - - total = ranks.sum() - return {k: float(ranks[i] / total) for i, k in enumerate(keys)} - - def _propagate_Q(self, key: Any) -> None: - """Propagate max-Q upward from `key` to the root.""" - node = self._nodes[key] - if node.children_keys: - child_rewards = [ - self._nodes[ck].Q - for ck in node.children_keys - if ck in self._nodes - ] - new_Q = max(node.reward, max(child_rewards)) if child_rewards else node.reward - else: - new_Q = node.reward - - if new_Q == node.Q: - return # no change — stop propagation - - node.Q = new_Q - if node.parent_key is not None and node.parent_key in self._nodes: - self._propagate_Q(node.parent_key) - - def _ancestry(self, key: Any) -> list[tuple[Any, float]]: - """Return the path from root to `key` as [(state, reward), ...].""" - path = [] - cur = key - while cur is not None: - node = self._nodes[cur] - path.append((node.state, node.reward)) - cur = node.parent_key - path.reverse() - return path - - -# --------------------------------------------------------------------------- -# Unit tests -# --------------------------------------------------------------------------- - -def _run_tests() -> None: - import sys - - failures: list[str] = [] - - def check(name: str, cond: bool, msg: str = "") -> None: - if not cond: - failures.append(f"FAIL [{name}]: {msg}") - else: - print(f" PASS [{name}]") - - print("=== puct_buffer unit tests ===\n") - - # ------------------------------------------------------------------ - # Basic add / best - # ------------------------------------------------------------------ - print("-- add / best --") - - buf = PUCTBuffer(c=1.0) - buf.add("s0", 0.5) - buf.add("s1", 0.8) - buf.add("s2", 0.3) - - state, reward = buf.best() - check("best_returns_max_reward_state", reward == 0.8, f"reward={reward}") - check("best_returns_correct_state", state == "s1", f"state={state!r}") - check("len_after_adds", len(buf) == 3, f"len={len(buf)}") - - # Duplicate add is a no-op - buf.add("s0", 99.0) - check("duplicate_add_noop", len(buf) == 3, "duplicate changed buffer size") - check("duplicate_reward_unchanged", buf._nodes[_make_key("s0")].reward == 0.5) - - # ------------------------------------------------------------------ - # Q uses MAX not mean - # ------------------------------------------------------------------ - print("\n-- Q = MAX not mean --") - - buf2 = PUCTBuffer() - buf2.add("root", 0.0) - buf2.add("child_low", 0.1, parent_state="root") - buf2.add("child_high", 0.9, parent_state="root") - - root_node = buf2._nodes[_make_key("root")] - check( - "Q_is_max_not_mean", - root_node.Q == 0.9, - f"root.Q={root_node.Q}, expected 0.9 (max), mean would be 0.5", - ) - - # Add another child with even higher reward — Q should update - buf2.add("child_best", 0.95, parent_state="root") - check( - "Q_updates_when_better_child_added", - root_node.Q == 0.95, - f"root.Q={root_node.Q}, expected 0.95", - ) - - # ------------------------------------------------------------------ - # Q propagates through grandchildren (MAX of all descendants) - # ------------------------------------------------------------------ - print("\n-- Q propagation --") - - buf3 = PUCTBuffer() - buf3.add("r", 0.0) - buf3.add("c1", 0.3, parent_state="r") - buf3.add("gc", 0.99, parent_state="c1") # grandchild - - r_node = buf3._nodes[_make_key("r")] - c1_node = buf3._nodes[_make_key("c1")] - check("grandchild_Q_propagates_to_child", c1_node.Q == 0.99, f"c1.Q={c1_node.Q}") - check("grandchild_Q_propagates_to_root", r_node.Q == 0.99, f"r.Q={r_node.Q}") - - # Parent with high own reward should NOT lose Q when children underperform - buf3b = PUCTBuffer() - buf3b.add("great_parent", 0.9) - buf3b.add("weak_child", 0.2, parent_state="great_parent") - gp_node = buf3b._nodes[_make_key("great_parent")] - check( - "parent_Q_not_lowered_by_weak_child", - gp_node.Q == 0.9, - f"great_parent.Q={gp_node.Q}, expected 0.9 (own reward dominates)", - ) - - # ------------------------------------------------------------------ - # Rank priors: ties get correct average rank (not index-based) - # ------------------------------------------------------------------ - print("\n-- rank prior tie handling --") - - buf_ties = PUCTBuffer() - # rewards: s0=0.1 (rank 1), s1=0.5 (tied), s2=0.3 (rank 2), s3=0.5 (tied) - # After tie-averaging: s0→1, s2→2, s1&s3→(3+4)/2=3.5 - buf_ties.add("s0", 0.1) - buf_ties.add("s1", 0.5) - buf_ties.add("s2", 0.3) - buf_ties.add("s3", 0.5) - priors_ties = buf_ties._rank_priors() - p1 = priors_ties[_make_key("s1")] - p3 = priors_ties[_make_key("s3")] - p2 = priors_ties[_make_key("s2")] - check("tied_states_equal_prior", abs(p1 - p3) < 1e-9, f"p1={p1:.6f} p3={p3:.6f}") - check("tied_states_outrank_lower", p1 > p2, f"tied={p1:.4f} vs s2={p2:.4f}") - - # ------------------------------------------------------------------ - # update() convenience wrapper - # ------------------------------------------------------------------ - print("\n-- update() --") - - buf4 = PUCTBuffer() - buf4.add("p", 0.5) - buf4.update("p", "child_via_update", 0.7) - check("update_adds_child", len(buf4) == 2, f"len={len(buf4)}") - check("update_links_child", "child_via_update" in [ - buf4._nodes[ck].state for ck in buf4._nodes[_make_key("p")].children_keys - ]) - - # ------------------------------------------------------------------ - # Exploration: unvisited high-reward states get selected - # ------------------------------------------------------------------ - print("\n-- exploration: unvisited high-reward states --") - - buf5 = PUCTBuffer(c=1.0) - # Old state, visited many times - buf5.add("visited", 0.6) - buf5._nodes[_make_key("visited")].n = 100 - # New high-reward state, never visited - buf5.add("fresh_high", 0.9) - - selected = buf5.select(batch_size=2, num_groups=2) - selected_states = [s for s, _ in selected] - check( - "unvisited_high_reward_selected", - "fresh_high" in selected_states, - f"selected states: {selected_states}", - ) - - # ------------------------------------------------------------------ - # Exploitation: Q(parent) rises after adding a high-reward child, making - # the parent score higher than a sibling with no children. - # We verify PUCT scores directly — not via select() — because select() - # would correctly pick the child itself (even better warm-start). - # ------------------------------------------------------------------ - print("\n-- exploitation: high-Q parent outscores peer --") - - buf6 = PUCTBuffer(c=0.01) # low exploration → scores dominated by Q - buf6.add("peer_no_children", 0.5) - buf6.add("parent_explored", 0.5) - # Give parent_explored a great child: Q should propagate to 0.99 - buf6.add("great_child_2", 0.99, parent_state="parent_explored") - - priors6 = buf6._rank_priors() - pk_peer = _make_key("peer_no_children") - pk_parent = _make_key("parent_explored") - score_peer = buf6._puct_score(buf6._nodes[pk_peer], priors6[pk_peer]) - score_parent = buf6._puct_score(buf6._nodes[pk_parent], priors6[pk_parent]) - - check( - "parent_Q_raised_by_great_child", - buf6._nodes[pk_parent].Q == 0.99, - f"parent.Q={buf6._nodes[pk_parent].Q}", - ) - check( - "high_Q_parent_outscores_peer", - score_parent > score_peer, - f"score_parent={score_parent:.4f}, score_peer={score_peer:.4f}", - ) - - # ------------------------------------------------------------------ - # select() group structure - # ------------------------------------------------------------------ - print("\n-- select() group structure --") - - buf7 = PUCTBuffer() - for i in range(10): - buf7.add(f"s{i}", float(i) / 10) - - result = buf7.select(batch_size=16, num_groups=4) - check("select_total_length", len(result) == 16, f"len={len(result)}") - - # Each group of 4 should share the same state - groups_of_4 = [result[i*4:(i+1)*4] for i in range(4)] - for gi, group in enumerate(groups_of_4): - states_in_group = [s for s, _ in group] - check( - f"group_{gi}_same_state", - len(set(states_in_group)) == 1, - f"group {gi} has mixed states: {states_in_group}", - ) - - # Each group should have a DIFFERENT initial state from the others - group_states = [group[0][0] for group in groups_of_4] - check( - "groups_have_distinct_states", - len(set(group_states)) == 4, - f"group states: {group_states}", - ) - - # ------------------------------------------------------------------ - # select() raises on batch_size not divisible by num_groups - # ------------------------------------------------------------------ - print("\n-- select() error handling --") - - buf8 = PUCTBuffer() - buf8.add("x", 1.0) - try: - buf8.select(batch_size=7, num_groups=3) - check("indivisible_batch_raises", False, "should have raised ValueError") - except ValueError: - check("indivisible_batch_raises", True) - - # select() on empty buffer raises - buf_empty = PUCTBuffer() - try: - buf_empty.select(batch_size=4, num_groups=2) - check("empty_buffer_select_raises", False, "should have raised ValueError") - except ValueError: - check("empty_buffer_select_raises", True) - - # ------------------------------------------------------------------ - # Context (ancestry path) - # ------------------------------------------------------------------ - print("\n-- context / ancestry path --") - - buf9 = PUCTBuffer() - buf9.add("root", 0.1) - buf9.add("child", 0.5, parent_state="root") - buf9.add("grand", 0.9, parent_state="child") - - # Force select to pick "grand" by making it best by far - buf9._nodes[_make_key("grand")].reward = 10.0 - buf9._propagate_Q(_make_key("child")) - buf9._propagate_Q(_make_key("root")) - - result9 = buf9.select(batch_size=1, num_groups=1) - state9, context9 = result9[0] - check("context_is_list", isinstance(context9, list)) - check( - "context_starts_at_root", - context9[0][0] == "root", - f"context[0]={context9[0]}", - ) - check( - "context_ends_at_selected", - context9[-1][0] == state9, - f"context[-1]={context9[-1]}, state={state9!r}", - ) - check( - "context_length_equals_depth", - len(context9) == 3, - f"len={len(context9)}, expected 3", - ) - - # ------------------------------------------------------------------ - # Visit count increments on select - # ------------------------------------------------------------------ - print("\n-- visit count tracking --") - - buf10 = PUCTBuffer() - buf10.add("a", 0.5) - buf10.add("b", 0.6) - n_before_a = buf10._nodes[_make_key("a")].n - buf10.select(batch_size=4, num_groups=2) - T_after = buf10._T - check("T_incremented_by_num_groups", T_after == 2, f"T={T_after}") - total_n = sum(n.n for n in buf10._nodes.values()) - check("total_n_equals_T", total_n == T_after, f"sum(n)={total_n}, T={T_after}") - - # ------------------------------------------------------------------ - # numpy array states - # ------------------------------------------------------------------ - print("\n-- numpy array states --") - - buf11 = PUCTBuffer() - arr_a = np.array([0.1, 0.5, 0.4]) - arr_b = np.array([0.3, 0.3, 0.4]) - buf11.add(arr_a, 0.7) - buf11.add(arr_b, 0.9) - check("numpy_states_len", len(buf11) == 2, f"len={len(buf11)}") - best_s, best_r = buf11.best() - check("numpy_best_reward", best_r == 0.9, f"best_r={best_r}") - check("numpy_best_state", np.array_equal(best_s, arr_b), f"best_s={best_s}") - - # ------------------------------------------------------------------ - print() - if failures: - for f in failures: - print(f) - print(f"\n{len(failures)} test(s) FAILED") - import sys; sys.exit(1) - else: - print("All tests passed.") - - -if __name__ == "__main__": - _run_tests() diff --git a/ray.sub.bak b/ray.sub.bak deleted file mode 100644 index e6e3e07af7..0000000000 --- a/ray.sub.bak +++ /dev/null @@ -1,487 +0,0 @@ -#!/bin/bash -#SBATCH --nodes=2 -#SBATCH --exclusive -#SBATCH --account=ACCOUNT -#SBATCH --job-name=JOB_NAME -#SBATCH --partition=PARTITION -#SBATCH --time=1:0:0 -#SBATCH --dependency=singleton - -# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -set -eoux pipefail - -######################################################## -# Function to detect if SLURM cluster uses GRES -######################################################## -maybe_gres_arg() { - # Check if any nodes in the partition have GRES configured - # Assumes a homogeneous allocation (not a heterogeneous job) - if sinfo -p $SLURM_JOB_PARTITION -h -o "%G" | grep -q "gpu:"; then - # Do a quick assert here that gpus:8 == gpus:$GPUS_PER_NODE. It is probably a user error if someone isn't using GPUS_PER_NODE=8 on our clusters if it supports --gres=gpu:8 or gpu:a100:8 - if [[ $GPUS_PER_NODE -ne $(sinfo -p $SLURM_JOB_PARTITION -h -o "%G" | grep "gpu:" | awk -F: '{print $NF}') ]]; then - echo "Error: GPUS_PER_NODE=$GPUS_PER_NODE but GRES detected is $(sinfo -p $SLURM_JOB_PARTITION -h -o "%G" | grep "gpu:") meaning GPUS_PER_NODE is not set to fully claim the GPUs on the nodes." >&2 - exit 1 - fi - echo "--gres=gpu:${GPUS_PER_NODE}" - return - fi - - # No GRES support detected - echo "" -} - -######################################################## -# User defined variables -######################################################## -CONTAINER=$CONTAINER -MOUNTS=$MOUNTS -COMMAND=${COMMAND:-} # This is a script relative to the SLURM_SUBMIT_DIR. If left empty, it will leave the cluster idle after it's brought up. -######################################################## -# Ports for all nodes (should be odd numbers since we place head/worker[0] on the same node) so all workers get the odd ports, but the head will get +1 the ports -NODE_MANAGER_PORT=${NODE_MANAGER_PORT:-53001} -OBJECT_MANAGER_PORT=${OBJECT_MANAGER_PORT:-53003} -RUNTIME_ENV_AGENT_PORT=${RUNTIME_ENV_AGENT_PORT:-53005} -DASHBOARD_AGENT_GRPC_PORT=${DASHBOARD_AGENT_GRPC_PORT:-53007} -METRICS_EXPORT_PORT=${METRICS_EXPORT_PORT:-53009} - -# Ports for the head node -PORT=${PORT:-54514} -RAY_CLIENT_SERVER_PORT=${RAY_CLIENT_SERVER_PORT:-10001} -#REDIT_SHARD_PORTS=${REDIT_SHARD_PORTS:-"random"} ?? -DASHBOARD_PORT=${DASHBOARD_PORT:-8265} # Also used by debugger -DASHBOARD_AGENT_LISTEN_PORT=${DASHBOARD_AGENT_LISTEN_PORT:-52365} -RAY_DEBUGGER_ARGS= -if [ "${RAY_DEBUG:-}" = "legacy" ]; then - RAY_DEBUGGER_ARGS="--ray-debugger-external" -fi - -# After ray>=2.47, this feature is enabled by default which creates uv venvs for any py_executable starting with `uv run`. -# There is severe contention and performance issues with this enabled considering our dependencies are so large and occasionally -# need to be compiled, so NeMo RL has an implementation in nemo_rl/utils/venv.py that does it once per node as opposed to once per task. -export RAY_ENABLE_UV_RUN_RUNTIME_ENV=0 - -# Setting ulimit is recommended by ray best practices page -# @ https://docs.ray.io/en/latest/cluster/vms/user-guides/large-cluster-best-practices.html -# It's session based and won't affect the system outside the script -# Ensure that the soft limit isn't above the hard limit -if [[ $(ulimit -Hn) == "unlimited" ]] || [[ 65535 -lt $(ulimit -Hn) ]]; then - ulimit -Sn 65535 -elif [[ $(ulimit -Hn) != "unlimited" ]] && [[ $(ulimit -Hn) -lt 65535 ]]; then - echo "[WARNING]: Cannot increase ulimit on file descriptors to 65535 according ray recommendation: https://docs.ray.io/en/latest/cluster/vms/user-guides/large-cluster-best-practices.html. Speak to cluster admins to increase, otherwise ray may crash unexpectedly." -fi - -# On our clusters, the largest port range on an idle worker appeared between 52369-64607 -# (not including the other ports set by this script). So this range is chosen to be -# somewhere in the middle -MIN_WORKER_PORT=${MIN_WORKER_PORT:-54001} -MAX_WORKER_PORT=${MAX_WORKER_PORT:-54513} -######################################################## -# Number seconds to sync logs from /tmp/ray/session_*/logs to $LOG_DIR/ray/ -RAY_LOG_SYNC_FREQUENCY=${RAY_LOG_SYNC_FREQUENCY:-} -######################################################## - -# Unset UV_CACHE_DIR to avoid local cache directory interferring with the container cache -unset UV_CACHE_DIR - -if [[ -n "${UV_CACHE_DIR_OVERRIDE:-}" ]]; then - mkdir -p "$UV_CACHE_DIR_OVERRIDE" - if [[ -n $MOUNTS ]]; then - MOUNTS+=",$UV_CACHE_DIR_OVERRIDE:/root/.cache/uv" - else - MOUNTS="$UV_CACHE_DIR_OVERRIDE:/root/.cache/uv" - fi -fi - -# Create logs directory -BASE_LOG_DIR=${BASE_LOG_DIR:-$SLURM_SUBMIT_DIR} -LOG_DIR="$BASE_LOG_DIR/$SLURM_JOB_ID-logs" -mkdir -p $LOG_DIR - -# Number of GPUs per worker node -GPUS_PER_NODE=${GPUS_PER_NODE:-8} - -# Detect GRES support and set GRES_ARG -GRES_ARG=$(maybe_gres_arg) -if [[ -n "$GRES_ARG" ]]; then - echo "[INFO] GRES support detected. Using: $GRES_ARG" -else - echo "[INFO] No GRES support detected. Running without --gres flag." -fi - -COMMON_SRUN_ARGS="$GRES_ARG" -COMMON_SRUN_ARGS+=" --no-container-mount-home" -COMMON_SRUN_ARGS+=" --mpi=pmix" -COMMON_SRUN_ARGS+=" --container-mounts=$MOUNTS" -COMMON_SRUN_ARGS+=" --container-image=$CONTAINER" -COMMON_SRUN_ARGS+=" --container-workdir=$SLURM_SUBMIT_DIR" -# TODO: delete these (just for debugging) -COMMON_SRUN_ARGS+=" -p $SLURM_JOB_PARTITION" -COMMON_SRUN_ARGS+=" -A $SLURM_JOB_ACCOUNT" -# Number of CPUs per worker node -CPUS_PER_WORKER=${CPUS_PER_WORKER:-$((GPUS_PER_NODE * 16))} - -num_retries=3 - -# Track backgrounded srun client PIDs for head and workers -declare -A SRUN_PIDS - -# Verify all backgrounded srun client processes are still alive; exit fast if any died -check_srun_processes() { - for name in "${!SRUN_PIDS[@]}"; do - pid="${SRUN_PIDS[$name]}" - # Check if the process is still running - if ! kill -0 "$pid" 2>/dev/null; then - echo "[ERROR] Background srun '$name' died (pid=$pid). Could be a failure in startup or an issue with the node preventing the srun to start. Attempting to exit." >&2 - # Signal sidecars inside containers to terminate ASAP - touch "$LOG_DIR/ENDED" - exit 1 - fi - done -} - -# Getting the node names and IP addresses in the SLURM allocation -nodes=$(scontrol show hostnames "$SLURM_JOB_NODELIST") -nodes_array=($nodes) -ip_addresses_array=() - -for node in $nodes; do - # Try multiple methods to get IP address - ENHANCED VERSION v2.0 - echo "[DEBUG] Resolving hostname: $node using enhanced resolution methods" - ip_address="" - - # Method 1: Try host command - echo "[DEBUG] Method 1: host command" - ip_address=$(host $node 2>/dev/null | awk '/has address/ { print $4 }' | head -1 || true) - echo "[DEBUG] host result: '$ip_address'" - - # Method 2: If host fails, try getent - if [[ -z "$ip_address" ]]; then - echo "[DEBUG] Method 2: getent hosts" - ip_address=$(getent hosts $node 2>/dev/null | awk '{ print $1 }' | head -1 || true) - echo "[DEBUG] getent result: '$ip_address'" - fi - - # Method 3: If getent fails, try nslookup - if [[ -z "$ip_address" ]]; then - echo "[DEBUG] Method 3: nslookup" - ip_address=$(nslookup $node 2>/dev/null | awk '/^Address: / { print $2 }' | head -1 || true) - echo "[DEBUG] nslookup result: '$ip_address'" - fi - - # Method 4: If all DNS methods fail, try ping to extract IP - if [[ -z "$ip_address" ]]; then - echo "[DEBUG] Method 4: ping" - ip_address=$(ping -c 1 $node 2>/dev/null | grep "PING" | sed 's/.*(\([^)]*\)).*/\1/' || true) - echo "[DEBUG] ping result: '$ip_address'" - fi - - # If still no IP, use the hostname itself (might work if it's already an IP or resolvable) - if [[ -z "$ip_address" ]]; then - echo "[WARNING] Could not resolve IP for $node, using hostname as fallback" - ip_address=$node - fi - - echo "[INFO] Node: $node -> IP: $ip_address" - # Add the IP address to the array - ip_addresses_array+=("$ip_address") -done - -head_node=${nodes_array[0]} -head_node_ip=${ip_addresses_array[0]} - -ip_head=$head_node_ip:$PORT - -# First we start the head of the ray cluster on one of the physical nodes -# Give the head node actual resources to make it schedulable - -head_cmd=$(cat < /dev/null 2>&1; then - for session_dir in /tmp/ray/session_[0-9]*/; do - if [[ -d "\$session_dir/logs" ]]; then - session_name=\$(basename "\$session_dir") - mkdir -p "$LOG_DIR/ray/\$session_name" - if command -v rsync > /dev/null 2>&1; then - rsync -ahP "\$session_dir/logs/" "$LOG_DIR/ray/\$session_name/logs/" 2>/dev/null || true - else - cp -r "\$session_dir/logs" "$LOG_DIR/ray/\$session_name/" - fi - fi - done - fi - if [[ -f "$LOG_DIR/ENDED" ]]; then - echo "Log sync sidecar terminating..." - break - fi - done -} -log-sync-sidecar & - -# Patch nsight.py before starting Ray head -sed -i 's/context\.py_executable = " "\.join(self\.nsight_cmd) + " python"/context.py_executable = " ".join(self.nsight_cmd) + f" {context.py_executable}"/g' /opt/nemo_rl_venv/lib64/python*/site-packages/ray/_private/runtime_env/nsight.py - -cat < /dev/null 2>&1; then - for session_dir in /tmp/ray/session_[0-9]*/; do - if [[ -d "\$session_dir/logs" ]]; then - session_name=\$(basename "\$session_dir") - mkdir -p "$LOG_DIR/ray/$node_i/\$session_name" - if command -v rsync > /dev/null 2>&1; then - rsync -ahP "\$session_dir/logs/" $LOG_DIR/ray/$node_i/\$session_name/logs/ 2>/dev/null || true - else - cp -r "\$session_dir/logs" $LOG_DIR/ray/$node_i/\$session_name/ - fi - fi - done - fi - if [[ -f "$LOG_DIR/ENDED" ]]; then - echo "Log sync sidecar terminating..." - break - fi - done -} -log-sync-sidecar & - -# Patch nsight.py before starting Ray worker -sed -i 's/context\.py_executable = " "\.join(self\.nsight_cmd) + " python"/context.py_executable = " ".join(self.nsight_cmd) + f" {context.py_executable}"/g' /opt/nemo_rl_venv/lib64/python*/site-packages/ray/_private/runtime_env/nsight.py - -cat <$SLURM_SUBMIT_DIR/${SLURM_JOB_ID}-attach.sh -# No args launches on the head node (node 0) -# Args 1-N launch on worker nodes (nodes 1 through N-1) -# Optional: set COMMAND='...' to run non-interactively instead of opening an interactive shell -WORKER_NUM=\${1:-} -if [[ -z "\$WORKER_NUM" ]]; then - # Empty means we are on the head node - if [[ -n "\${COMMAND:-}" ]]; then - srun --no-container-mount-home $GRES_ARG -A $SLURM_JOB_ACCOUNT -p $SLURM_JOB_PARTITION --overlap --container-name=ray-head --container-workdir=$CONTAINER_CWD --nodes=1 --ntasks=1 -w "$head_node" --jobid $SLURM_JOB_ID bash -c "\$COMMAND" - else - srun --no-container-mount-home $GRES_ARG -A $SLURM_JOB_ACCOUNT -p $SLURM_JOB_PARTITION --overlap --container-name=ray-head --container-workdir=$CONTAINER_CWD --nodes=1 --ntasks=1 -w "$head_node" --jobid $SLURM_JOB_ID --pty bash - fi -else - # Worker numbers 1 through N-1 correspond to ray-worker-1 through ray-worker-(N-1) - # and use nodes_array[1] through nodes_array[N-1] - if [[ \$WORKER_NUM -lt 1 || \$WORKER_NUM -ge $SLURM_JOB_NUM_NODES ]]; then - echo "Error: WORKER_NUM must be between 1 and $((SLURM_JOB_NUM_NODES-1))" - exit 1 - fi - nodes_array=($nodes) - if [[ -n "\${COMMAND:-}" ]]; then - srun --no-container-mount-home $GRES_ARG -A $SLURM_JOB_ACCOUNT -p $SLURM_JOB_PARTITION --overlap --container-name=ray-worker-\$WORKER_NUM --container-workdir=$CONTAINER_CWD --nodes=1 --ntasks=1 -w "\${nodes_array[\$WORKER_NUM]}" --jobid $SLURM_JOB_ID bash -c "\$COMMAND" - else - srun --no-container-mount-home $GRES_ARG -A $SLURM_JOB_ACCOUNT -p $SLURM_JOB_PARTITION --overlap --container-name=ray-worker-\$WORKER_NUM --container-workdir=$CONTAINER_CWD --nodes=1 --ntasks=1 -w "\${nodes_array[\$WORKER_NUM]}" --jobid $SLURM_JOB_ID --pty bash - fi -fi -EOF - chmod +x $SLURM_SUBMIT_DIR/${SLURM_JOB_ID}-attach.sh - echo " COMMAND='echo hello' bash $SLURM_SUBMIT_DIR/${SLURM_JOB_ID}-attach.sh # run a non-interactive command on head node" - echo " bash $SLURM_SUBMIT_DIR/${SLURM_JOB_ID}-attach.sh # to attach to head node (i.e., 'worker 0')" - echo " bash $SLURM_SUBMIT_DIR/${SLURM_JOB_ID}-attach.sh 1 # to attach to worker 1" - echo " bash $SLURM_SUBMIT_DIR/${SLURM_JOB_ID}-attach.sh 2 # to attach to worker 2, etc." - sleep infinity -fi From a0de9218c39a199ad669c26deff630e661bbb0eb Mon Sep 17 00:00:00 2001 From: Morgane Moss Date: Tue, 7 Apr 2026 19:55:53 +0000 Subject: [PATCH 48/48] restore Gym submodule to match main (avoid merge conflict) --- 3rdparty/Gym-workspace/Gym | 1 + 1 file changed, 1 insertion(+) create mode 160000 3rdparty/Gym-workspace/Gym diff --git a/3rdparty/Gym-workspace/Gym b/3rdparty/Gym-workspace/Gym new file mode 160000 index 0000000000..1a4912e231 --- /dev/null +++ b/3rdparty/Gym-workspace/Gym @@ -0,0 +1 @@ +Subproject commit 1a4912e231bb2795b062f7de97496caaf382c7f6