diff --git a/3rdparty/Gym-workspace/Gym b/3rdparty/Gym-workspace/Gym index 23cdeb3807..1a4912e231 160000 --- a/3rdparty/Gym-workspace/Gym +++ b/3rdparty/Gym-workspace/Gym @@ -1 +1 @@ -Subproject commit 23cdeb38077d7b72a5fbae0927a2e1a74bfc15f7 +Subproject commit 1a4912e231bb2795b062f7de97496caaf382c7f6 diff --git a/README.md b/README.md index f90091e5af..e6431e6136 100644 --- a/README.md +++ b/README.md @@ -104,6 +104,8 @@ For detailed information on backend selection, configuration, and examples, see - ✅ **Environment Support and Isolation** - Support for multi-environment training and dependency isolation between components. - ✅ **Worker Isolation** - Process isolation between RL Actors (no worries about global state). - ✅ **Learning Algorithms** - GRPO/GSPO/DAPO, SFT(with LoRA), DPO, and On-policy distillation. +- ✅ **Advantage Estimators** - Group-relative (GRPO), multi-reward (GDPO), Reinforce++, and [Entropic Adaptive-β](nemo_rl/algorithms/entropic_advantage_estimator.py) (LOO entropic weighting from [TTT-Discover](https://arxiv.org/abs/2601.16175)). +- ✅ **PUCT Buffer** - [Tree-structured state selection](nemo_rl/utils/puct_buffer.py) for iterative optimization environments (exploration/exploitation via Upper Confidence bounds). - ✅ **Multi-Turn RL** - Multi-turn generation and training for RL with tool use, games, etc. - ✅ **Advanced Parallelism with DTensor** - PyTorch FSDP2, TP, CP, and SP for efficient training (through NeMo AutoModel). - ✅ **Larger Model Support with Longer Sequences** - Performant parallelisms with Megatron Core (TP/PP/CP/SP/EP/FSDP) (through NeMo Megatron Bridge). diff --git a/examples/configs/grpo_erdos_discover.yaml b/examples/configs/grpo_erdos_discover.yaml new file mode 100644 index 0000000000..1d51b4cde3 --- /dev/null +++ b/examples/configs/grpo_erdos_discover.yaml @@ -0,0 +1,98 @@ +# TTT-Discover Erdős — Nemotron-3-Super-120B, 16k seq, 8 nodes, CP=2 +defaults: "grpo_superv3.yaml" + +grpo: + num_prompts_per_step: 8 + num_generations_per_prompt: 63 + max_num_steps: 50 + max_rollout_turns: 1 + remove_constant_reward_groups: true + val_period: 0 + val_at_start: false + val_at_end: false + adv_estimator: + name: entropic_adaptive_beta + gamma: 0.6931471805599453 + +loss_fn: + kl_penalty_coef: 0.1 + ratio_clip: 0.2 + token_level_loss: false + +policy: + model_name: "/home/shared/models/NVIDIA-Nemotron-3-Super-120B-A12B-BF16" + tokenizer: + name: "/home/shared/models/NVIDIA-Nemotron-3-Super-120B-A12B-BF16" + chat_template_kwargs: null + max_total_sequence_length: 16384 + train_global_batch_size: 504 + train_micro_batch_size: 1 + logprob_batch_size: 1 + + generation: + colocated: + enabled: false + resources: + num_nodes: 2 + gpus_per_node: 8 + max_new_tokens: 15360 + vllm_cfg: + async_engine: false + tensor_parallel_size: 8 + gpu_memory_utilization: 0.85 + max_model_len: 16384 + + megatron_cfg: + tensor_model_parallel_size: 4 + pipeline_model_parallel_size: 1 + context_parallel_size: 2 + expert_model_parallel_size: 8 + sequence_parallel: true + activation_checkpointing: true + empty_unused_memory_level: 2 + optimizer_cpu_offload: true + optimizer: + optimizer_cpu_offload: true + optimizer_offload_fraction: 1.0 + + dynamic_batching: + enabled: false + + lora_cfg: + enabled: false + +optimizer: + lr: 4.0e-5 + +data: + shuffle: false + max_input_seq_length: 16384 + +env: + erdos_discovery: + num_initial_states: 8 # matches num_prompts_per_step + puct_seed_batch_size: 8 # matches num_prompts_per_step + sandbox_timeout: 1000 + should_use_nemo_gym: false + +cluster: + gpus_per_node: 8 + num_nodes: 8 + +logger: + log_dir: "results/erdos-120b-16k" + wandb_enabled: true + wandb: + project: "ttt-discover-erdos" + name: "nemotron-120b-16k-8node-puct" + tensorboard_enabled: false + mlflow_enabled: false + swanlab_enabled: false + +checkpointing: + enabled: false + checkpoint_dir: "results/erdos-120b-16k" + save_period: 999999 + checkpoint_must_save_by: null + model_save_format: "safetensors" + save_consolidated: false diff --git a/examples/run_discover.py b/examples/run_discover.py new file mode 100644 index 0000000000..34c39c011a --- /dev/null +++ b/examples/run_discover.py @@ -0,0 +1,361 @@ +"""Run script for TTT-Discover GRPO training on the Erdős Minimum Overlap Problem. + +Matches the reference implementation at: + https://github.com/test-time-training/discover/blob/main/examples/erdos_min_overlap/env.py + +Usage (inside NeMo RL container): + python examples/run_discover.py --config examples/configs/grpo_erdos_discover.yaml +""" + +import itertools +import logging +import os +import sys +from typing import Optional + +import numpy as np +import ray +import torch +from torch.utils.data import IterableDataset + +from nemo_rl.algorithms.grpo import MasterConfig, grpo_train, setup +from nemo_rl.algorithms.utils import get_tokenizer, set_seed +from nemo_rl.data.interfaces import DatumSpec, LLMMessageLogType +from nemo_rl.distributed.virtual_cluster import init_ray +from nemo_rl.environments.erdos_discovery_environment import ( + ErdosDiscoveryEnvironment, + build_erdos_question, + create_initial_state, +) +from nemo_rl.models.generation import configure_generation_config + +logger = logging.getLogger(__name__) + + +# ═══════════════════════════════════════════════════════════════════ +# Datum generation +# ═══════════════════════════════════════════════════════════════════ + + +def generate_discover_datum( + tokenizer, + state: dict, + idx: int, + task_name: str = "erdos_discovery", +) -> DatumSpec: + """Create a DatumSpec from a state dict. + + The prompt is built using the reference TTT-Discover get_question() format. + """ + user_prompt = build_erdos_question(state) + + messages: LLMMessageLogType = [ + {"role": "user", "content": user_prompt}, + ] + + prompt_text = tokenizer.apply_chat_template( + messages, tokenize=False, add_generation_prompt=True + ) + prompt_ids = tokenizer.encode(prompt_text, add_special_tokens=False) + + for msg in messages: + msg_text = tokenizer.apply_chat_template( + [msg], tokenize=False, add_generation_prompt=False + ) + msg_ids = tokenizer.encode(msg_text, add_special_tokens=False) + msg["token_ids"] = torch.tensor(msg_ids, dtype=torch.long) + + extra = { + "construction": state.get("construction"), + "c5_bound": state.get("c5_bound"), + "n_points": state.get("n_points"), + "code": state.get("code", ""), + "parent_c5": state.get("parent_c5"), + "observation": state.get("observation", ""), + } + if state.get("erdos_ref_state") is not None: + extra["erdos_ref_state"] = state["erdos_ref_state"] + + return DatumSpec( + message_log=messages, + length=len(prompt_ids), + extra_env_info=extra, + loss_multiplier=1.0, + idx=idx, + task_name=task_name, + ) + + +# ═══════════════════════════════════════════════════════════════════ +# Datasets: PUCT (train) vs random (val) +# ═══════════════════════════════════════════════════════════════════ + + +class PUCTDiscoverDataset(IterableDataset): + """Training dataset: pulls PUCT-selected states from the Ray env actor. + + Matches ttt-discover-ref: ``ErdosRefPUCTSampler.sample_states(num_prompts)`` + (distinct parent states per step). The env updates the sampler in + ``ErdosDiscoveryEnvironment._sync_step`` via ``update_states`` / + ``record_failed_rollout`` and ``flush``. Requires ``data.num_workers == 0``. + """ + + def __init__( + self, + tokenizer, + env_actor, + num_prompts_per_step: int, + task_name: str = "erdos_discovery", + length: int = 1000, + ): + self.tokenizer = tokenizer + self._env = env_actor + self.num_prompts_per_step = num_prompts_per_step + self.task_name = task_name + self.length = length + self._idx_counter = itertools.count() + + def __iter__(self): + for _ in itertools.count(): + states = ray.get( + self._env.puct_sample_states.remote( + self.num_prompts_per_step, + ) + ) + for state in states: + idx = next(self._idx_counter) + yield generate_discover_datum( + self.tokenizer, + state, + idx=idx, + task_name=self.task_name, + ) + + def __len__(self): + return self.length + + +class RandomDiscoverDataset(IterableDataset): + """Random warm-starts (e.g. validation) — does not touch the PUCT sampler.""" + + def __init__( + self, + tokenizer, + num_states_per_step: int = 8, + task_name: str = "erdos_discovery", + length: int = 1000, + seed: int = 42, + ): + self.tokenizer = tokenizer + self.num_states_per_step = num_states_per_step + self.task_name = task_name + self.length = length + self._idx_counter = itertools.count() + self._rng = np.random.default_rng(seed) + + def __iter__(self): + for _ in itertools.count(): + for _ in range(self.num_states_per_step): + state = create_initial_state(self._rng) + idx = next(self._idx_counter) + yield generate_discover_datum( + self.tokenizer, + state, + idx=idx, + task_name=self.task_name, + ) + + def __len__(self): + return self.length + + +# ═══════════════════════════════════════════════════════════════════ +# Setup +# ═══════════════════════════════════════════════════════════════════ + + +def setup_discover_data(config: MasterConfig, tokenizer): + """Create dataset, environment, and wire them together.""" + base_env = config.get("env", {}).get("erdos_discovery", {}) + env_config = dict(base_env) if isinstance(base_env, dict) else {} + num_states = int(config.get("grpo", {}).get("num_prompts_per_step", 8)) + task_name = "erdos_discovery" + + # Ref parity: cold-start seed count must match prompts per step (PUCT batch_size). + seed_bs = int(env_config.get("puct_seed_batch_size", num_states)) + if seed_bs != num_states: + logger.warning( + "Overriding puct_seed_batch_size %s -> %s to match grpo.num_prompts_per_step " + "(ttt-discover-ref PUCT batch_size).", + seed_bs, + num_states, + ) + env_config["puct_seed_batch_size"] = num_states + + data_cfg = config.get("data", {}) + if int(data_cfg.get("num_workers", 0)) != 0: + logger.warning( + "Setting data.num_workers=0 for Erdős PUCT (dataset calls Ray on the driver)." + ) + config["data"] = {**data_cfg, "num_workers": 0} + + env = ErdosDiscoveryEnvironment.options( + num_gpus=0, + max_restarts=-1, + max_task_retries=-1, + ).remote(config=env_config) + + train_dataset = PUCTDiscoverDataset( + tokenizer=tokenizer, + env_actor=env, + num_prompts_per_step=num_states, + task_name=task_name, + length=config.get("grpo", {}).get("max_num_steps", 50) * num_states, + ) + + val_dataset = RandomDiscoverDataset( + tokenizer=tokenizer, + num_states_per_step=num_states, + task_name=task_name, + length=num_states, + seed=config.get("seed", 42) + 1, + ) + + task_to_env = {task_name: env} + val_task_to_env = {task_name: env} + + return train_dataset, val_dataset, task_to_env, val_task_to_env + + +# ═══════════════════════════════════════════════════════════════════ +# Main +# ═══════════════════════════════════════════════════════════════════ + + +def main(): + from omegaconf import OmegaConf + from nemo_rl.utils.config import load_config + + # Register custom resolvers + if not OmegaConf.has_resolver("mul"): + OmegaConf.register_new_resolver("mul", lambda a, b: a * b) + if not OmegaConf.has_resolver("div"): + OmegaConf.register_new_resolver("div", lambda a, b: a // b) + + try: + from nemo_rl.utils.config import register_omegaconf_resolvers + register_omegaconf_resolvers() + except ImportError: + pass + + # Parse --config argument + config_path = None + for i, arg in enumerate(sys.argv[1:], 1): + if arg.startswith("--config="): + config_path = arg.split("=", 1)[1] + elif arg == "--config" and i < len(sys.argv) - 1: + config_path = sys.argv[i + 1] + elif not arg.startswith("--") and config_path is None: + config_path = arg + + if config_path is None: + config_path = os.path.join( + os.path.dirname(__file__), "configs", "grpo_erdos_discover.yaml" + ) + + print(f"Loading config from: {config_path}") + config = load_config(config_path) + + # Resolve OmegaConf interpolations + oc = OmegaConf.create(config) + config = OmegaConf.to_container(oc, resolve=True) + + # Initialize Ray + init_ray() + set_seed(config.get("seed", 42)) + + # Tokenizer + tokenizer = get_tokenizer(config["policy"]["tokenizer"]) + + # Generation config + config["policy"]["generation"] = configure_generation_config( + config["policy"]["generation"], tokenizer + ) + + # Setup data + environment + train_dataset, val_dataset, task_to_env, val_task_to_env = ( + setup_discover_data(config, tokenizer) + ) + + # Ensure checkpointing config exists (some container versions require it) + if "checkpointing" not in config: + config["checkpointing"] = { + "enabled": False, + "checkpoint_dir": "results/erdos", + "save_period": 999999, + "checkpoint_must_save_by": None, + "model_save_format": "safetensors", + "save_consolidated": False, + "metric_name": "total_reward/mean", + "higher_is_better": True, + "keep_top_k": 1000000, + } + elif config["checkpointing"].get("checkpoint_must_save_by") is None: + config["checkpointing"]["checkpoint_must_save_by"] = None + + # Setup returns vary across container versions + setup_result = setup(config, tokenizer, train_dataset, val_dataset) + setup_list = list(setup_result) + n = len(setup_list) + print(f" setup() returned {n} values") + for i, v in enumerate(setup_list): + print(f" [{i}] {type(v).__name__}: {str(v)[:80]}") + + if n == 11: + # super-v3 container + (policy, policy_generation, _nemo_gym, _clusters, + dataloader, val_dataloader, loss_fn, + nemo_logger, checkpointer, grpo_state, master_config) = setup_list + # Ensure checkpointing exists in master_config + if "checkpointing" not in master_config: + master_config["checkpointing"] = config.get("checkpointing", { + "enabled": False, + "checkpoint_must_save_by": None, + "save_period": 999999, + }) + grpo_train( + policy, policy_generation, + dataloader, val_dataloader, + tokenizer, loss_fn, + task_to_env, val_task_to_env, + nemo_logger, checkpointer, + grpo_state, master_config, + ) + elif n == 10: + # v0.5.0 container: Policy, VllmGen, clusters, train_dl, val_dl, + # loss_fn, logger, checkpointer, grpo_state, master_config + (policy, policy_generation, _clusters, + dataloader, val_dataloader, + loss_fn, nemo_logger, checkpointer, + grpo_state, master_config) = setup_list + # Ensure checkpointing exists in master_config + if "checkpointing" not in master_config: + master_config["checkpointing"] = config.get("checkpointing", { + "enabled": False, + "checkpoint_must_save_by": None, + "save_period": 999999, + }) + grpo_train( + policy, policy_generation, + dataloader, val_dataloader, + tokenizer, loss_fn, + task_to_env, val_task_to_env, + nemo_logger, checkpointer, + grpo_state, master_config, + ) + else: + raise RuntimeError(f"Unexpected setup() return count: {n}") + + +if __name__ == "__main__": + main() diff --git a/launch_scripts/launch_erdos_120b.sh b/launch_scripts/launch_erdos_120b.sh new file mode 100755 index 0000000000..55f5993aae --- /dev/null +++ b/launch_scripts/launch_erdos_120b.sh @@ -0,0 +1,102 @@ +#!/bin/bash +# TTT-Discover Erdős — Nemotron-3-Super-120B, 8k seq len, wandb logging +set -euo pipefail +cd /home/mormio/RL + +CONTAINER="/home/shared/containers/nemo-rl-super-v3.sqsh" +MODEL_PATH="/home/shared/models/NVIDIA-Nemotron-3-Super-120B-A12B-BF16" +EXP="results/erdos-120b-$(date +%Y%m%d_%H%M)" +mkdir -p "$EXP" + +WANDB_API_KEY=$(grep 'password' ~/.netrc | head -1 | awk '{print $2}') +MOUNTS="$PWD:$PWD,$MODEL_PATH:$MODEL_PATH,$HOME/.cache:$HOME/.cache" + +COMMAND=" +cd /opt/nemo-rl && \ +export NCCL_BUFFSIZE=33554432 && \ +export CUDA_DEVICE_ORDER=PCI_BUS_ID && \ +export NCCL_IB_AR_THRESHOLD=0 && \ +export NCCL_IB_PCI_RELAXED_ORDERING=1 && \ +export NCCL_IB_QPS_PER_CONNECTION=2 && \ +export NCCL_IB_SPLIT_DATA_ON_QPS=0 && \ +export NCCL_IGNORE_CPU_AFFINITY=1 && \ +export NCCL_IB_HCA=mlx5_4:1,mlx5_7:1,mlx5_8:1,mlx5_9:1,mlx5_10:1,mlx5_13:1,mlx5_14:1,mlx5_15:1 && \ +export NCCL_SOCKET_IFNAME=bond0 && \ +export UCX_NET_DEVICES=bond0 && \ +export HF_HUB_ENABLE_HF_TRANSFER=0 && \ +export TORCH_CUDA_ARCH_LIST='9.0 10.0' && \ +export NRL_IGNORE_VERSION_MISMATCH=1 && \ +export ERDOS_LOG_DIR=/home/mormio/RL/results/erdos_outputs && \ +export ERDOS_PUCT_LOG_DIR=/home/mormio/RL/results/erdos_puct && \ +export WANDB_API_KEY=$WANDB_API_KEY && \ + +SRC=/home/mormio/RL +cp \$SRC/nemo_rl/algorithms/entropic_advantage_estimator.py /opt/nemo-rl/nemo_rl/algorithms/ +cp \$SRC/nemo_rl/environments/erdos_discovery_environment.py /opt/nemo-rl/nemo_rl/environments/ +cp \$SRC/nemo_rl/environments/erdos_ref_puct_sampler.py /opt/nemo-rl/nemo_rl/environments/ +cp \$SRC/examples/run_discover.py /opt/nemo-rl/examples/ +cp \$SRC/examples/configs/grpo_erdos_discover.yaml /opt/nemo-rl/examples/configs/ + +python -c \" +path = '/opt/nemo-rl/nemo_rl/algorithms/grpo.py' +with open(path) as f: + content = f.read() +if 'entropic_adaptive_beta' not in content: + old = ' else:\\n raise ValueError(f\\\"Invalid adv_estimator name: {adv_estimator_name}\\\")\\n\\n return adv_estimator' + new = ''' elif adv_estimator_name == \\\"entropic_adaptive_beta\\\": + from nemo_rl.algorithms.entropic_advantage_estimator import ( + EntropicAdaptiveBetaAdvantageEstimator, + ) + adv_estimator = EntropicAdaptiveBetaAdvantageEstimator( + adv_estimator_config, loss_config + ) + print(\\\" Using Entropic Adaptive-Beta advantage estimator (TTT-Discover)\\\") + else: + raise ValueError(f\\\"Invalid adv_estimator name: {adv_estimator_name}\\\") + + return adv_estimator''' + content = content.replace(old, new) + with open(path, 'w') as f: + f.write(content) + print('Patched grpo.py') +\" && \ + +python -c \" +path = '/opt/nemo-rl/nemo_rl/environments/utils.py' +with open(path) as f: + content = f.read() +if 'erdos_discovery' not in content: + content = content.replace( + '\\\"nemo_gym\\\": {', + '\\\"erdos_discovery\\\": {\\n \\\"actor_class_fqn\\\": \\\"nemo_rl.environments.erdos_discovery_environment.ErdosDiscoveryEnvironment\\\",\\n },\\n \\\"nemo_gym\\\": {' + ) + with open(path, 'w') as f: + f.write(content) + print('Patched utils.py') +\" && \ + +python examples/run_discover.py \ + --config examples/configs/grpo_erdos_discover.yaml +" + +echo "Submitting Erdős TTT-Discover 120B (8k seq, wandb)..." +echo " Container: $CONTAINER" +echo " Model: $MODEL_PATH" +echo " Nodes: 8 (2 inference + 6 training)" +echo " Seq len: 16384" +echo " Exp: $EXP" + +COMMAND="$COMMAND" \ +CONTAINER="$CONTAINER" \ +MOUNTS="$MOUNTS" \ +GPUS_PER_NODE=8 \ +sbatch \ + --nodes=8 --partition=batch --exclusive \ + --job-name=erdos-120b --time=12:00:00 \ + --output="$EXP/slurm-%j.out" \ + --error="$EXP/slurm-%j.err" \ + --exclude=d2dfac12-001,d2dfac12-002,d2dfac12-004,d2dfac12-007,d2dfac12-008,d2dfac12-019,d2dfac12-027,d2dfac12-028,d2dfac12-029 \ + ray.sub + +echo "Logs: $EXP/" +echo "W&B: https://wandb.ai/nous_research/ttt-discover-erdos" diff --git a/nemo_rl/algorithms/entropic_advantage_estimator.py b/nemo_rl/algorithms/entropic_advantage_estimator.py new file mode 100644 index 0000000000..af78ae5637 --- /dev/null +++ b/nemo_rl/algorithms/entropic_advantage_estimator.py @@ -0,0 +1,179 @@ +"""Entropic Adaptive-Beta Advantage Estimator for TTT-Discover. + +Implements the Leave-One-Out (LOO) entropic advantage from +"Learning to Discover at Test Time" (arXiv:2601.16175). + +Instead of standard group-relative advantages (Adv = R - mean(R)), +this estimator: + 1. Solves for β such that KL(softmax_β(R) || uniform) = γ (default γ = ln(2)) + 2. Computes LOO advantages: w_i = exp(β·r_i) / Z_{-i} - 1 + where Z_{-i} is the normalizer excluding the i-th sample. + +Properties: + - Shift-invariant, approximately scale-invariant + - Monotone in reward + - Approximately mean-zero + - Adaptive scaling via β solves the reward-scale sensitivity of standard GRPO +""" + +import math +from typing import Optional + +import torch + + +def _solve_beta( + rewards: torch.Tensor, + gamma: float = math.log(2), + max_iter: int = 50, + tol: float = 1e-6, +) -> float: + """Solve for β such that KL(softmax_β(R) || uniform) = γ via bisection. + + Args: + rewards: [K] tensor of rewards for one group. + gamma: Target KL divergence. Default ln(2) as in the paper. + max_iter: Maximum bisection iterations. + tol: Convergence tolerance on β. + + Returns: + Scalar β value. + """ + K = rewards.shape[0] + if K <= 1: + return 0.0 + + log_K = math.log(K) + r = rewards.double() + r_max = r.max() + + def kl_at_beta(b: float) -> float: + logits = b * (r - r_max) + log_Z = torch.logsumexp(logits, dim=0) + logq = logits - log_Z + q = logq.exp() + kl = (q * (logq + log_K)).sum().item() + return kl + + # Bisect: KL is monotonically increasing in |β| for non-constant rewards + # Find upper bound for β + lo, hi = 0.0, 1.0 + while kl_at_beta(hi) < gamma and hi < 1e8: + hi *= 2.0 + + # Edge case: all rewards identical → β = 0, KL = 0 for any β + if hi >= 1e8: + return 0.0 + + for _ in range(max_iter): + mid = (lo + hi) / 2.0 + kl = kl_at_beta(mid) + if abs(kl - gamma) < tol: + return mid + if kl < gamma: + lo = mid + else: + hi = mid + + return (lo + hi) / 2.0 + + +def compute_entropic_advantages( + rewards: torch.Tensor, + gamma: float = math.log(2), + eps: float = 1e-8, +) -> torch.Tensor: + """Compute LOO entropic advantages for a group of rewards. + + Args: + rewards: [K] tensor of rewards for one group. + gamma: Target KL for adaptive β. + eps: Small constant for numerical stability. + + Returns: + [K] tensor of advantages. + """ + K = rewards.shape[0] + if K <= 1: + return torch.zeros_like(rewards) + + beta = _solve_beta(rewards, gamma=gamma) + if beta == 0.0: + return torch.zeros_like(rewards) + + r = rewards.double() + r_max = r.max() + e = torch.exp(beta * (r - r_max)) + + if K == 1: + Z_loo = e + else: + # Leave-one-out normalizer: Z_{-i} = (sum(e) - e_i) / (K - 1) + Z_loo = (e.sum() - e) / (K - 1) + + w = e / (Z_loo + eps) + advantages = (w - 1.0).to(rewards.dtype) + return advantages + + +class EntropicAdaptiveBetaAdvantageEstimator: + """Advantage estimator using entropic adaptive-β LOO weighting. + + Follows the same interface as GRPOAdvantageEstimator: + compute_advantage(prompt_ids, rewards, mask, **kwargs) -> [B, S] tensor + + Config keys (under grpo.adv_estimator): + gamma: Target KL for β search. Default ln(2) ≈ 0.693. + eps: Numerical stability constant. Default 1e-8. + """ + + def __init__(self, estimator_config: dict, loss_config: dict): + self.gamma = estimator_config.get("gamma", math.log(2)) + self.eps = estimator_config.get("eps", 1e-8) + + def compute_advantage( + self, + prompt_ids: torch.Tensor, + rewards: torch.Tensor, + mask: torch.Tensor, + **kwargs, + ) -> torch.Tensor: + """Compute per-token advantages using entropic adaptive-β LOO. + + Args: + prompt_ids: [B] or [B, S] tensor identifying which prompt each + sample belongs to (same prompt = same group). + rewards: [B] scalar rewards per sample. + mask: [B, S] response token mask (1 = generation token). + + Returns: + [B, S] advantages tensor. Each generation token gets the + sample-level advantage; non-generation tokens get 0. + """ + batch_size, seq_len = mask.shape + advantages = torch.zeros_like(mask, dtype=rewards.dtype) + + # Group by prompt (same as GRPO's per-prompt baseline) + if prompt_ids.dim() > 1: + # prompt_ids is [B, S] — use first token as group key + group_ids = prompt_ids[:, 0] + else: + group_ids = prompt_ids + + unique_prompts = group_ids.unique() + + for pid in unique_prompts: + group_mask = group_ids == pid + group_rewards = rewards[group_mask] + + group_adv = compute_entropic_advantages( + group_rewards, gamma=self.gamma, eps=self.eps + ) + + # Expand sample-level advantages to [group_size, seq_len] + # and mask to generation tokens only + group_indices = group_mask.nonzero(as_tuple=True)[0] + for i, idx in enumerate(group_indices): + advantages[idx] = group_adv[i] * mask[idx] + + return advantages diff --git a/nemo_rl/algorithms/grpo.py b/nemo_rl/algorithms/grpo.py index e550429ce2..58f722c653 100644 --- a/nemo_rl/algorithms/grpo.py +++ b/nemo_rl/algorithms/grpo.py @@ -122,14 +122,17 @@ class AsyncGRPOConfig(TypedDict): class AdvEstimatorConfig(TypedDict): - """Configuration for advantage estimator (GRPO, GDPO, or Reinforce++).""" + """Configuration for advantage estimator (GRPO, GDPO, Reinforce++, or Entropic).""" - name: str # "grpo", "gdpo", or "reinforce_plus_plus" + name: str # "grpo", "gdpo", "reinforce_plus_plus", or "entropic_adaptive_beta" # GRPO specific normalize_rewards: NotRequired[bool] use_leave_one_out_baseline: NotRequired[bool] # Reinforce++ specific minus_baseline: NotRequired[bool] + # Entropic Adaptive-Beta specific (TTT-Discover, arXiv:2601.16175) + gamma: NotRequired[float] # Target KL for beta search; default ln(2) + eps: NotRequired[float] # Numerical stability; default 1e-8 class GRPOConfig(TypedDict): @@ -1066,6 +1069,14 @@ def _create_advantage_estimator(master_config: MasterConfig): adv_estimator_config, loss_config ) print(" ✓ Using Reinforce++ advantage estimator") + elif adv_estimator_name == "entropic_adaptive_beta": + from nemo_rl.algorithms.entropic_advantage_estimator import ( + EntropicAdaptiveBetaAdvantageEstimator, + ) + adv_estimator = EntropicAdaptiveBetaAdvantageEstimator( + adv_estimator_config, loss_config + ) + print(" ✓ Using Entropic Adaptive-Beta advantage estimator (TTT-Discover)") else: raise ValueError(f"Invalid adv_estimator name: {adv_estimator_name}") diff --git a/nemo_rl/environments/ERDOS_DISCOVERY.md b/nemo_rl/environments/ERDOS_DISCOVERY.md new file mode 100644 index 0000000000..938d7b9d9e --- /dev/null +++ b/nemo_rl/environments/ERDOS_DISCOVERY.md @@ -0,0 +1,156 @@ +# TTT-Discover: Learning to Discover at Test Time + +Implementation of [TTT-Discover](https://arxiv.org/abs/2601.16175) using NeMo RL +for GRPO training and NeMo Gym for environment scoring. + +## Overview + +TTT-Discover is a framework for using LLMs to make mathematical discoveries +through iterative refinement. The model generates candidate solutions (as code), +which are scored and organized in a tree structure (PUCT buffer). Training uses +entropic adaptive-β advantages instead of standard GRPO group-relative baselines. + +The first application is the **Erdős Minimum Overlap Problem**: finding step +functions that minimize the upper bound on the Erdős constant `c` (known bounds: +`0.379005 < c < 0.380927`). + +## Architecture + +``` +┌─────────────────────────────────────────────────┐ +│ NeMo RL (training) │ +│ │ +│ GRPO Loop: │ +│ 1. Dataloader → prompts from PUCT buffer │ +│ 2. vLLM generates code completions │ +│ 3. Environment returns rewards │ +│ 4. Entropic adaptive-β advantage estimator │ +│ 5. Policy gradient update │ +│ 6. Weight sync to vLLM │ +│ │ +│ Components in nemo_rl/: │ +│ algorithms/entropic_advantage_estimator.py │ +│ environments/erdos_discovery_environment.py │ +│ utils/puct_buffer.py │ +└─────────────┬───────────────────────────────────┘ + │ HTTP (POST /verify, /select_state, ...) + │ +┌─────────────▼───────────────────────────────────┐ +│ NeMo Gym (environment) │ +│ │ +│ Erdős Resource Server: │ +│ - Sandboxed Python code execution │ +│ - FFT-based bound computation │ +│ - Constraint validation (len, range, mean) │ +│ - reward = 1 / bound │ +│ - PUCT buffer state management │ +│ - Prompt formatting from tree context │ +│ │ +│ Location: Gym/resources_servers/erdos_discovery │ +└─────────────────────────────────────────────────┘ +``` + +## Key Components + +### Entropic Adaptive-β Advantage Estimator + +Location: `nemo_rl/algorithms/entropic_advantage_estimator.py` + +Replaces standard GRPO group-relative advantages with Leave-One-Out (LOO) +entropic weighting: + +1. **Solve for β**: Find β such that `KL(softmax_β(R) || uniform) = ln(2)` + via bisection search. +2. **LOO advantages**: `w_i = exp(β·r_i) / Z_{-i} - 1` where `Z_{-i}` is the + leave-one-out normalizer (excludes sample `i`). + +Properties: shift-invariant, approximately scale-invariant, monotone, ~zero-mean. + +Config: +```yaml +grpo: + adv_estimator: + name: entropic_adaptive_beta + gamma: 0.6931 # ln(2), target KL divergence +``` + +### PUCT Buffer + +Location: `nemo_rl/utils/puct_buffer.py` + +Tree-structured state selection using Predictor + Upper Confidence bounds for +Trees (PUCT). Balances exploitation (high-reward states) with exploration +(under-visited branches). + +Score: `Q(s) + c · P(s) · √(1+T) / (1+n(s))` + +- `Q(s)`: best reward reachable from state `s` +- `P(s)`: rank-based prior +- `n(s)`: visit count +- `T`: total visits +- `c`: exploration constant (default 1.0) + +This is a **general utility** — usable by any iterative optimization environment, +not just Erdős. + +### Erdős Discovery Environment + +Location: `nemo_rl/environments/erdos_discovery_environment.py` + +Ray remote actor implementing `EnvironmentInterface`. Calls the NeMo Gym Erdős +resource server for sandboxed code execution and reward computation. + +Config: +```yaml +env: + erdos_discovery: + resource_server_url: http://localhost:8080 + num_initial_states: 16 + sandbox_timeout: 600 +``` + +### Erdős Gym Resource Server + +Location: `Gym/resources_servers/erdos_discovery/` + +Standalone FastAPI server handling: +- `/verify`: execute code, validate `f`, compute `reward = 1/bound` +- `/seed_session`: initialize PUCT buffer with random states +- `/select_state`: PUCT-select states for next training batch +- `/update_buffer`: add discoveries to the tree + +## Hyperparameters (from the paper) + +| Parameter | Value | Description | +|-----------|-------|-------------| +| model | 120B MoE | gpt-oss-120b-bf16 with LoRA r=32 | +| group_size | 64 | Rollouts per initial state | +| groups_per_batch | 8 | PUCT-selected states per step | +| epochs | 50 | Training steps | +| lr | 4e-5 | Learning rate | +| context_window | 32768 | Max tokens | +| kl_penalty | 0.1 | KL penalty coefficient | +| puct_c | 1.0 | PUCT exploration constant | +| entropic γ | ln(2) | Target KL for adaptive β | +| sandbox_timeout | 600s | Code execution limit | + +## Running + +_Run script coming soon. Will follow the `research/template_project/` pattern._ + +```bash +# 1. Start the Gym resource server +cd ~/Gym +ng_run "+config_paths=[resources_servers/erdos_discovery/configs/erdos_discovery.yaml]" + +# 2. Run GRPO training with TTT-Discover +cd ~/RL +python research/ttt_discover/run_discover.py \ + --config research/ttt_discover/configs/erdos_120b.yaml +``` + +## References + +- Yu Sun et al., "Learning to Discover at Test Time" (arXiv:2601.16175), 2026. +- Reference implementation: https://github.com/test-time-training/discover +- Haugland (2016) for prior SOTA bound 0.380927. diff --git a/nemo_rl/environments/LESSONS_LEARNED.md b/nemo_rl/environments/LESSONS_LEARNED.md new file mode 100644 index 0000000000..503e518dd1 --- /dev/null +++ b/nemo_rl/environments/LESSONS_LEARNED.md @@ -0,0 +1,182 @@ +# TTT-Discover on NeMo RL — Lessons Learned + +## What This Is + +TTT-Discover (arXiv:2601.16175) running the Erdős Minimum Overlap Problem +on NeMo RL's GRPO framework. The LLM writes Python code defining a step +function, which is executed in a sandbox and scored via FFT autocorrelation. + +## What Worked + +### Final Working Setup +- **Container**: `nvcr.io#nvidia/nemo-rl:v0.5.0` +- **Cluster**: 2 nodes, 8x B200 per node (16 GPUs total) +- **Model**: Qwen/Qwen2.5-1.5B-Instruct with LoRA (r=16) for debug +- **Ray orchestration**: Dakota's patched `ray.sub` (see below) +- **~27s per training step** (5 debug steps completed successfully) + +### Launch Pattern +```bash +cd ~/RL + +CONTAINER="nvcr.io#nvidia/nemo-rl:v0.5.0" +MOUNTS="$PWD:$PWD,/home/shared/models:/home/shared/models" +COMMAND=" +export HF_HUB_ENABLE_HF_TRANSFER=0 +export TORCH_CUDA_ARCH_LIST='9.0 10.0' +export PYTHONPATH=/path/to/RL:\${PYTHONPATH:-} +cd /opt/nemo-rl +python examples/run_discover.py --config examples/configs/grpo_erdos_discover_debug.yaml +" + +COMMAND="$COMMAND" CONTAINER="$CONTAINER" MOUNTS="$MOUNTS" GPUS_PER_NODE=8 \ +sbatch --nodes=2 --partition=batch --exclusive --time=01:00:00 ray.sub +``` + +### Key Config (grpo_erdos_discover_debug.yaml) +```yaml +defaults: "grpo_math_1B.yaml" # Inherit all base settings + +grpo: + num_prompts_per_step: 4 # 4 PUCT-selected states + num_generations_per_prompt: 8 # 8 rollouts per state + max_num_steps: 5 # Debug: 5 steps only + max_rollout_turns: 1 # Single-turn code generation + adv_estimator: + name: entropic_adaptive_beta # NOT standard grpo + gamma: 0.6931471805599453 # ln(2) + +policy: + model_name: "Qwen/Qwen2.5-1.5B-Instruct" + max_total_sequence_length: 4096 + dtensor_cfg: + cpu_offload: false # IMPORTANT: must be false or logprobs crash + lora_cfg: + enabled: true + rank: 16 + +env: + erdos_discovery: + resource_server_url: "inline" # No Gym server needed for debug +``` + +## Cluster-Specific Fixes (d2dfac12 / Together AI B200) + +### ray.sub Patches +Use Dakota's patched `ray.sub` from `~/dakota-ref/ray.sub`. Key changes from upstream: +1. **MPI**: `--mpi=pmi2` (not `pmix`) +2. **Container writable**: `--container-writable` instead of `--no-container-mount-home` +3. **NCCL IB config** in ray.sub itself (not just COMMAND): + ```bash + export NCCL_IB_HCA=mlx5_4:1,mlx5_7:1,mlx5_8:1,mlx5_9:1,mlx5_10:1,mlx5_13:1,mlx5_14:1,mlx5_15:1 + export NCCL_SOCKET_IFNAME=bond0 + export UCX_NET_DEVICES=bond0 + ``` +4. **No `--account`** needed on this cluster +5. `srun --overlap` works via `enroot exec` (Dakota's ray.sub handles this) + +### NGC Container Auth +```bash +mkdir -p ~/.config/enroot +cat > ~/.config/enroot/.credentials << 'EOF' +machine nvcr.io login $oauthtoken password nvapi-YOUR_KEY_HERE +EOF +``` + +### Container Pull Time +First pull takes ~10-12 minutes per node. Cached after that, but cache is +per-node (different nodes need their own pull). Plan for this in your first run. + +## What Broke and How We Fixed It + +### 1. TransformerEngine Won't Build (bare metal) +**Problem**: `uv sync --extra automodel` fails because TransformerEngine needs +`cudnn.h` which doesn't exist on the head node. +**Fix**: Use the NeMo RL container (has TE pre-built). Don't try bare metal +with LoRA — LoRA requires DTensorPolicyWorkerV2 which needs the `automodel` +extra which needs TransformerEngine. + +### 2. Ray Version Mismatch +**Problem**: Container has Ray 2.49.2, but `uv run` picks up Ray 2.54.0 from +the mounted `.venv`. +**Fix**: Use `python` directly (container's Python), not `uv run`. Set your +code path via `PYTHONPATH` instead. + +### 3. Code Compatibility with v0.5.0 Container +**Problem**: Our branch's `nemo_rl` code is newer than v0.5.0 (has `decord` +imports, `register_omegaconf_resolvers`, etc. that the container doesn't have). +**Fix**: Don't overwrite the container's code at `/opt/nemo-rl`. Instead: + - Copy only NEW files (custom estimator, environment, run script) + - Monkey-patch `grpo.py` and `utils.py` at runtime to register new components + - Add `mul`/`div` OmegaConf resolvers manually if `register_omegaconf_resolvers` isn't available + +### 4. Ray Actor Event Loop +**Problem**: `asyncio.get_event_loop().run_until_complete()` in environment +`step()` crashes with "This event loop is already running" inside Ray actors. +**Fix**: Detect running loop and use synchronous path instead: +```python +try: + loop = asyncio.get_running_loop() +except RuntimeError: + loop = None +if loop and loop.is_running(): + return self._sync_step(...) # No async +else: + return asyncio.run(self._async_step(...)) +``` + +### 5. Empty Observations +**Problem**: `KeyError: 'content'` in rollouts.py — the rollout engine expects +observations from `env.step()` to have a `content` key. +**Fix**: Return `[{"role": "user", "content": ""} for _ in range(batch_size)]` +not `[{}] * batch_size`. + +### 6. CPU Offload + LogProbs +**Problem**: `RuntimeError: Expected all tensors to be on the same device` — +model weights on CPU (from `cpu_offload: true`) but input_ids on cuda. +**Fix**: Set `dtensor_cfg.cpu_offload: false` for small debug models. +For large models, double-check the offload/reload flow works with your config. + +### 7. `${mul:...}` OmegaConf Resolver +**Problem**: Base config uses `${mul:a,b}` interpolation but v0.5.0 container +doesn't register this resolver. +**Fix**: Register it manually in your run script: +```python +from omegaconf import OmegaConf +if not OmegaConf.has_resolver("mul"): + OmegaConf.register_new_resolver("mul", lambda a, b: a * b) +``` + +## Architecture + +``` +NeMo RL Container (v0.5.0) +├── /opt/nemo-rl/ ← Container's code (base) +│ ├── nemo_rl/algorithms/ +│ │ ├── grpo.py ← monkey-patched at runtime +│ │ └── entropic_advantage_estimator.py ← our NEW file +│ ├── nemo_rl/environments/ +│ │ ├── utils.py ← monkey-patched at runtime +│ │ └── erdos_discovery_environment.py ← our NEW file +│ └── nemo_rl/utils/ +│ └── puct_buffer.py ← our NEW file +│ +├── /home/mormio/RL/ ← Mounted source (for cp at startup) +└── /home/shared/models/ ← Mounted model weights +``` + +## Scaling to 120B + +For gpt-oss-120b-bf16 (MoE, 128 experts): +- Use 8 nodes (2 training + 6 inference) or similar +- LoRA r=32, alpha=1.0 +- EP=8 for expert parallelism (see Dakota's 120B config) +- May need `cpu_offload: true` + careful memory management +- Set `max_total_sequence_length: 32768` for full context +- Consider `generation.colocated: false` with separate inference nodes + +## References +- Paper: "Learning to Discover at Test Time" (arXiv:2601.16175) +- Reference impl: https://github.com/test-time-training/discover +- Dakota's working configs: `~/dakota-ref/` +- NeMo RL docs: https://docs.nvidia.com/nemo/rl/latest/index.html diff --git a/nemo_rl/environments/erdos_discovery_environment.py b/nemo_rl/environments/erdos_discovery_environment.py new file mode 100644 index 0000000000..97c39701cc --- /dev/null +++ b/nemo_rl/environments/erdos_discovery_environment.py @@ -0,0 +1,787 @@ +"""Erdős Minimum Overlap Discovery Environment — matches reference TTT-Discover implementation. + +Reference: https://github.com/test-time-training/discover/blob/main/examples/erdos_min_overlap/env.py +Paper: "Learning to Discover at Test Time" (arXiv:2601.16175) + +Key differences from our v1: +- Uses C5 = max(np.correlate(h, 1-h, mode="full") * dx) formulation (h over [0,2]) +- Code must define run(seed=42, budget_s=1000) returning (h_values, c5_bound, n_points) +- Allows scipy, cvxpy in addition to numpy/math +- State context shows parent code + improvement direction +- reward = 1 / (1e-8 + c5_bound) +""" + +import asyncio +import logging +import math +import os +import re +import signal +import time +from typing import Any, Optional + +import numpy as np +import ray +import torch + +from nemo_rl.environments.interfaces import ( + EnvironmentInterface, + EnvironmentReturn, +) +from nemo_rl.data.interfaces import LLMMessageLogType +from nemo_rl.environments.erdos_ref_puct_sampler import ( + ErdosRefPUCTSampler, + ErdosRefState, + erdos_ref_state_to_prompt_state, +) + +logger = logging.getLogger(__name__) + +# ═══════════════════════════════════════════════════════════════════ +# Reference C5 verification (from TTT-Discover env.py) +# ═══════════════════════════════════════════════════════════════════ + +def verify_c5_solution(h_values, c5_achieved, n_points): + """Verify a C5 solution — exact copy from reference implementation.""" + if not isinstance(h_values, np.ndarray): + try: + h_values = np.array(h_values, dtype=np.float64) + except (ValueError, TypeError) as e: + raise ValueError(f"Cannot convert h_values to numpy array: {e}") + + if len(h_values.shape) != 1: + raise ValueError(f"h_values must be 1D array, got shape {h_values.shape}") + + if h_values.shape[0] != n_points: + raise ValueError(f"Expected h shape ({n_points},), got {h_values.shape}") + + if not np.all(np.isfinite(h_values)): + raise ValueError("h_values contain NaN or inf values") + + if np.any(h_values < 0) or np.any(h_values > 1): + raise ValueError(f"h(x) is not in [0, 1]. Range: [{h_values.min()}, {h_values.max()}]") + + n = n_points + target_sum = n / 2.0 + current_sum = np.sum(h_values) + + if current_sum != target_sum: + h_values = h_values * (target_sum / current_sum) + if np.any(h_values < 0) or np.any(h_values > 1): + raise ValueError( + f"After normalization, h(x) is not in [0, 1]. " + f"Range: [{h_values.min()}, {h_values.max()}]" + ) + + dx = 2.0 / n_points + j_values = 1.0 - h_values + correlation = np.correlate(h_values, j_values, mode="full") * dx + computed_c5 = np.max(correlation) + + if not np.isfinite(computed_c5): + raise ValueError(f"Computed C5 is not finite: {computed_c5}") + + if not np.isclose(computed_c5, c5_achieved, atol=1e-4): + raise ValueError(f"C5 mismatch: reported {c5_achieved:.6f}, computed {computed_c5:.6f}") + + return computed_c5 + + +# ═══════════════════════════════════════════════════════════════════ +# Sandbox execution +# ═══════════════════════════════════════════════════════════════════ + +_ALLOWED_MODULES = frozenset({ + "numpy", "np", "math", "cmath", "random", "scipy", "cvxpy", + "itertools", "functools", "collections", "fractions", "decimal", + "copy", "operator", "time", +}) + + +def _execute_run_function(code: str, timeout: int = 1000, n_cpus: int = 2) -> dict: + """Execute code that defines run(), call it, verify the result. + + Matches the reference SandboxRewardEvaluator flow. + """ + import builtins + + _SAFE_BUILTIN_NAMES = [ + "abs", "all", "any", "bool", "dict", "divmod", "enumerate", + "filter", "float", "format", "int", "isinstance", "issubclass", + "iter", "len", "list", "map", "max", "min", "next", "object", + "print", "range", "repr", "reversed", "round", "set", "slice", + "sorted", "str", "sum", "tuple", "type", "zip", "True", "False", + "None", "complex", "frozenset", "bytes", "bytearray", "memoryview", + "property", "staticmethod", "classmethod", "super", "hash", "id", + "input", "ord", "chr", "hex", "oct", "bin", "pow", + "Exception", "ValueError", "TypeError", "KeyError", "IndexError", + "StopIteration", "RuntimeError", "NotImplementedError", + "OverflowError", "ZeroDivisionError", "AttributeError", + "ImportError", "FileNotFoundError", "OSError", "ArithmeticError", + ] + + safe_builtins = {k: getattr(builtins, k) for k in _SAFE_BUILTIN_NAMES + if hasattr(builtins, k)} + + def _safe_import(name, globals=None, locals=None, fromlist=(), level=0): + base = name.split(".")[0] + if base not in _ALLOWED_MODULES: + raise ImportError(f"Module '{name}' not allowed") + return builtins.__import__(name, globals, locals, fromlist, level) + + safe_builtins["__import__"] = _safe_import + + import random as _random + namespace = { + "__builtins__": safe_builtins, + "np": np, + "numpy": np, + "math": math, + "random": _random, + } + + # Add evaluate_erdos_solution to namespace (reference injects this) + def _evaluate_erdos_solution(h_values, c5_bound, n_points): + verify_c5_solution(h_values, c5_bound, n_points) + return float(c5_bound) + + namespace["evaluate_erdos_solution"] = _evaluate_erdos_solution + + stdout_capture = [] + original_print = builtins.print + def capturing_print(*args, **kwargs): + import io + buf = io.StringIO() + kwargs["file"] = buf + original_print(*args, **kwargs) + stdout_capture.append(buf.getvalue()) + namespace["__builtins__"]["print"] = capturing_print + + try: + # Use multiprocessing for timeout (signal.alarm doesn't work in Ray actor threads) + import multiprocessing as mp + import pickle + + def _run_in_subprocess(code, ns_pickle, result_queue): + """Execute code in a subprocess with proper timeout support.""" + import signal as _signal + namespace = pickle.loads(ns_pickle) + + class _Timeout(Exception): + pass + def _handler(s, f): + raise _Timeout("timeout") + _signal.signal(_signal.SIGALRM, _handler) + _signal.alarm(timeout) + try: + exec(compile(code, "", "exec"), namespace) + if "run" not in namespace: + result_queue.put({"error": "No 'run' function defined"}) + return + out = namespace["run"](seed=42, budget_s=timeout) + result_queue.put({"result": out}) + except _Timeout: + result_queue.put({"error": f"Execution timed out after {timeout}s"}) + except Exception as e: + result_queue.put({"error": f"{type(e).__name__}: {str(e)[:300]}"}) + + # Run in subprocess with signal.alarm for clean timeout. + # The subprocess handles its own timeout and exits cleanly. + # We only use p.terminate() (SIGTERM, not SIGKILL) as a last resort. + import multiprocessing as _mp + + _EXEC_TIMEOUT = min(timeout, 1000) + + def _worker_fn(code_str, result_queue, exec_timeout): + """Run code in a subprocess. signal.alarm works here (main thread).""" + import signal as _sig + import sys as _sys + import os as _os + + class _AlarmTimeout(BaseException): + pass + + def _alarm_handler(signum, frame): + raise _AlarmTimeout() + + _sig.signal(_sig.SIGALRM, _alarm_handler) + _sig.alarm(exec_timeout) + + try: + import numpy, math, random + ns = { + "__builtins__": __builtins__ if isinstance(__builtins__, dict) else vars(__builtins__).copy(), + "np": numpy, "numpy": numpy, "math": math, "random": random, + "evaluate_erdos_solution": lambda h, c, n: float(c), + } + exec(compile(code_str, "", "exec"), ns) + if "run" not in ns: + result_queue.put({"error": "No 'run' function defined"}) + return + out = ns["run"](seed=42, budget_s=exec_timeout - 10) + # Serialize result (numpy arrays can't cross process boundary directly) + h = out[0].tolist() if hasattr(out[0], "tolist") else list(out[0]) + result_queue.put({"result": (h, float(out[1]), int(out[2]))}) + except _AlarmTimeout: + result_queue.put({"error": f"Execution timed out after {exec_timeout}s"}) + except Exception as e: + result_queue.put({"error": f"{type(e).__name__}: {str(e)[:300]}"}) + finally: + _sig.alarm(0) # Cancel any pending alarm + + q = _mp.Queue() + p = _mp.Process(target=_worker_fn, args=(code, q, _EXEC_TIMEOUT)) + p.start() + # Wait for subprocess: alarm should fire inside it, so give extra grace + p.join(timeout=_EXEC_TIMEOUT + 30) + + if p.is_alive(): + # Subprocess didn't exit cleanly — send SIGTERM first (graceful) + p.terminate() + p.join(timeout=10) + if p.is_alive(): + p.kill() # Last resort + p.join(timeout=5) + return { + "reward": 0.0, "raw_score": None, + "error_msg": f"Subprocess terminated after {_EXEC_TIMEOUT}s", + "stdout": "".join(stdout_capture), + } + + if q.empty(): + return { + "reward": 0.0, "raw_score": None, + "error_msg": "Subprocess exited without result", + "stdout": "".join(stdout_capture), + } + + try: + exec_result = q.get_nowait() + except Exception: + return { + "reward": 0.0, "raw_score": None, + "error_msg": "Failed to read subprocess result", + "stdout": "".join(stdout_capture), + } + + if "error" in exec_result: + return { + "reward": 0.0, "raw_score": None, + "error_msg": exec_result["error"], + "stdout": "".join(stdout_capture), + } + + raw = exec_result["result"] + h_values = np.asarray(raw[0], dtype=np.float64) + c5_bound = raw[1] + n_points = raw[2] + result = (h_values, c5_bound, n_points) + + + + # Verify + computed_c5 = verify_c5_solution(h_values, c5_bound, n_points) + + if computed_c5 <= 0 or not np.isfinite(computed_c5): + return { + "reward": 0.0, "raw_score": None, + "error_msg": f"Invalid C5: {computed_c5}", + "stdout": "".join(stdout_capture), + } + + return { + "reward": float(1.0 / (1e-8 + computed_c5)), + "raw_score": float(computed_c5), + "error_msg": "", + "result_construction": h_values.tolist(), + "n_points": int(n_points), + "stdout": "".join(stdout_capture), + } + + except Exception as e: + return { + "reward": 0.0, "raw_score": None, + "error_msg": f"{type(e).__name__}: {str(e)[:300]}", + "stdout": "".join(stdout_capture), + } + + +def _extract_code(response: str) -> str: + """Extract Python code from LLM response.""" + code_re = re.compile(r"```(?:python)?\s*\n(.*?)```", re.DOTALL) + blocks = code_re.findall(response) + if blocks: + return blocks[-1].strip() + # If no code block, try the whole response + return response.strip() + + +# ═══════════════════════════════════════════════════════════════════ +# Initial state generation (from reference) +# ═══════════════════════════════════════════════════════════════════ + +def create_initial_state(rng=None): + """Create a random initial state — matches reference exactly.""" + if rng is None: + rng = np.random.default_rng() + n_points = int(rng.integers(40, 100)) + construction = np.ones(n_points) * 0.5 + perturbation = rng.uniform(-0.4, 0.4, n_points) + perturbation = perturbation - np.mean(perturbation) + construction = construction + perturbation + dx = 2.0 / n_points + correlation = np.correlate(construction, 1 - construction, mode="full") * dx + c5_bound = float(np.max(correlation)) + return { + "construction": construction.tolist(), + "c5_bound": c5_bound, + "n_points": n_points, + "code": "", + "parent_c5": None, + "observation": "", + } + + +def erdos_ref_state_from_create_initial(rng: np.random.Generator) -> ErdosRefState: + """Match ttt-discover-ref ErdosMinOverlapEnv.create_initial_state → State.""" + d = create_initial_state(rng) + return ErdosRefState( + timestep=-1, + construction=list(d["construction"]), + code="", + value=-float(d["c5_bound"]), + parent_values=[], + parents=[], + observation="", + ) + + +# ═══════════════════════════════════════════════════════════════════ +# Prompt construction (from reference State.to_prompt + ErdosMinOverlapEnv.get_question) +# ═══════════════════════════════════════════════════════════════════ + +TARGET_C5 = 0.3808 + + +def state_to_prompt(state: dict) -> str: + """Build the state context portion — matches reference State.to_prompt().""" + c5 = state.get("c5_bound", None) + parent_c5 = state.get("parent_c5", None) + code = state.get("code", "") + observation = state.get("observation", "") + + value_ctx = "You are iteratively optimizing C₅ bound." + + if code and code.strip(): + value_ctx += f"\nHere is the last code we ran:\n```python\n{code}\n```" + else: + value_ctx += "\nNo previous code available." + + if parent_c5 is not None and c5 is not None: + current_gap = c5 - TARGET_C5 + value_ctx += ( + f"\nHere is the C₅ bound before and after running the code above " + f"(lower is better): {parent_c5:.6f} -> {c5:.6f}" + ) + value_ctx += ( + f"\nTarget: {TARGET_C5}. Current gap: {current_gap:.6f}. " + f"Further improvements will also be generously rewarded." + ) + elif c5 is not None: + current_gap = c5 - TARGET_C5 + value_ctx += f"\nCurrent C₅ bound (lower is better): {c5:.6f}" + value_ctx += ( + f"\nTarget: {TARGET_C5}. Current gap: {current_gap:.6f}. " + f"Further improvements will also be generously rewarded." + ) + else: + value_ctx += f"\nTarget C₅ bound: {TARGET_C5}" + + if observation and observation.strip(): + stdout = observation.strip() + if len(stdout) > 500: + stdout = "\n\n\t\t ...(TRUNCATED)...\n" + stdout[-500:] + value_ctx += f"\n\n--- Previous Program Output ---\n{stdout}\n--- End Output ---" + + return value_ctx + + +def build_erdos_question(state: dict) -> str: + """Build the full Erdős question — matches reference ErdosMinOverlapEnv.get_question().""" + state_ctx = state_to_prompt(state) + + construction = state.get("construction", []) + n = len(construction) if construction else 0 + + construction_section = "" + if construction and n > 0: + construction_section = ( + f"\nYou may want to start your search from the current construction, " + f"which you can access through the `initial_h_values` global variable " + f"(n={n} samples).\n" + f"You are encouraged to explore solutions that use other starting points " + f"to prevent getting stuck in a local optimum.\n" + ) + + code = state.get("code", "") + if code and code.strip(): + code_section = ( + "Reason about how you could further improve this construction.\n" + "Ideally, try to do something different than the above algorithm. " + "Could be using different algorithmic ideas, adjusting your heuristics, " + "adjusting / sweeping your hyperparemeters, etc.\n" + "Unless you make a meaningful improvement, you will not be rewarded." + ) + else: + code_section = "Write code to optimize this construction." + + return f"""You are an expert in harmonic analysis, numerical optimization, and mathematical discovery. +Your task is to find an improved upper bound for the Erdős minimum overlap problem constant C₅. + +## Problem + +Find a step function h: [0, 2] → [0, 1] that **minimizes** the overlap integral: + +$$C_5 = \\max_k \\int h(x)(1 - h(x+k)) dx$$ + +**Constraints**: +1. h(x) ∈ [0, 1] for all x +2. ∫₀² h(x) dx = 1 + +**Discretization**: Represent h as n_points samples over [0, 2]. +With dx = 2.0 / n_points: +- 0 ≤ h[i] ≤ 1 for all i +- sum(h) * dx = 1 (equivalently: sum(h) == n_points / 2 exactly) + +The evaluation computes: C₅ = max(np.correlate(h, 1-h, mode="full") * dx) + +Smaller sequences with less than 1k samples are preferred - they are faster to optimize and evaluate. + +**Lower C₅ values are better** - they provide tighter upper bounds on the Erdős constant. + +## Budget & Resources +- **Time budget**: 1000s for your code to run +- **CPUs**: 2 available + +## Rules +- Define `run(seed=42, budget_s=1000, **kwargs)` that returns `(h_values, c5_bound, n_points)` +- Use scipy, numpy, cvxpy[CBC,CVXOPT,GLOP,GLPK,GUROBI,MOSEK,PDLP,SCIP,XPRESS,ECOS], math +- Make all helper functions top level, no closures or lambdas +- No filesystem or network IO +- `evaluate_erdos_solution()` and `initial_h_values` (an initial construction, if available) are pre-imported +- Your function must complete within budget_s seconds and return the best solution found + +**Lower is better**. Current record: C₅ ≤ 0.38092. Our goal is to find a construction that shows C₅ ≤ 0.38080. + +{state_ctx} +{construction_section} +{code_section} +""" + + +# ═══════════════════════════════════════════════════════════════════ +# NeMo RL Environment +# ═══════════════════════════════════════════════════════════════════ + +ErdosMetadata = dict[str, Any] + + +@ray.remote +class ErdosDiscoveryEnvironment(EnvironmentInterface): + """Erdős Minimum Overlap environment for NeMo RL GRPO. + + Matches the reference TTT-Discover implementation: + - C5 formulation with np.correlate + - run() function entrypoint + - scipy/cvxpy allowed + - State context with parent code + improvement tracking + """ + + def __init__(self, config: dict = None): + config = config or {} + self.sandbox_timeout = config.get("sandbox_timeout", 1000) + self.num_initial_states = int(config.get("num_initial_states", 8)) + self.puct_c = float(config.get("puct_c", 1.0)) + self.puct_seed_batch_size = int( + config.get("puct_seed_batch_size", self.num_initial_states) + ) + + # Tracking + self.best_reward = 0.0 + self.best_c5 = float("inf") + self.total_verified = 0 + self.total_valid = 0 + + log_dir = config.get("puct_log_dir") or os.environ.get( + "ERDOS_PUCT_LOG_DIR", "/tmp/erdos_puct" + ) + os.makedirs(log_dir, exist_ok=True) + sampler_path = os.path.join(log_dir, "puct_sampler.json") + resume_step = config.get("puct_resume_step") + if resume_step is not None: + resume_step = int(resume_step) + + self.sampler = ErdosRefPUCTSampler( + file_path=sampler_path, + init_state_fn=lambda: erdos_ref_state_from_create_initial( + np.random.default_rng() + ), + max_buffer_size=int(config.get("puct_max_buffer_size", 1000)), + batch_size=self.puct_seed_batch_size, + resume_step=resume_step, + puct_c=self.puct_c, + topk_children=int(config.get("puct_topk_children", 2)), + max_construction_len=int(config.get("puct_max_construction_len", 1000)), + ) + + def get_initial_states(self, n: int = None) -> list[dict]: + """Random states (e.g. validation). Training uses puct_sample_states().""" + rng = np.random.default_rng(42) + k = n if n is not None else self.num_initial_states + return [create_initial_state(rng) for _ in range(k)] + + def puct_sample_states(self, num_prompts: int) -> list[dict]: + """ttt-discover-ref PUCTSampler.sample_states — prompts + serial parent State.""" + picked = self.sampler.sample_states(num_prompts) + out: list[dict] = [] + for s in picked: + prompt_state = erdos_ref_state_to_prompt_state(s) + prompt_state["erdos_ref_state"] = s.to_dict() + out.append(prompt_state) + return out + + def step( + self, + message_log_batch: list[LLMMessageLogType], + metadata: list[ErdosMetadata], + ) -> EnvironmentReturn[ErdosMetadata]: + """Evaluate a batch of LLM responses.""" + try: + loop = asyncio.get_running_loop() + except RuntimeError: + loop = None + + # Always use sync path (Ray actors run in event loops) + return self._sync_step(message_log_batch, metadata) + + def _sync_step( + self, + message_log_batch: list[LLMMessageLogType], + metadata: list[ErdosMetadata], + ) -> EnvironmentReturn[ErdosMetadata]: + """Synchronous step — executes code and computes C5 reward.""" + batch_size = len(message_log_batch) + rewards = torch.zeros(batch_size) + terminateds = torch.ones(batch_size) + observations = [{"role": "user", "content": ""} for _ in range(batch_size)] + answers = [None] * batch_size + updated_metadata = list(metadata) + + import time as _time + _t0 = _time.time() + step_num = getattr(self, "_step_count", 0) + 1 + self._step_count = step_num + print(f"[{_time.strftime('%H:%M:%S')}] 🧪 Starting reward computation for {batch_size} rollouts") + + for i, message_log in enumerate(message_log_batch): + if i > 0 and i % 50 == 0: + elapsed = _time.time() - _t0 + rate = i / elapsed if elapsed > 0 else 0 + eta = (batch_size - i) / rate if rate > 0 else 0 + print( + f"[{_time.strftime('%H:%M:%S')}] Reward progress: {i}/{batch_size} " + f"({elapsed:.0f}s elapsed, {rate:.1f} it/s, ~{eta:.0f}s remaining)" + ) + + # Extract assistant response + response_text = "" + for msg in reversed(message_log): + if msg.get("role") == "assistant": + response_text = msg.get("content", "") + break + + # Extract code and execute + code = _extract_code(response_text) + + # Inject initial_h_values from state if available + state = metadata[i] if i < len(metadata) else {} + construction = state.get("construction", None) + parent_erdos: Optional[ErdosRefState] = None + raw_parent = state.get("erdos_ref_state") + if raw_parent is not None: + try: + parent_erdos = ErdosRefState.from_dict(raw_parent) + except Exception as e: + logger.warning("Invalid erdos_ref_state in metadata: %s", e) + + preamble = "import numpy as np\n\n" + if construction: + preamble += f"initial_h_values = np.array({construction!r})\n\n" + else: + preamble += "initial_h_values = None\n\n" + + full_code = preamble + code + result = _execute_run_function(full_code, timeout=self.sandbox_timeout) + + reward = result.get("reward", 0.0) + rewards[i] = reward + self.total_verified += 1 + + c5 = result.get("raw_score", None) + + if reward > 0: + self.total_valid += 1 + if c5 is not None and c5 < self.best_c5: + self.best_c5 = c5 + self.best_reward = reward + print(f"🏆 NEW BEST C5: {c5:.6f} (reward={reward:.4f})") + + answers[i] = f"C5={c5:.6f}" if c5 else f"reward={reward:.4f}" + + if parent_erdos is not None: + if ( + reward > 0 + and c5 is not None + and result.get("result_construction") is not None + ): + child = ErdosRefState( + timestep=step_num, + construction=list(result["result_construction"]), + code=code, + value=-float(c5), + observation=str(result.get("stdout", "") or ""), + ) + try: + self.sampler.update_states( + [child], [parent_erdos], save=False + ) + except Exception as e: + logger.warning("PUCT update_states failed: %s", e) + else: + self.sampler.record_failed_rollout(parent_erdos) + + if i < len(updated_metadata): + updated_metadata[i] = { + **updated_metadata[i], + "reward": reward, + "c5_bound": c5, + "error_msg": result.get("error_msg", ""), + "stdout": result.get("stdout", ""), + # Update state for PUCT if valid + "result_construction": result.get("result_construction"), + "n_points": result.get("n_points"), + } + + elapsed = _time.time() - _t0 + valid = sum(1 for r in rewards if r > 0) + max_r = float(rewards.max()) if len(rewards) > 0 else 0.0 + batch_best_c5 = 1.0 / max_r if max_r > 0 else float("inf") + global_best = self.best_c5 if self.best_c5 < float("inf") else float("inf") + print( + f"\n{'='*60}\n" + f"🎯 STEP REWARDS: {batch_size} rollouts in {elapsed:.1f}s\n" + f" valid: {valid}/{batch_size} ({100*valid/max(1,batch_size):.1f}%)\n" + f" avg_reward: {sum(float(r) for r in rewards)/max(1,batch_size):.6f}\n" + f" max_reward: {max_r:.6f}\n" + f" batch_best_C5: {batch_best_c5:.6f}\n" + f" GLOBAL_BEST_C5: {global_best:.6f}\n" + f"{'='*60}" + ) + + # Save outputs to JSONL for debugging + import json + + log_dir = os.environ.get("ERDOS_LOG_DIR", "/tmp/erdos_outputs") + os.makedirs(log_dir, exist_ok=True) + out_path = os.path.join(log_dir, f"step_{step_num:03d}.jsonl") + try: + with open(out_path, "w") as fout: + # Save a sample: first 10 + all valid ones + for idx in range(batch_size): + r = float(rewards[idx]) + save_this = (idx < 10) or (r > 0) + if not save_this: + continue + response = "" + for msg in reversed(message_log_batch[idx]): + if msg.get("role") == "assistant": + response = msg.get("content", "") + break + meta = updated_metadata[idx] if idx < len(updated_metadata) else {} + entry = { + "idx": idx, + "reward": r, + "c5_bound": meta.get("c5_bound"), + "error_msg": meta.get("error_msg", ""), + "response_len": len(response), + "response_preview": response[:500], + "code_preview": _extract_code(response)[:500] if response else "", + } + fout.write(json.dumps(entry) + "\n") + print(f" 📝 Saved outputs to {out_path}") + except Exception as e: + print(f" ⚠️ Failed to save outputs: {e}") + + try: + self.sampler.flush(step_num) + except Exception as e: + logger.warning("PUCT flush failed: %s", e) + + return EnvironmentReturn( + observations=observations, + metadata=updated_metadata, + next_stop_strings=[None] * batch_size, + rewards=rewards, + terminateds=terminateds, + answers=answers, + ) + + def global_post_process_and_metrics( + self, + metadata: list[ErdosMetadata], + ) -> tuple[list[ErdosMetadata], dict[str, float]]: + """Compute aggregate metrics after a step.""" + metrics = { + "best_c5": self.best_c5 if self.best_c5 < float("inf") else 0.0, + "best_reward": self.best_reward, + "total_verified": float(self.total_verified), + "total_valid": float(self.total_valid), + "valid_rate": ( + self.total_valid / max(1, self.total_verified) + ), + } + + # Count valid solutions in this batch + batch_rewards = [m.get("reward", 0.0) for m in metadata] + batch_valid = sum(1 for r in batch_rewards if r > 0) + batch_c5s = [m.get("c5_bound") for m in metadata if m.get("c5_bound") is not None] + + metrics["erdos/max_reward"] = float(max(batch_rewards)) if batch_rewards else 0.0 + metrics["erdos/avg_reward"] = float(sum(batch_rewards) / max(1, len(batch_rewards))) + metrics["erdos/valid_count"] = float(batch_valid) + metrics["erdos/valid_rate"] = float(batch_valid / max(1, len(batch_rewards))) + metrics["erdos/batch_size"] = float(len(metadata)) + + if batch_c5s: + metrics["erdos/best_c5"] = float(min(batch_c5s)) + metrics["erdos/mean_c5"] = float(sum(batch_c5s) / len(batch_c5s)) + metrics["erdos/worst_c5"] = float(max(batch_c5s)) + + metrics["erdos/global_best_c5"] = float(self.best_c5) if self.best_c5 < float("inf") else 0.0 + metrics["erdos/global_valid_total"] = float(self.total_valid) + + # Print summary to driver log + max_r = metrics["erdos/max_reward"] + avg_r = metrics["erdos/avg_reward"] + best = metrics.get("erdos/best_c5", "n/a") + print(f" 🎯 Erdős: avg_reward={avg_r:.4f} max_reward={max_r:.4f} " + f"valid={batch_valid}/{len(metadata)} " + f"best_c5={best}") + + try: + for k, v in self.sampler.get_sample_stats().items(): + metrics[f"erdos/{k}"] = float(v) + except Exception: + pass + + return metadata, metrics diff --git a/nemo_rl/environments/erdos_ref_puct_sampler.py b/nemo_rl/environments/erdos_ref_puct_sampler.py new file mode 100644 index 0000000000..c04b7152e4 --- /dev/null +++ b/nemo_rl/environments/erdos_ref_puct_sampler.py @@ -0,0 +1,513 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# PUCT sampler and Erdős State mirror ttt-discover-ref: +# ttt_discover/tinker_utils/sampler.py (PUCTSampler) +# ttt_discover/tinker_utils/state.py (State) +# examples/erdos_min_overlap/env.py (value = -C₅, minimize) + +from __future__ import annotations + +import json +import os +import threading +import time +import uuid +from contextlib import contextmanager +from dataclasses import dataclass, field +from typing import Any, Callable, Optional + +import numpy as np + + +def to_json_serializable(obj: Any) -> Any: + if isinstance(obj, np.ndarray): + return obj.tolist() + if isinstance(obj, (np.integer, np.floating)): + return obj.item() + if isinstance(obj, dict): + return {k: to_json_serializable(v) for k, v in obj.items()} + if isinstance(obj, (list, tuple)): + return [to_json_serializable(v) for v in obj] + return obj + + +@contextmanager +def _file_lock(lock_path: str, *, poll_s: float = 0.05, stale_s: float = 600.0): + while True: + try: + fd = os.open(lock_path, os.O_CREAT | os.O_EXCL | os.O_WRONLY) + try: + os.write(fd, f"{os.getpid()}\n{time.time()}\n".encode("utf-8")) + finally: + os.close(fd) + break + except FileExistsError: + try: + st = os.stat(lock_path) + if (time.time() - st.st_mtime) > stale_s: + os.remove(lock_path) + continue + except FileNotFoundError: + continue + time.sleep(poll_s) + try: + yield + finally: + try: + os.remove(lock_path) + except FileNotFoundError: + pass + + +def _atomic_write_json(path: str, obj: Any) -> None: + os.makedirs(os.path.dirname(path) or ".", exist_ok=True) + tmp_path = f"{path}.tmp.{os.getpid()}" + with open(tmp_path, "w", encoding="utf-8") as f: + json.dump(to_json_serializable(obj), f, indent=2, sort_keys=True) + f.flush() + os.fsync(f.fileno()) + os.replace(tmp_path, path) + + +def _read_json_or_default(path: str, default: Any) -> Any: + if not os.path.exists(path): + return default + try: + with open(path, "r", encoding="utf-8") as f: + return json.load(f) + except json.JSONDecodeError: + return default + + +def _sampler_file_for_step(base_path: str, step: int) -> str: + base_name = base_path.replace(".json", "") + return f"{base_name}_step_{step:06d}.json" + + +@dataclass +class ErdosRefState: + """Mirrors ttt_discover State for Erdős (value = -C₅, higher is better).""" + + timestep: int + construction: list + code: str = "" + value: Optional[float] = None + parent_values: list[float] = field(default_factory=list) + parents: list[dict] = field(default_factory=list) + observation: str = "" + id: str = field(default_factory=lambda: str(uuid.uuid4())) + + def to_dict(self) -> dict: + return { + "type": "ErdosRefState", + "id": self.id, + "timestep": self.timestep, + "value": self.value, + "parent_values": list(self.parent_values), + "parents": list(self.parents), + "observation": self.observation, + "construction": to_json_serializable(self.construction), + "code": self.code, + } + + @classmethod + def from_dict(cls, d: dict) -> ErdosRefState: + return cls( + timestep=int(d["timestep"]), + construction=list(d.get("construction") or []), + code=str(d.get("code") or ""), + value=d.get("value"), + parent_values=list(d.get("parent_values") or []), + parents=list(d.get("parents") or []), + observation=str(d.get("observation") or ""), + id=str(d.get("id") or uuid.uuid4()), + ) + + +def erdos_ref_state_to_prompt_state(s: ErdosRefState) -> dict[str, Any]: + """Map to keys consumed by build_erdos_question / state_to_prompt.""" + c5 = -float(s.value) if s.value is not None else None + parent_c5 = None + if s.parent_values: + parent_c5 = -float(s.parent_values[0]) + return { + "construction": list(s.construction) if s.construction else [], + "c5_bound": c5, + "n_points": len(s.construction) if s.construction else 0, + "code": s.code or "", + "parent_c5": parent_c5, + "observation": s.observation or "", + } + + +class ErdosRefPUCTSampler: + """ + Line-for-line behavior of ttt_discover PUCTSampler for Erdős discovery. + + score(i) = Q(i) + c * scale * P(i) * sqrt(1 + T) / (1 + n[i]) + Q(i) = m[i] if n[i]>0 else R(i); P = rank prior; scale = max(R)-min(R) on non-initial. + """ + + def __init__( + self, + file_path: str, + init_state_fn: Callable[[], ErdosRefState], + max_buffer_size: int = 1000, + batch_size: int = 1, + resume_step: Optional[int] = None, + puct_c: float = 1.0, + topk_children: int = 2, + max_construction_len: Optional[int] = 1000, + ): + self.file_path = file_path + self._init_state_fn = init_state_fn + self.max_buffer_size = max_buffer_size + self.batch_size = batch_size + self.topk_children = topk_children + self.puct_c = float(puct_c) + self.max_construction_len = max_construction_len + + self._states: list[ErdosRefState] = [] + self._initial_states: list[ErdosRefState] = [] + self._last_sampled_states: list[ErdosRefState] = [] + self._last_sampled_indices: list[int] = [] + self._lock = threading.Lock() + self._current_step = resume_step if resume_step is not None else 0 + + self._n: dict[str, int] = {} + self._m: dict[str, float] = {} + self._T: int = 0 + self._last_scale: float = 1.0 + self._last_puct_stats: list[tuple[int, float, float, float, float]] = [] + + if resume_step is not None: + self._load(resume_step) + if not self._states: + for _ in range(batch_size): + state = init_state_fn() + self._initial_states.append(state) + self._states.append(state) + self._save(self._current_step) + + @staticmethod + def _set_parent_info(child: ErdosRefState, parent: ErdosRefState) -> None: + child.parent_values = ( + [parent.value] + parent.parent_values if parent.value is not None else [] + ) + child.parents = [{"id": parent.id, "timestep": parent.timestep}] + parent.parents + + @staticmethod + def _filter_topk_per_parent( + states: list[ErdosRefState], + parent_states: list[ErdosRefState], + k: int, + ) -> tuple[list[ErdosRefState], list[ErdosRefState]]: + if not states: + return [], [] + if k == 0: + return states, parent_states + parent_to_children: dict[str, list[tuple[ErdosRefState, ErdosRefState]]] = {} + for child, parent in zip(states, parent_states): + pid = parent.id + parent_to_children.setdefault(pid, []).append((child, parent)) + topk_children, topk_parents = [], [] + for children_and_parents in parent_to_children.values(): + sorted_pairs = sorted( + children_and_parents, + key=lambda x: x[0].value if x[0].value is not None else float("-inf"), + reverse=True, + ) + for child, parent in sorted_pairs[:k]: + topk_children.append(child) + topk_parents.append(parent) + return topk_children, topk_parents + + def _load(self, step: int) -> None: + file_path = _sampler_file_for_step(self.file_path, step) + if not os.path.exists(file_path): + raise FileNotFoundError( + f"Cannot resume from step {step}: sampler file not found: {file_path}" + ) + with _file_lock(f"{file_path}.lock"): + store = _read_json_or_default(file_path, default=None) + if store is None: + raise ValueError(f"Failed to load sampler state from {file_path}") + self._states = [ErdosRefState.from_dict(s) for s in store.get("states", [])] + self._initial_states = [ + ErdosRefState.from_dict(s) for s in store.get("initial_states", []) + ] + self._n = store.get("puct_n", {}) or {} + self._m = store.get("puct_m", {}) or {} + self._T = int(store.get("puct_T", 0) or 0) + + def _save(self, step: int) -> None: + save_path = _sampler_file_for_step(self.file_path, step) + store = { + "step": step, + "states": [s.to_dict() for s in self._states], + "initial_states": [s.to_dict() for s in self._initial_states], + "puct_n": self._n, + "puct_m": self._m, + "puct_T": self._T, + } + with _file_lock(f"{save_path}.lock"): + _atomic_write_json(save_path, store) + + def _get_construction_key(self, state: ErdosRefState): + if state.construction: + return tuple(state.construction) + if state.code: + return state.code + return None + + def _compute_scale( + self, values: np.ndarray, mask: Optional[np.ndarray] = None + ) -> float: + if values.size == 0: + return 1.0 + v = values[mask] if mask is not None else values + return float(max(np.max(v) - np.min(v), 1e-6)) if v.size > 0 else 1.0 + + def _compute_prior(self, values: np.ndarray, scale: float) -> np.ndarray: + del scale # matches ref signature + if values.size == 0: + return np.array([]) + n = len(values) + ranks = np.argsort(np.argsort(-values)) + weights = (n - ranks).astype(np.float64) + return weights / weights.sum() + + def _get_lineage(self, state: ErdosRefState) -> set[str]: + lineage = {state.id} + for p in state.parents or []: + if p.get("id"): + lineage.add(str(p["id"])) + return lineage + + def _build_children_map(self) -> dict[str, set[str]]: + children: dict[str, set[str]] = {} + for s in self._states: + for p in s.parents or []: + pid = p.get("id") + if pid: + children.setdefault(str(pid), set()).add(s.id) + return children + + def _get_full_lineage( + self, state: ErdosRefState, children_map: dict[str, set[str]] + ) -> set[str]: + lineage = self._get_lineage(state) + queue = [state.id] + visited = {state.id} + while queue: + sid = queue.pop(0) + for child_id in children_map.get(sid, []): + if child_id not in visited: + visited.add(child_id) + lineage.add(child_id) + queue.append(child_id) + return lineage + + def sample_states(self, num_states: int) -> list[ErdosRefState]: + initial_ids = {s.id for s in self._initial_states} + candidates = list(self._states) + + if not candidates: + picked = [self._init_state_fn() for _ in range(num_states)] + self._last_sampled_states = picked + self._last_sampled_indices = [] + self._last_puct_stats = [(0, 0.0, 0.0, 0.0, 0.0) for _ in picked] + return picked + + vals = np.array( + [float(s.value if s.value is not None else float("-inf")) for s in candidates] + ) + non_initial_mask = np.array([s.id not in initial_ids for s in candidates]) + scale = self._compute_scale( + vals, non_initial_mask if non_initial_mask.any() else None + ) + self._last_scale = scale + p = self._compute_prior(vals, scale) + sqrt_t = np.sqrt(1.0 + self._T) + + scores = [] + for i, s in enumerate(candidates): + n = self._n.get(s.id, 0) + m = self._m.get(s.id, vals[i]) + q = m if n > 0 else vals[i] + bonus = self.puct_c * scale * p[i] * sqrt_t / (1.0 + n) + score = q + bonus + scores.append((score, vals[i], s, n, q, p[i], bonus)) + + scores.sort(key=lambda x: (x[0], x[1]), reverse=True) + + if num_states > 1: + children_map = self._build_children_map() + picked, top_scores = [], [] + blocked_ids: set[str] = set() + for entry in scores: + s = entry[2] + if s.id in blocked_ids: + continue + picked.append(s) + top_scores.append(entry) + blocked_ids.update(self._get_full_lineage(s, children_map)) + if len(picked) >= num_states: + break + else: + top_scores = scores[:num_states] + picked = [t[2] for t in top_scores] + + state_id_to_idx = {s.id: i for i, s in enumerate(self._states)} + self._last_sampled_states = picked + self._last_sampled_indices = [state_id_to_idx.get(s.id, -1) for s in picked] + self._last_puct_stats = [(t[3], t[4], t[5], t[6], t[0]) for t in top_scores] + + # Erdős: no construction_length_limits — no refresh (ref no-op for this env) + return picked + + def update_states( + self, + states: list[ErdosRefState], + parent_states: list[ErdosRefState], + save: bool = True, + step: Optional[int] = None, + ) -> None: + if not states: + return + assert len(states) == len(parent_states) + + parent_max: dict[str, float] = {} + parent_obj: dict[str, ErdosRefState] = {} + for child, parent in zip(states, parent_states): + if child.value is None: + continue + pid = parent.id + parent_obj[pid] = parent + parent_max[pid] = max(parent_max.get(pid, float("-inf")), float(child.value)) + + for pid, y in parent_max.items(): + self._m[pid] = max(self._m.get(pid, y), y) + parent = parent_obj[pid] + anc_ids = [pid] + [ + str(p["id"]) for p in (parent.parents or []) if p.get("id") + ] + for aid in anc_ids: + self._n[aid] = self._n.get(aid, 0) + 1 + self._T += 1 + + states, parent_states = self._filter_topk_per_parent( + states, parent_states, self.topk_children + ) + existing = {self._get_construction_key(s) for s in self._states} + existing.discard(None) + + new_states = [] + for child, parent in zip(states, parent_states): + if child.value is None: + continue + if ( + self.max_construction_len is not None + and child.construction + and len(child.construction) > self.max_construction_len + ): + continue + key = self._get_construction_key(child) + if key is not None and key in existing: + continue + self._set_parent_info(child, parent) + new_states.append(child) + if key is not None: + existing.add(key) + + if not new_states: + return + with self._lock: + self._states.extend(new_states) + if save: + self._finalize_and_save(step) + + def _finalize_and_save(self, step: Optional[int] = None) -> None: + if len(self._states) > self.max_buffer_size: + actual_values = [ + s.value if s.value is not None else float("-inf") for s in self._states + ] + by_actual = list(np.argsort(actual_values)[::-1]) + initial_ids = {s.id for s in self._initial_states} + initial_indices = {i for i, s in enumerate(self._states) if s.id in initial_ids} + keep = set(initial_indices) + for i in by_actual: + if len(keep) >= self.max_buffer_size: + break + keep.add(i) + self._states = [self._states[i] for i in sorted(keep)] + if step is not None: + self._current_step = step + self._save(self._current_step) + + def flush(self, step: Optional[int] = None) -> None: + with self._lock: + if self.topk_children > 0: + by_parent: dict[str, list[ErdosRefState]] = {} + no_parent: list[ErdosRefState] = [] + for s in self._states: + pid = s.parents[0]["id"] if s.parents else None + if pid: + by_parent.setdefault(pid, []).append(s) + else: + no_parent.append(s) + filtered = [] + for children in by_parent.values(): + children.sort( + key=lambda x: x.value if x.value is not None else float("-inf"), + reverse=True, + ) + filtered.extend(children[: self.topk_children]) + self._states = no_parent + filtered + self._finalize_and_save(step) + + def record_failed_rollout(self, parent: ErdosRefState) -> None: + anc_ids = [parent.id] + [ + str(p["id"]) for p in (parent.parents or []) if p.get("id") + ] + for aid in anc_ids: + self._n[aid] = self._n.get(aid, 0) + 1 + self._T += 1 + + def get_sample_stats(self) -> dict[str, float]: + def _stats(values, prefix): + arr = np.array([v for v in values if v is not None]) + if len(arr) == 0: + return {} + return { + f"{prefix}/mean": float(np.mean(arr)), + f"{prefix}/std": float(np.std(arr)), + f"{prefix}/min": float(np.min(arr)), + f"{prefix}/max": float(np.max(arr)), + } + + buffer_values = [s.value for s in self._states] + buffer_timesteps = [s.timestep for s in self._states] + buffer_constr_lens = [ + len(s.construction) if s.construction else 0 for s in self._states + ] + sampled_values = [s.value for s in self._last_sampled_states] + sampled_timesteps = [s.timestep for s in self._last_sampled_states] + sampled_constr_lens = [ + len(s.construction) if s.construction else 0 + for s in self._last_sampled_states + ] + stats: dict[str, float] = { + "puct/buffer_size": float(len(self._states)), + "puct/sampled_size": float(len(self._last_sampled_states)), + "puct/T": float(self._T), + "puct/scale_last": float(self._last_scale), + } + stats.update(_stats(buffer_values, "puct/buffer_value")) + stats.update(_stats(buffer_timesteps, "puct/buffer_timestep")) + stats.update(_stats(buffer_constr_lens, "puct/buffer_construction_len")) + stats.update(_stats(sampled_values, "puct/sampled_value")) + stats.update(_stats(sampled_timesteps, "puct/sampled_timestep")) + stats.update(_stats(sampled_constr_lens, "puct/sampled_construction_len")) + return stats diff --git a/nemo_rl/environments/utils.py b/nemo_rl/environments/utils.py index df82c7d1af..a1bc6ace3f 100644 --- a/nemo_rl/environments/utils.py +++ b/nemo_rl/environments/utils.py @@ -50,6 +50,9 @@ class EnvRegistryEntry(TypedDict, total=False): "vlm": { "actor_class_fqn": "nemo_rl.environments.vlm_environment.VLMEnvironment", }, + "erdos_discovery": { + "actor_class_fqn": "nemo_rl.environments.erdos_discovery_environment.ErdosDiscoveryEnvironment", + }, "nemo_gym": { "actor_class_fqn": "nemo_rl.environments.nemo_gym.NemoGym", }, diff --git a/pyproject.toml b/pyproject.toml index 17830c391c..22bc8a1ba8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -50,11 +50,11 @@ dependencies = [ "swanlab", "pyzmq", "decord2", - "nccl4py", # for non-colocated refit "cuda-bindings", # for non-colocated refit "pybase64", # for sglang refit "nvidia-cudnn-cu12==9.19.0.56", # for transformer-engine no build isolation + "sentencepiece>=0.2.1", ] [project.optional-dependencies] diff --git a/ray.sub b/ray.sub index e6e3e07af7..3a9b367599 100644 --- a/ray.sub +++ b/ray.sub @@ -32,7 +32,8 @@ maybe_gres_arg() { # Assumes a homogeneous allocation (not a heterogeneous job) if sinfo -p $SLURM_JOB_PARTITION -h -o "%G" | grep -q "gpu:"; then # Do a quick assert here that gpus:8 == gpus:$GPUS_PER_NODE. It is probably a user error if someone isn't using GPUS_PER_NODE=8 on our clusters if it supports --gres=gpu:8 or gpu:a100:8 - if [[ $GPUS_PER_NODE -ne $(sinfo -p $SLURM_JOB_PARTITION -h -o "%G" | grep "gpu:" | awk -F: '{print $NF}') ]]; then + # Note: cut -d'(' -f1 removes the socket spec like "(S:0-3)" that some clusters append + if [[ $GPUS_PER_NODE -ne $(sinfo -p $SLURM_JOB_PARTITION -h -o "%G" | grep "gpu:" | cut -d'(' -f1 | awk -F: '{print $NF}') ]]; then echo "Error: GPUS_PER_NODE=$GPUS_PER_NODE but GRES detected is $(sinfo -p $SLURM_JOB_PARTITION -h -o "%G" | grep "gpu:") meaning GPUS_PER_NODE is not set to fully claim the GPUs on the nodes." >&2 exit 1 fi @@ -47,9 +48,11 @@ maybe_gres_arg() { ######################################################## # User defined variables ######################################################## +export OMP_NUM_THREADS=16 CONTAINER=$CONTAINER MOUNTS=$MOUNTS COMMAND=${COMMAND:-} # This is a script relative to the SLURM_SUBMIT_DIR. If left empty, it will leave the cluster idle after it's brought up. +SETUP_COMMAND=${SETUP_COMMAND:-} # Setup commands to run on all nodes before starting Ray ######################################################## # Ports for all nodes (should be odd numbers since we place head/worker[0] on the same node) so all workers get the odd ports, but the head will get +1 the ports NODE_MANAGER_PORT=${NODE_MANAGER_PORT:-53001} @@ -59,8 +62,9 @@ DASHBOARD_AGENT_GRPC_PORT=${DASHBOARD_AGENT_GRPC_PORT:-53007} METRICS_EXPORT_PORT=${METRICS_EXPORT_PORT:-53009} # Ports for the head node -PORT=${PORT:-54514} -RAY_CLIENT_SERVER_PORT=${RAY_CLIENT_SERVER_PORT:-10001} +# GCS head port and client server port -- kept below ephemeral range (32768-60999). +PORT=${PORT:-9900} +RAY_CLIENT_SERVER_PORT=${RAY_CLIENT_SERVER_PORT:-9901} #REDIT_SHARD_PORTS=${REDIT_SHARD_PORTS:-"random"} ?? DASHBOARD_PORT=${DASHBOARD_PORT:-8265} # Also used by debugger DASHBOARD_AGENT_LISTEN_PORT=${DASHBOARD_AGENT_LISTEN_PORT:-52365} @@ -84,16 +88,36 @@ elif [[ $(ulimit -Hn) != "unlimited" ]] && [[ $(ulimit -Hn) -lt 65535 ]]; then echo "[WARNING]: Cannot increase ulimit on file descriptors to 65535 according ray recommendation: https://docs.ray.io/en/latest/cluster/vms/user-guides/large-cluster-best-practices.html. Speak to cluster admins to increase, otherwise ray may crash unexpectedly." fi -# On our clusters, the largest port range on an idle worker appeared between 52369-64607 -# (not including the other ports set by this script). So this range is chosen to be -# somewhere in the middle -MIN_WORKER_PORT=${MIN_WORKER_PORT:-54001} -MAX_WORKER_PORT=${MAX_WORKER_PORT:-54513} +# Worker port range must NOT overlap with the OS ephemeral range (32768-60999) +# to prevent TOCTOU collisions. Ray's Raylet uses a probe-and-release pattern +# (CheckPortFree) that can leave a window where the port is unguarded. +# If the range overlaps with the ephemeral range, the kernel can assign the same +# port as an ephemeral source port for outgoing TCP traffic during that window, +# causing EADDRINUSE when the worker tries to bind. +# +# Port layout (all below ephemeral range): +# 10002-11000 Ray worker gRPC (this setting) +# 11001-15000 NeMo RL HTTP servers / TCPStore (policy.generation.port_range_low/high) +# 15001-20000 NeMo Gym HTTP servers (Gym global config port_range_low/high) +MIN_WORKER_PORT=${MIN_WORKER_PORT:-10002} +MAX_WORKER_PORT=${MAX_WORKER_PORT:-11000} ######################################################## # Number seconds to sync logs from /tmp/ray/session_*/logs to $LOG_DIR/ray/ RAY_LOG_SYNC_FREQUENCY=${RAY_LOG_SYNC_FREQUENCY:-} ######################################################## +# NCCL IB config for multi-node training +export NCCL_BUFFSIZE=33554432 +export CUDA_DEVICE_ORDER=PCI_BUS_ID +export NCCL_IB_AR_THRESHOLD=0 +export NCCL_IB_PCI_RELAXED_ORDERING=1 +export NCCL_IB_QPS_PER_CONNECTION=2 +export NCCL_IB_SPLIT_DATA_ON_QPS=0 +export NCCL_IGNORE_CPU_AFFINITY=1 +export NCCL_IB_HCA=mlx5_4:1,mlx5_7:1,mlx5_8:1,mlx5_9:1,mlx5_10:1,mlx5_13:1,mlx5_14:1,mlx5_15:1 +export NCCL_SOCKET_IFNAME=bond0 +export UCX_NET_DEVICES=bond0 + # Unset UV_CACHE_DIR to avoid local cache directory interferring with the container cache unset UV_CACHE_DIR @@ -111,6 +135,14 @@ BASE_LOG_DIR=${BASE_LOG_DIR:-$SLURM_SUBMIT_DIR} LOG_DIR="$BASE_LOG_DIR/$SLURM_JOB_ID-logs" mkdir -p $LOG_DIR +# Write SETUP_COMMAND to a file to avoid heredoc escaping issues +SETUP_COMMAND_FILE="" +if [[ -n "$SETUP_COMMAND" ]]; then + SETUP_COMMAND_FILE="$LOG_DIR/setup_command.sh" + echo "$SETUP_COMMAND" > "$SETUP_COMMAND_FILE" + chmod +x "$SETUP_COMMAND_FILE" +fi + # Number of GPUs per worker node GPUS_PER_NODE=${GPUS_PER_NODE:-8} @@ -123,8 +155,9 @@ else fi COMMON_SRUN_ARGS="$GRES_ARG" -COMMON_SRUN_ARGS+=" --no-container-mount-home" -COMMON_SRUN_ARGS+=" --mpi=pmix" +# COMMON_SRUN_ARGS+=" --no-container-mount-home" # Disabled: need shared home for model conversion cache +COMMON_SRUN_ARGS+=" --container-writable" +COMMON_SRUN_ARGS+=" --mpi=pmi2" COMMON_SRUN_ARGS+=" --container-mounts=$MOUNTS" COMMON_SRUN_ARGS+=" --container-image=$CONTAINER" COMMON_SRUN_ARGS+=" --container-workdir=$SLURM_SUBMIT_DIR" @@ -132,7 +165,7 @@ COMMON_SRUN_ARGS+=" --container-workdir=$SLURM_SUBMIT_DIR" COMMON_SRUN_ARGS+=" -p $SLURM_JOB_PARTITION" COMMON_SRUN_ARGS+=" -A $SLURM_JOB_ACCOUNT" # Number of CPUs per worker node -CPUS_PER_WORKER=${CPUS_PER_WORKER:-$((GPUS_PER_NODE * 16))} +CPUS_PER_WORKER=${CPUS_PER_WORKER:-112} num_retries=3 @@ -295,6 +328,10 @@ chmod +x /launch-head.sh count=0 while [[ \$count -lt $num_retries ]]; do + if [[ -n "$SETUP_COMMAND_FILE" ]] && [[ -f "$SETUP_COMMAND_FILE" ]]; then + echo "[INFO] Running setup command from $SETUP_COMMAND_FILE..." + bash "$SETUP_COMMAND_FILE" + fi bash /launch-head.sh count=\$((count+1)) echo "Head node failed \$count/$num_retries times, restarting in 5 seconds..." @@ -304,9 +341,10 @@ touch $LOG_DIR/ENDED exit 1 EOF ) -srun $COMMON_SRUN_ARGS --container-name=ray-head --nodes=1 --ntasks=1 --cpus-per-task=$CPUS_PER_WORKER -w "$head_node" -o $LOG_DIR/ray-head.log bash -x -c "$head_cmd" & +srun $COMMON_SRUN_ARGS --nodes=1 --ntasks=1 --cpus-per-task=$CPUS_PER_WORKER -w "$head_node" -o $LOG_DIR/ray-head.log bash -x -c "$head_cmd" & SRUN_PIDS["ray-head"]=$! +sleep 120s NUM_ACTORS=$((GPUS_PER_NODE * SLURM_JOB_NUM_NODES)) # Start Ray worker nodes @@ -394,6 +432,10 @@ EOFINNER count=0 while [[ \$count -lt $num_retries ]]; do + if [[ -n "$SETUP_COMMAND_FILE" ]] && [[ -f "$SETUP_COMMAND_FILE" ]]; then + echo "[INFO] Running setup command from $SETUP_COMMAND_FILE..." + bash "$SETUP_COMMAND_FILE" + fi bash /launch-worker.sh count=\$((count+1)) echo "Worker failed \$count/$num_retries times, restarting in 5 seconds..." @@ -403,7 +445,7 @@ touch $LOG_DIR/ENDED exit 1 EOF ) - srun $COMMON_SRUN_ARGS --container-name=ray-worker-$i --exact --nodes=1 --ntasks=1 --cpus-per-task=$CPUS_PER_WORKER -w "$node_i" -o $LOG_DIR/ray-worker-$i.log bash -x -c "$worker_cmd" & + srun $COMMON_SRUN_ARGS --exact --nodes=1 --ntasks=1 --cpus-per-task=$CPUS_PER_WORKER -w "$node_i" -o $LOG_DIR/ray-worker-$i.log bash -x -c "$worker_cmd" & SRUN_PIDS["ray-worker-$i"]=$! sleep 3 done @@ -414,11 +456,60 @@ while check_srun_processes && ! srun --overlap --nodes=1 --ntasks=1 -w $head_nod sleep 2 done +######################################################## +# Run sandbox in parallel on the head node (overlap) +######################################################## +export SLURM_MASTER_NODE=$(scontrol show hostnames $SLURM_JOB_NODELIST | head -n1) + +# Check if SANDBOX_CONTAINER and SANDBOX_COMMAND are defined +if [[ -n "${SANDBOX_CONTAINER:-}" ]] && [[ -n "${SANDBOX_COMMAND:-}" ]]; then + SANDBOX_PORTS_DIR="$LOG_DIR/sandbox" + mkdir -p "$SANDBOX_PORTS_DIR" + echo "[INFO] Starting sandbox on head node in parallel (ports_dir=$SANDBOX_PORTS_DIR)..." + srun --output "$LOG_DIR/sandbox.log" \ + --error "$LOG_DIR/sandbox.log" \ + --container-image="$SANDBOX_CONTAINER" \ + --container-mounts="$SANDBOX_PORTS_DIR:$SANDBOX_PORTS_DIR" \ + --no-container-mount-home \ + --mpi=pmi2 \ + -A "$SLURM_JOB_ACCOUNT" \ + -p "$SLURM_JOB_PARTITION" \ + --wait=60 \ + --kill-on-bad-exit=1 \ + --overlap \ + --nodes="$SLURM_JOB_NUM_NODES" \ + --ntasks-per-node=1 \ + -w "$head_node" \ + --export=ALL,SANDBOX_PORTS_DIR=$SANDBOX_PORTS_DIR \ + bash -c "$SANDBOX_COMMAND" & + SRUN_PIDS["sandbox"]=$! + echo "[INFO] Sandbox started in background (PID: ${SRUN_PIDS["sandbox"]})" +else + echo "[INFO] SANDBOX_CONTAINER or SANDBOX_COMMAND not defined, skipping sandbox startup" +fi + # At this stage the Ray cluster bringup has started on the physical nodes in the allocation # Before we launch a job on this cluster we need to make sure that the bringup is complete # We do so by querying the number of worker_units in the ray cluster and asserting = NUM_ACTORS + +# Helper function to get container PID using enroot +# Since we use -w to target specific nodes and only run one container per node, +# we can just grab the first (and only) container PID on that node +get_container_pid() { + local node=$1 + srun --overlap --nodes=1 -w "$node" bash -c "enroot list -f | awk 'NR>1 && \$2 ~ /^[0-9]+\$/ {print \$2; exit}'" +} + extract_worker_units() { - status_output=$(srun --overlap --container-name=ray-head --nodes=1 --ntasks=1 -w "$head_node" ray status) + # Get the container PID for ray-head + head_container_pid=$(get_container_pid "$head_node") + if [[ -z "$head_container_pid" ]]; then + echo 0 + return + fi + + # Execute ray status inside the container using enroot exec + status_output=$(srun --overlap --nodes=1 -w "$head_node" enroot exec "$head_container_pid" ray status) if echo "$status_output" | grep -q "worker_units"; then worker_units=$(echo "$status_output" | grep "worker_units" | awk -F'[/. ]' '{print $4}') echo $worker_units @@ -448,20 +539,30 @@ echo "All workers connected!" # This driver process is responsible for launching a job on the Ray cluster CONTAINER_CWD=$(scontrol show job $SLURM_JOB_ID | grep -oP 'WorkDir=\K[^ ]+' | head -1) if [[ -n "$COMMAND" ]]; then - srun --no-container-mount-home --overlap --container-name=ray-head --container-workdir=$CONTAINER_CWD --nodes=1 --ntasks=1 -w "$head_node" -o $LOG_DIR/ray-driver.log bash -c "$COMMAND" + # Get container PID and execute command inside it + head_container_pid=$(get_container_pid "$head_node") + srun --overlap --nodes=1 -w "$head_node" -o $LOG_DIR/ray-driver.log enroot exec "$head_container_pid" bash -c "cd $CONTAINER_CWD && $COMMAND" else echo "[INFO]: Ray Cluster is idled, run this on the slurm head node to get a shell to the head node:" cat <$SLURM_SUBMIT_DIR/${SLURM_JOB_ID}-attach.sh # No args launches on the head node (node 0) # Args 1-N launch on worker nodes (nodes 1 through N-1) # Optional: set COMMAND='...' to run non-interactively instead of opening an interactive shell + +# Helper to get container PID +get_container_pid() { + local node=\$1 + srun --overlap --nodes=1 -w "\$node" --jobid $SLURM_JOB_ID bash -c "enroot list -f | awk 'NR>1 && \\\$2 ~ /^[0-9]+\\\$/ {print \\\$2; exit}'" +} + WORKER_NUM=\${1:-} if [[ -z "\$WORKER_NUM" ]]; then # Empty means we are on the head node + HEAD_PID=\$(get_container_pid "$head_node") if [[ -n "\${COMMAND:-}" ]]; then - srun --no-container-mount-home $GRES_ARG -A $SLURM_JOB_ACCOUNT -p $SLURM_JOB_PARTITION --overlap --container-name=ray-head --container-workdir=$CONTAINER_CWD --nodes=1 --ntasks=1 -w "$head_node" --jobid $SLURM_JOB_ID bash -c "\$COMMAND" + srun $GRES_ARG -A $SLURM_JOB_ACCOUNT -p $SLURM_JOB_PARTITION --overlap --nodes=1 -w "$head_node" --jobid $SLURM_JOB_ID enroot exec "\$HEAD_PID" bash -c "cd $CONTAINER_CWD && \$COMMAND" else - srun --no-container-mount-home $GRES_ARG -A $SLURM_JOB_ACCOUNT -p $SLURM_JOB_PARTITION --overlap --container-name=ray-head --container-workdir=$CONTAINER_CWD --nodes=1 --ntasks=1 -w "$head_node" --jobid $SLURM_JOB_ID --pty bash + srun $GRES_ARG -A $SLURM_JOB_ACCOUNT -p $SLURM_JOB_PARTITION --overlap --nodes=1 -w "$head_node" --jobid $SLURM_JOB_ID --pty enroot exec "\$HEAD_PID" bash -c "cd $CONTAINER_CWD && exec bash" fi else # Worker numbers 1 through N-1 correspond to ray-worker-1 through ray-worker-(N-1) @@ -471,10 +572,12 @@ else exit 1 fi nodes_array=($nodes) + node="\${nodes_array[\$WORKER_NUM]}" + WORKER_PID=\$(get_container_pid "\$node") if [[ -n "\${COMMAND:-}" ]]; then - srun --no-container-mount-home $GRES_ARG -A $SLURM_JOB_ACCOUNT -p $SLURM_JOB_PARTITION --overlap --container-name=ray-worker-\$WORKER_NUM --container-workdir=$CONTAINER_CWD --nodes=1 --ntasks=1 -w "\${nodes_array[\$WORKER_NUM]}" --jobid $SLURM_JOB_ID bash -c "\$COMMAND" + srun $GRES_ARG -A $SLURM_JOB_ACCOUNT -p $SLURM_JOB_PARTITION --overlap --nodes=1 -w "\$node" --jobid $SLURM_JOB_ID enroot exec "\$WORKER_PID" bash -c "cd $CONTAINER_CWD && \$COMMAND" else - srun --no-container-mount-home $GRES_ARG -A $SLURM_JOB_ACCOUNT -p $SLURM_JOB_PARTITION --overlap --container-name=ray-worker-\$WORKER_NUM --container-workdir=$CONTAINER_CWD --nodes=1 --ntasks=1 -w "\${nodes_array[\$WORKER_NUM]}" --jobid $SLURM_JOB_ID --pty bash + srun $GRES_ARG -A SLURM_JOB_ACCOUNT -p $SLURM_JOB_PARTITION --overlap --nodes=1 -w "\$node" --jobid $SLURM_JOB_ID --pty enroot exec "\$WORKER_PID" bash -c "cd $CONTAINER_CWD && exec bash" fi fi EOF @@ -485,3 +588,4 @@ EOF echo " bash $SLURM_SUBMIT_DIR/${SLURM_JOB_ID}-attach.sh 2 # to attach to worker 2, etc." sleep infinity fi + diff --git a/uv.lock b/uv.lock index 27ddd02b2a..e110fa526d 100644 --- a/uv.lock +++ b/uv.lock @@ -4482,6 +4482,7 @@ dependencies = [ { name = "pyzmq" }, { name = "ray", extra = ["default"] }, { name = "rich" }, + { name = "sentencepiece" }, { name = "setuptools" }, { name = "swanlab" }, { name = "sympy" }, @@ -4628,6 +4629,7 @@ requires-dist = [ { name = "pyzmq" }, { name = "ray", extras = ["default"], specifier = "==2.54.0" }, { name = "rich" }, + { name = "sentencepiece", specifier = ">=0.2.1" }, { name = "setuptools" }, { name = "sgl-kernel", marker = "extra == 'sglang'", git = "https://github.com/JustinTong0323/sglang.git?subdirectory=sgl-kernel&rev=70aa688742dd2b75bf9e8e980249303f39295b0d" }, { name = "sglang", marker = "extra == 'sglang'", git = "https://github.com/JustinTong0323/sglang.git?subdirectory=python&rev=70aa688742dd2b75bf9e8e980249303f39295b0d" },