diff --git a/README.md b/README.md index abd0980..50f92a7 100644 --- a/README.md +++ b/README.md @@ -52,7 +52,7 @@ python LLM_Collaboration_with_MARL/train_grpo.py \ # Multi-turn override example python LLM_Collaboration_with_MARL/train_magrpo.py \ --config LLM_Collaboration_with_MARL/configs/mt_magrpo_che_config.yaml \ - --override dataset.train_split='test[:20]' dataset.eval_split='test[20:30]' \ + --override dataset.train_split='test[16:]' dataset.eval_split='test[:16]' \ magrpo.num_turns=2 magrpo.turn_gradient_weights=[1.5,0.5] ``` ### Legacy Command-Line Args @@ -84,13 +84,12 @@ python LLM_Collaboration_with_MARL/train_magrpo.py \ ### External Modes -Multi-turn training supports external transition modes for 2nd+ turns, set via `magrpo.external_mode`: +Multi-turn training supports external transition modes for 2nd+ turns, set via `external.mode`: -- `expert_edits` **(default)**: Uses an expert LLM to suggest edits. - - Requires `magrpo.expert_model` in config (e.g., `deepseek-coder`, Claude, etc.). - - Requires corrsponding API keys in env vars. +- `level_feedback` **(default)**: Detailed diagnostics (impl found, syntax with line/col, per-test pass/fail errors, aux usage). + - Requires `external.expert_model` in config when using `expert_edits` (e.g., `deepseek-coder`, Claude, etc.). This parameter is ignored for other modes (`level_feedback`, `level_passed`, `passed`, `plain`). +- Requires corrsponding API keys in env vars. - `level_passed`: Binary passed signals (impl found, syntax, tests summary, aux usage). -- `level_feedback`: Detailed diagnostics (impl found, syntax with line/col, per-test pass/fail errors, aux usage). - `passed`: A binary signal — "All levels passed" or "Not all levels passed". - `plain`: No signals or diagnostics. @@ -98,28 +97,27 @@ Multi-turn training supports external transition modes for 2nd+ turns, set via ` # HumanEval with detailed feedback signals python LLM_Collaboration_with_MARL/train_magrpo.py \ --config LLM_Collaboration_with_MARL/configs/mt_magrpo_he_config.yaml \ - --override magrpo.external_mode='level_feedback' + --override external.mode='level_feedback' ``` ### Sandbox Tests -The external modes obtain `entry_point` and tests via an internal resolver registered by the training script. **By default, the sandbox tests are the same as the dataset’s eval tests.** -Note: `magrpo.sandbox_slice` only affects analysis-based modes (`level_feedback`, `level_passed`, `passed`), and it has no effect on `expert_edits`. +The external modes obtain `entry_point` and tests via an internal resolver registered by the training script. **By default, sandbox executes only the first assert (`sandbox_slice=1`).** Use all eval tests by setting `external.sandbox_slice` to `0`, `None`, or `'all'`. A negative value uses the last N asserts. Note: `external.sandbox_slice` only affects analysis-based modes (`level_feedback`, `level_passed`, `passed`), and it has no effect on `expert_edits`. ```bash -# Add a magrpo.sandbox_slice to override +# Add an external.sandbox_slice override python LLM_Collaboration_with_MARL/train_magrpo.py \ --config LLM_Collaboration_with_MARL/configs/mt_magrpo_che_config.yaml \ - --override magrpo.external_mode='level_feedback' magrpo.sandbox_slice=-2 + --override external.mode='level_feedback' external.sandbox_slice=-2 ``` ### Handoff Strategy -In MAGRPO, since agents generate a few responses per turn, we need to hand off one for efficiency, else the number of generations per turn will increase exponentially. External handoff controls which previous response is used as context for the later turns. **By default, the "best" completion per agent from the prior turn is used.** Random handoff requires the training loop to supply a candidate pool of previous-turn completions per agent to the external transition. If only a single completion per agent is available, random falls back to the best completion. +In MAGRPO/GRPO multi-turn training, we hand off one prior completion per agent to keep compute bounded. The trainer selects this per the `handoff` mode: **default `random`**, or `best`. Selection happens in the CoMLRL trainer; external modes simply format the next-turn prompts using the provided completions. Configure via `magrpo.handoff` or `grpo.handoff` in your config or `--override`. ```bash python LLM_Collaboration_with_MARL/train_magrpo.py \ --config LLM_Collaboration_with_MARL/configs/mt_magrpo_he_config.yaml \ - --override magrpo.external_mode='plain' magrpo.external_handoff='random' + --override external.mode='plain' magrpo.handoff='best' ``` diff --git a/configs/grpo_che_config.yaml b/configs/grpo_che_config.yaml index b78ee0f..8c5c4a6 100644 --- a/configs/grpo_che_config.yaml +++ b/configs/grpo_che_config.yaml @@ -1,6 +1,4 @@ -# Configuration for CoopHumanEval single-agent training with GRPO - -# Model configuration +# model model: name: "Qwen/Qwen2.5-Coder-3B" type: "qwen" @@ -13,21 +11,28 @@ model: trust_remote_code: true torch_dtype: "auto" -# Dataset configuration +# dataset dataset: - name: "CoMLRL/CoopHumaneval" - type: "coophumaneval" # Used to select formatters and reward function - train_split: "test[:50]" - eval_split: "test[50:66]" + name: "CoMLRL/CoopHumanEval" + type: "coophumaneval" + train_split: "test[16:]" + eval_split: "test[:16]" -# Output configuration +# output output: - base_dir: "../../../projects/bepg/tchen19/output" - save_final_model: true + base_dir: "../../../work/hdd/bepg/sliu30/output_st_grpo" + save_final_model: false + +# external +external: + mode: "level_feedback" + sandbox_slice: 1 + original_prompt: true + previous_response: true -# GRPO training configuration +# grpo grpo: - num_train_epochs: 20 # Same as multi-agent CHE + num_train_epochs: 16 per_device_train_batch_size: 1 learning_rate: 1.0e-5 logging_steps: 50 @@ -36,13 +41,13 @@ grpo: max_new_tokens: 256 temperature: 0.8 top_p: 0.95 - # Early termination threshold for single-agent (GRPO) + handoff: random early_termination_threshold: 2.1 -# Wandb configuration +# wandb wandb: project: "mlrl" entity: "nu-llpr" - name: "grpo_coophumaneval" # Will be appended with model name in script - dir: "../../../projects/bevi/sliu30" + name: "grpo_coophumaneval" + dir: "../../../work/hdd/bepg/sliu30/output_st_grpo" tags: ["grpo", "coophumaneval", "single-agent"] diff --git a/configs/grpo_he_config.yaml b/configs/grpo_he_config.yaml index 47d991e..f29865a 100644 --- a/configs/grpo_he_config.yaml +++ b/configs/grpo_he_config.yaml @@ -1,7 +1,4 @@ -# Configuration for HumanEval single-agent training with GRPO -# Based on train_he_single_agent.py parameters - -# Model configuration +# model model: name: "Qwen/Qwen2.5-Coder-3B" type: "qwen" @@ -14,36 +11,43 @@ model: trust_remote_code: true torch_dtype: "auto" -# Dataset configuration +# dataset dataset: name: "openai/openai_humaneval" - type: "humaneval" # Used to select formatters and reward function - train_split: "test[33:133]" + type: "humaneval" + train_split: "test[33:163]" eval_split: "test[:32]" -# Output configuration +# output output: - base_dir: "../../../projects/bepg/tchen19/output" - save_final_model: true + base_dir: "../../../work/hdd/bepg/sliu30/output_st_grpo" + save_final_model: false + +# external +external: + mode: "level_feedback" + sandbox_slice: 1 + original_prompt: true + previous_response: true -# GRPO training configuration +# grpo grpo: - num_train_epochs: 10 + num_train_epochs: 8 per_device_train_batch_size: 1 learning_rate: 1.0e-5 logging_steps: 50 save_steps: 200 - num_generations: 4 # Number of completions to generate per prompt + num_generations: 4 max_new_tokens: 256 temperature: 0.8 top_p: 0.95 - # Early termination threshold for single-agent (GRPO) + handoff: random early_termination_threshold: 2.1 -# Wandb configuration +# wandb wandb: project: "mlrl" entity: "nu-llpr" - name: "grpo_humaneval" # Will be appended with model name in script - dir: "../../../projects/bepg/sliu30" + name: "grpo_humaneval" + dir: "../../../work/hdd/bepg/sliu30/output_st_grpo" tags: ["grpo", "humaneval", "single-agent"] diff --git a/configs/magrpo_che_config.yaml b/configs/magrpo_che_config.yaml index 09d91d9..526b5a7 100644 --- a/configs/magrpo_che_config.yaml +++ b/configs/magrpo_che_config.yaml @@ -1,7 +1,4 @@ -# Configuration for CoopHumanEval training with MAGRPO -# Exact parameters from train_che.py - -# Model configuration +# model model: name: "Qwen/Qwen2.5-Coder-3B" type: "qwen" @@ -14,21 +11,28 @@ model: trust_remote_code: true torch_dtype: "auto" -# Dataset configuration +# dataset dataset: - name: "LovelyBuggies/CoopHumaneval" - type: "coophumaneval" # Used to select formatters and reward function - train_split: "test[:50]" - eval_split: "test[50:66]" + name: "CoMLRL/CoopHumanEval" + type: "coophumaneval" + train_split: "test[16:]" + eval_split: "test[:16]" -# Output configuration +# output output: - base_dir: "../../../projects/bepg/sliu30/output" - save_final_model: true + base_dir: "../../../work/hdd/bepg/sliu30/output_st_magrpo" + save_final_model: false + +# external +external: + mode: "level_feedback" + sandbox_slice: 1 + original_prompt: true + previous_response: true -# MAGRPO training configuration +# magrpo magrpo: - num_train_epochs: 20 # Exact value from train_che.py + num_train_epochs: 16 per_device_train_batch_size: 1 learning_rate: 2.0e-5 logging_steps: 50 @@ -38,13 +42,13 @@ magrpo: temperature: 0.8 top_p: 0.95 num_agents: 2 - # Early termination threshold for multi-agent (MAGRPO) + handoff: random early_termination_threshold: 4.0 -# Wandb configuration +# wandb wandb: project: "mlrl" entity: "nu-llpr" - name: "magrpo_coophumaneval" # Will be appended with model name in script - dir: "../../../projects/bevi/sliu30" + name: "magrpo_coophumaneval" + dir: "../../../work/hdd/bepg/sliu30/output_st_magrpo" tags: ["magrpo", "coophumaneval", "multi-agent"] diff --git a/configs/magrpo_he_config.yaml b/configs/magrpo_he_config.yaml index 0d450b6..4a48715 100644 --- a/configs/magrpo_he_config.yaml +++ b/configs/magrpo_he_config.yaml @@ -1,9 +1,6 @@ -# Configuration for HumanEval training with MAGRPO -# This file defines all parameters for training experiments - -# Model configuration +# model model: - name: "Qwen/Qwen2.5-Coder-3B" # Options: "Qwen/Qwen2.5-Coder-3B", "bigcode/starcoder2-3b", etc. + name: "Qwen/Qwen2.5-Coder-3B" type: "qwen" temperature: 0.7 top_p: 0.9 @@ -14,21 +11,28 @@ model: trust_remote_code: true torch_dtype: "auto" -# Dataset configuration +# dataset dataset: name: "openai/openai_humaneval" - type: "humaneval" # Used to select formatters and reward function - train_split: "test[33:133]" + type: "humaneval" + train_split: "test[33:163]" eval_split: "test[:32]" -# Output configuration +# output output: - base_dir: "../../../projects/bepg/sliu30/output" - save_final_model: true + base_dir: "../../../work/hdd/bepg/sliu30/output_st_magrpo" + save_final_model: false + +# external +external: + mode: "level_feedback" + sandbox_slice: 1 + original_prompt: true + previous_response: true -# MAGRPO training configuration +# magrpo magrpo: - num_train_epochs: 10 + num_train_epochs: 8 per_device_train_batch_size: 1 learning_rate: 2.0e-5 logging_steps: 50 @@ -36,13 +40,13 @@ magrpo: num_generations: 4 max_new_tokens: 256 num_agents: 2 - # Early termination threshold for multi-agent (MAGRPO) + handoff: random early_termination_threshold: 4.0 -# Wandb configuration +# wandb wandb: project: "mlrl" entity: "nu-llpr" - name: "magrpo_humaneval" # Will be appended with model name in script - dir: "../../../projects/bepg/sliu30" + name: "magrpo_humaneval" + dir: "../../../work/hdd/bepg/sliu30/output_st_magrpo" tags: ["magrpo", "humaneval", "multi-agent"] diff --git a/configs/mt_grpo_che_config.yaml b/configs/mt_grpo_che_config.yaml index 00a7745..49ae34e 100644 --- a/configs/mt_grpo_che_config.yaml +++ b/configs/mt_grpo_che_config.yaml @@ -1,7 +1,4 @@ -# Configuration for Multi-Turn CoopHumanEval training with GRPO (single-agent) -# Based on mt_magrpo_che_config.yaml parameters but adapted for single-agent - -# Model configuration +# model model: name: "Qwen/Qwen2.5-Coder-3B" type: "qwen" @@ -14,22 +11,29 @@ model: trust_remote_code: true torch_dtype: "bfloat16" -# Dataset configuration +# dataset dataset: - name: "LovelyBuggies/CoopHumaneval" - type: "coophumaneval" # Used to select formatters and reward function - train_split: "test[:50]" - eval_split: "test[50:66]" + name: "CoMLRL/CoopHumanEval" + type: "coophumaneval" + train_split: "test[16:]" + eval_split: "test[:16]" -# Output configuration +# output output: - base_dir: "../../../projects/bevi/sliu30/output_mt" - save_final_model: true + base_dir: "../../../work/hdd/bepg/sliu30/output_mt_grpo" + save_final_model: false + +# external +external: + mode: "level_feedback" + sandbox_slice: 1 + original_prompt: true + previous_response: true -# GRPO training configuration (multi-turn enabled via num_turns) +# grpo grpo: num_turns: 2 - num_train_epochs: 10 # Reduced from 20 for multi-turn + num_train_epochs: 8 per_device_train_batch_size: 1 learning_rate: 2.0e-5 logging_steps: 50 @@ -38,17 +42,15 @@ grpo: max_new_tokens: 256 temperature: 0.8 top_p: 0.95 - # Multi-turn specific parameters + handoff: random turn_gradient_weights: [1.2, 0.8] early_termination_weight: 2.0 early_termination_threshold: 2.1 - external_mode: "expert_edits" # Options: expert_edits (default), level_passed, level_feedback, passed, plain - expert_model: "deepseek-coder" # Used by expert_edits mode only -# Wandb configuration +# wandb wandb: project: "mlrl" entity: "nu-llpr" - name: "mt_grpo_coophumaneval" # Will be appended with model name in script - dir: "../../../projects/bevi/sliu30" + name: "mt_grpo_coophumaneval" + dir: "../../../work/hdd/bepg/sliu30/output_mt_grpo" tags: ["mt_grpo", "coophumaneval", "single-agent", "multi-turn"] diff --git a/configs/mt_grpo_he_config.yaml b/configs/mt_grpo_he_config.yaml index ef5da6d..aa0d15e 100644 --- a/configs/mt_grpo_he_config.yaml +++ b/configs/mt_grpo_he_config.yaml @@ -1,8 +1,6 @@ -# Configuration for Multi-Turn HumanEval training with GRPO (single-agent) - -# Model configuration +# model model: - name: "Qwen/Qwen2.5-Coder-3B" # Options: "Qwen/Qwen2.5-Coder-3B", "bigcode/starcoder2-3b", etc. + name: "Qwen/Qwen2.5-Coder-3B" type: "qwen" temperature: 0.7 top_p: 0.9 @@ -13,22 +11,29 @@ model: trust_remote_code: true torch_dtype: "bfloat16" -# Dataset configuration +# dataset dataset: name: "openai/openai_humaneval" - type: "humaneval" # Used to select formatters and reward function - train_split: "test[33:133]" + type: "humaneval" + train_split: "test[33:163]" eval_split: "test[:32]" -# Output configuration +# output output: - base_dir: "../../../projects/bepg/sliu30/output_mt" - save_final_model: true + base_dir: "../../../work/hdd/bepg/sliu30/output_mt_grpo" + save_final_model: false + +# external +external: + mode: "level_feedback" + sandbox_slice: 1 + original_prompt: true + previous_response: true -# GRPO training configuration (multi-turn enabled via num_turns) +# grpo grpo: - num_turns: 2 # Set > 1 for multi-turn training - num_train_epochs: 7 + num_turns: 2 + num_train_epochs: 6 per_device_train_batch_size: 1 learning_rate: 2.0e-5 logging_steps: 50 @@ -37,17 +42,15 @@ grpo: max_new_tokens: 256 temperature: 0.8 top_p: 0.95 - # Multi-turn specific parameters - turn_gradient_weights: [1.2, 0.8] # Weights for different turns - early_termination_weight: 2.0 # Weight for early termination reward + handoff: random + turn_gradient_weights: [1.2, 0.8] + early_termination_weight: 2.0 early_termination_threshold: 2.1 - external_mode: "expert_edits" # Options: expert_edits (default), level_passed, level_feedback, passed, plain - expert_model: "deepseek-coder" # Used by expert_edits mode only -# Wandb configuration +# wandb wandb: project: "mlrl" entity: "nu-llpr" - name: "mt_grpo_humaneval" # Will be appended with model name in script - dir: "../../../projects/bepg/sliu30" + name: "mt_grpo_humaneval" + dir: "../../../work/hdd/bepg/sliu30/output_mt_grpo" tags: ["mt_grpo", "humaneval", "single-agent", "multi-turn"] diff --git a/configs/mt_magrpo_che_config.yaml b/configs/mt_magrpo_che_config.yaml index 65966e0..eb79c60 100644 --- a/configs/mt_magrpo_che_config.yaml +++ b/configs/mt_magrpo_che_config.yaml @@ -1,7 +1,4 @@ -# Configuration for Multi-Turn CoopHumanEval training with MAGRPO -# Based on train_che.py parameters adapted for multi-turn - -# Model configuration +# model model: name: "Qwen/Qwen2.5-Coder-3B" type: "qwen" @@ -14,22 +11,29 @@ model: trust_remote_code: true torch_dtype: "bfloat16" -# Dataset configuration +# dataset dataset: - name: "LovelyBuggies/CoopHumaneval" - type: "coophumaneval" # Used to select formatters and reward function - train_split: "test[:50]" - eval_split: "test[50:66]" + name: "CoMLRL/CoopHumanEval" + type: "coophumaneval" + train_split: "test[16:]" + eval_split: "test[:16]" -# Output configuration +# output output: - base_dir: "../../../projects/bepg/sliu30/output_mt" - save_final_model: true + base_dir: "../../../work/hdd/bepg/sliu30/output_mt_magrpo" + save_final_model: false + +# external +external: + mode: "level_feedback" + sandbox_slice: 1 + original_prompt: true + previous_response: true -# MAGRPO training configuration (multi-turn enabled via num_turns) +# magrpo magrpo: num_turns: 2 - num_train_epochs: 10 # Reduced from 20 for multi-turn + num_train_epochs: 8 per_device_train_batch_size: 1 learning_rate: 2.0e-5 logging_steps: 50 @@ -39,17 +43,15 @@ magrpo: temperature: 0.8 top_p: 0.95 num_agents: 2 - # Multi-turn specific parameters - turn_gradient_weights: [1.2, 0.8] # Weights for different turns + handoff: random + turn_gradient_weights: [1.2, 0.8] early_termination_weight: 2.0 early_termination_threshold: 4.0 - external_mode: "expert_edits" # Options: expert_edits (default), level_passed, level_feedback - expert_model: "deepseek-coder" # Used by expert_edits mode only -# Wandb configuration +# wandb wandb: project: "mlrl" entity: "nu-llpr" - name: "mt_magrpo_coophumaneval" # Will be appended with model name in script - dir: "../../../projects/bevi/sliu30" + name: "mt_magrpo_coophumaneval" + dir: "../../../work/hdd/bepg/sliu30/output_mt_magrpo" tags: ["mt_magrpo", "coophumaneval", "multi-agent", "multi-turn"] diff --git a/configs/mt_magrpo_he_config.yaml b/configs/mt_magrpo_he_config.yaml index c6e545f..ea0ea8f 100644 --- a/configs/mt_magrpo_he_config.yaml +++ b/configs/mt_magrpo_he_config.yaml @@ -1,9 +1,6 @@ -# Configuration for Multi-Turn HumanEval training with MAGRPO -# This file defines all parameters for multi-turn training experiments - -# Model configuration +# model model: - name: "Qwen/Qwen2.5-Coder-3B" # Options: "Qwen/Qwen2.5-Coder-3B", "bigcode/starcoder2-3b", etc. + name: "Qwen/Qwen2.5-Coder-3B" type: "qwen" temperature: 0.7 top_p: 0.9 @@ -14,22 +11,29 @@ model: trust_remote_code: true torch_dtype: "bfloat16" -# Dataset configuration +# dataset dataset: name: "openai/openai_humaneval" - type: "humaneval" # Used to select formatters and reward function - train_split: "test[33:133]" + type: "humaneval" + train_split: "test[33:163]" eval_split: "test[:32]" -# Output configuration +# output output: - base_dir: "../../../projects/bepg/sliu30/output_mt" - save_final_model: true + base_dir: "../../../work/hdd/bepg/sliu30/output_mt_magrpo" + save_final_model: false + +# external +external: + mode: "level_feedback" + sandbox_slice: 1 + original_prompt: true + previous_response: true -# MAGRPO training configuration (multi-turn enabled via num_turns) +# magrpo magrpo: - num_turns: 2 # Set > 1 for multi-turn training - num_train_epochs: 7 + num_turns: 2 + num_train_epochs: 6 per_device_train_batch_size: 1 learning_rate: 2.0e-5 logging_steps: 50 @@ -37,17 +41,15 @@ magrpo: num_generations: 4 max_new_tokens: 256 num_agents: 2 - # Multi-turn specific parameters - turn_gradient_weights: [1.2, 0.8] # Weights for different turns - early_termination_weight: 2.0 # Weight for early termination reward + handoff: random + turn_gradient_weights: [1.2, 0.8] + early_termination_weight: 2.0 early_termination_threshold: 4.0 - external_mode: "expert_edits" # Options: expert_edits (default), level_passed, level_feedback - expert_model: "deepseek-coder" # Used by expert_edits mode only -# Wandb configuration +# wandb wandb: project: "mlrl" entity: "nu-llpr" - name: "mt_magrpo_humaneval" # Will be appended with model name in script - dir: "../../../projects/bepg/sliu30" + name: "mt_magrpo_humaneval" + dir: "../../../work/hdd/bepg/sliu30/output_mt_magrpo" tags: ["mt_magrpo", "humaneval", "multi-agent", "multi-turn"] diff --git a/external/__init__.py b/external/__init__.py index 56a3101..2b994ee 100644 --- a/external/__init__.py +++ b/external/__init__.py @@ -1,5 +1,4 @@ from typing import Any, Callable, Dict, List, Tuple, Union, Optional -import random # Mode implementations live alongside this file from . import expert_edits @@ -52,7 +51,7 @@ def get_external_transition( Args: prompt: Original problem prompt. - agent_completions: Best completions from previous turn (one per agent). + agent_completions: Selected completions from previous turn (one per agent). num_agents: Number of agents (1 or 2). mode: External transition mode name (default: "expert_edits"). **kwargs: Mode-specific parameters (e.g., expert_model, retries). @@ -77,47 +76,13 @@ def get_external_transition( # Pull common flags controlling prompt composition original_prompt_flag = kwargs.get("original_prompt", False) previous_response_flag = kwargs.get("previous_response", True) - handoff_strategy = (kwargs.get("handoff_strategy") or "best").lower() - - # Optional pool of candidate completions from the previous turn (per agent) - # Expected shape when num_agents==2: (List[str] for aux, List[str] for main) - # When num_agents==1: List[str] for main or a tuple where main candidates are in the last position - candidate_pool = kwargs.get("agent_candidate_completions") - - def select_handoff(prev_best_aux: str, prev_best_main: str) -> Tuple[str, str]: - if handoff_strategy != "random" or not candidate_pool: - return prev_best_aux, prev_best_main - try: - if int(num_agents) == 1: - # Single-agent: only choose for main - if isinstance(candidate_pool, (list, tuple)): - # If tuple/list of lists, pick the last as main pool; if flat list, use it directly - if len(candidate_pool) >= 1 and isinstance(candidate_pool[0], str): - main_pool = candidate_pool # flat list - else: - main_pool = candidate_pool[-1] - else: - main_pool = None - chosen_main = random.choice(main_pool) if main_pool else prev_best_main - return "", chosen_main - # Two agents - aux_pool, main_pool = candidate_pool - chosen_aux = random.choice(aux_pool) if aux_pool else prev_best_aux - chosen_main = random.choice(main_pool) if main_pool else prev_best_main - return chosen_aux, chosen_main - except Exception: - # Fallback safely to best if pool malformed - return prev_best_aux, prev_best_main if mode == "expert_edits": if int(num_agents) == 1: - main_best = agent_completions[0] - _aux_best = "" - _aux_comp, main_comp = select_handoff("", main_best) + main_comp = agent_completions[0] aux_comp = "" # isolate aux in single-agent mode else: - aux_best, main_best = agent_completions[0], agent_completions[1] - aux_comp, main_comp = select_handoff(aux_best, main_best) + aux_comp, main_comp = agent_completions[0], agent_completions[1] original_prompt, aux_edits, main_edits = expert_edits.add_expert_edits( prompt=prompt, aux_completion=aux_comp, @@ -156,12 +121,10 @@ def select_handoff(prev_best_aux: str, prev_best_main: str) -> Tuple[str, str]: if mode == "level_feedback": if int(num_agents) == 1: - main_best = agent_completions[0] - aux_comp, main_comp = select_handoff("", main_best) + main_comp = agent_completions[0] aux_comp = "" else: - aux_best, main_best = agent_completions[0], agent_completions[1] - aux_comp, main_comp = select_handoff(aux_best, main_best) + aux_comp, main_comp = agent_completions[0], agent_completions[1] ctx = get_context(prompt) or {} entry_point = ctx.get("entry_point", "") test_code = ctx.get("tests_sandbox") or ctx.get("tests_eval", "") @@ -187,12 +150,10 @@ def select_handoff(prev_best_aux: str, prev_best_main: str) -> Tuple[str, str]: if mode == "level_passed": if int(num_agents) == 1: - main_best = agent_completions[0] - aux_comp, main_comp = select_handoff("", main_best) + main_comp = agent_completions[0] aux_comp = "" else: - aux_best, main_best = agent_completions[0], agent_completions[1] - aux_comp, main_comp = select_handoff(aux_best, main_best) + aux_comp, main_comp = agent_completions[0], agent_completions[1] ctx = get_context(prompt) or {} entry_point = ctx.get("entry_point", "") test_code = ctx.get("tests_sandbox") or ctx.get("tests_eval", "") @@ -218,12 +179,10 @@ def select_handoff(prev_best_aux: str, prev_best_main: str) -> Tuple[str, str]: if mode == "passed": if int(num_agents) == 1: - main_best = agent_completions[0] - aux_comp, main_comp = select_handoff("", main_best) + main_comp = agent_completions[0] aux_comp = "" else: - aux_best, main_best = agent_completions[0], agent_completions[1] - aux_comp, main_comp = select_handoff(aux_best, main_best) + aux_comp, main_comp = agent_completions[0], agent_completions[1] ctx = get_context(prompt) or {} entry_point = ctx.get("entry_point", "") test_code = ctx.get("tests_sandbox") or ctx.get("tests_eval", "") @@ -249,12 +208,10 @@ def select_handoff(prev_best_aux: str, prev_best_main: str) -> Tuple[str, str]: if mode == "plain": if int(num_agents) == 1: - main_best = agent_completions[0] - aux_comp, main_comp = select_handoff("", main_best) + main_comp = agent_completions[0] aux_comp = "" else: - aux_best, main_best = agent_completions[0], agent_completions[1] - aux_comp, main_comp = select_handoff(aux_best, main_best) + aux_comp, main_comp = agent_completions[0], agent_completions[1] ctx = get_context(prompt) or {} entry_point = ctx.get("entry_point", "") test_code = ctx.get("tests_sandbox") or ctx.get("tests_eval", "") diff --git a/test/README.md b/test/README.md deleted file mode 100644 index 712a4f8..0000000 --- a/test/README.md +++ /dev/null @@ -1,17 +0,0 @@ -## Run External Output Test (dump_external_prompts.py) - -- Recommended Method (Automatically sets environment and aggregates output into two files) - - `bash LLM_Collaboration_with_MARL/test/run_external.sh` - -- Directly Run Script (if located in repository root directory) - - Optional: `conda activate comlrl` - - `export PYTHONPATH="${PYTHONPATH}:$(pwd)/LLM_Collaboration_with_MARL"` - - `python3 LLM_Collaboration_with_MARL/test/dump_external_prompts.py` - -- Notes - - For `expert_edits` mode, the script now defaults to calling the real expert model (`deepseek-coder`). - - Set `DEEPSEEK_API_KEY` in your environment; otherwise requests will fail and may fall back to stub where applicable. - - To force offline stub, pass `--offline` to the script or set `OFFLINE=1`. - - Output files located in `LLM_Collaboration_with_MARL/test/`: - - `prompts_sa.txt` (single agent) - - `prompts_ma.txt` (multi agent) diff --git a/test/dump_external_prompts.py b/test/dump_external_prompts.py deleted file mode 100644 index 92dace9..0000000 --- a/test/dump_external_prompts.py +++ /dev/null @@ -1,240 +0,0 @@ -#!/usr/bin/env python3 -""" -Generate and save prompts from external modes for num_agents=1 (sa) and 2 (ma). - -Outputs two files under test/ only: - - prompts_sa.txt (single agent) - - prompts_ma.txt (multi agent) - -Avoids network by stubbing expert_edits.add_expert_edits. -""" - -import os -import sys -import argparse -from typing import Dict - - -def project_root() -> str: - # This file lives at /LLM_Collaboration_with_MARL/test - # We want sys.path to include /LLM_Collaboration_with_MARL - return os.path.abspath( - os.path.join(os.path.dirname(os.path.abspath(__file__)), "..") - ) - - -def add_repo_to_path(): - root = project_root() - # root here is /LLM_Collaboration_with_MARL - if root not in sys.path: - sys.path.insert(0, root) - - -def build_context(prompt: str) -> Dict[str, str]: - entry_point = "add" - # Minimal tests that the framework can parse (uses candidate -> replaced by entry_point) - tests = ( - "def check(candidate):\n" - " assert candidate(1, 2) == 3\n" - " assert candidate(-1, 1) == 0\n" - " assert candidate(0, 0) == 0\n" - ) - return {"entry_point": entry_point, "tests_eval": tests, "tests_sandbox": tests} - - -def main(): - parser = argparse.ArgumentParser(description="Dump external prompts for 1/2 agents") - parser.add_argument( - "--real-expert", - action="store_true", - help="Force call real expert model for expert_edits (requires API key)", - ) - parser.add_argument( - "--offline", - action="store_true", - help="Force offline stub for expert_edits (no network)", - ) - parser.add_argument( - "--expert-model", - type=str, - default=os.environ.get("EXPERT_MODEL", "deepseek-coder"), - help="Expert model name for expert_edits", - ) - args = parser.parse_args() - - add_repo_to_path() - - # Default to real expert calls unless offline is requested - def _is_truthy(v: str) -> bool: - return str(v).lower() in ("1", "true", "yes", "y") - - use_real_expert = True - if _is_truthy(os.environ.get("OFFLINE", "0")) or args.offline: - use_real_expert = False - if _is_truthy(os.environ.get("REAL_EXPERT", "0")): - use_real_expert = True - if args.real_expert: - use_real_expert = True - - # Ensure 'anthropic' is importable (not used unless that model is chosen) - try: - import anthropic # type: ignore # noqa: F401 - except Exception: - import types - - if "anthropic" not in sys.modules: - m = types.ModuleType("anthropic") - - class _Anthropic: - def __init__(self, *a, **k): - pass - - m.Anthropic = _Anthropic - sys.modules["anthropic"] = m - - # Stub openai only when not using real expert edits - if not use_real_expert: - import types - - if "openai" not in sys.modules: - m2 = types.ModuleType("openai") - - class _OpenAI: - def __init__(self, *a, **k): - pass - - m2.OpenAI = _OpenAI - sys.modules["openai"] = m2 - - # Import our local package (ensure we don't shadow by this test module name later) - import importlib - - external = importlib.import_module("external") - - # Register a simple context resolver - def resolver(p: str): - return build_context(p) - - external.set_context_resolver(resolver) - - # Monkey patch expert_edits.add_expert_edits to avoid network calls - if not use_real_expert: - from external import expert_edits as ee - - def _stub_add_expert_edits( - prompt: str, aux_completion: str, main_completion: str, **kwargs - ): - # Return deterministic edits; leave aux empty for single-agent compatibility - return ( - prompt, - "# AUX EDIT: (stub) helper not needed for this task", - "# MAIN EDIT: (stub) handle edge cases and negative inputs", - ) - - ee.add_expert_edits = _stub_add_expert_edits # type: ignore - - # Inputs - original_prompt = "Write a function add(x: int, y: int) -> int that returns the sum of two integers." - - # Completions for num_agents=1 (main only) - main_only_code = ( - "def add(x, y):\n" " # simple implementation\n" " return x + y\n" - ) - - # Completions for num_agents=2 (aux + main) - aux_code = "def aux(x, y):\n" " # helper returns sum\n" " return x + y\n" - main_code = "def add(x, y):\n" " # use helper\n" " return aux(x, y)\n" - - modes = ["expert_edits", "level_feedback", "level_passed", "passed", "plain"] - - out_dir = os.path.dirname(os.path.abspath(__file__)) - - sa_sections = [] - ma_sections = [] - - for mode in modes: - # num_agents = 1 - try: - prompts_1 = external.get_external_transition( - prompt=original_prompt, - agent_completions=[main_only_code], - num_agents=1, - mode=mode, - expert_model=args.expert_model if mode == "expert_edits" else None, - ) - # get_external_transition returns [main_prompt] for single-agent - if isinstance(prompts_1, (list, tuple)): - main_prompt_1 = prompts_1[-1] - else: - main_prompt_1 = str(prompts_1) - - sa_sections.append( - "\n".join( - [ - f"=== MODE: {mode} | num_agents=1 ===", - "", - "MAIN PROMPT:", - main_prompt_1, - "", - ] - ) - ) - except Exception as e: - msg = f"[agents=1] Error in mode '{mode}': {e}" - print(msg) - sa_sections.append(f"=== MODE: {mode} | num_agents=1 ===\nERROR: {e}\n") - - # num_agents = 2 - try: - prompts_2 = external.get_external_transition( - prompt=original_prompt, - agent_completions=(aux_code, main_code), - num_agents=2, - mode=mode, - expert_model=args.expert_model if mode == "expert_edits" else None, - ) - # get_external_transition returns (aux_prompt, main_prompt) for two agents - if isinstance(prompts_2, (list, tuple)) and len(prompts_2) == 2: - aux_prompt_2, main_prompt_2 = prompts_2 - else: - # Fallback to consistent formatting - aux_prompt_2 = "" - main_prompt_2 = ( - prompts_2[0] - if isinstance(prompts_2, (list, tuple)) - else str(prompts_2) - ) - - ma_sections.append( - "\n".join( - [ - f"=== MODE: {mode} | num_agents=2 ===", - "", - "AUX PROMPT:", - aux_prompt_2, - "", - "MAIN PROMPT:", - main_prompt_2, - "", - ] - ) - ) - except Exception as e: - msg = f"[agents=2] Error in mode '{mode}': {e}" - print(msg) - ma_sections.append(f"=== MODE: {mode} | num_agents=2 ===\nERROR: {e}\n") - - # Write combined outputs - sa_path = os.path.join(out_dir, "prompts_sa.txt") - with open(sa_path, "w", encoding="utf-8") as f: - f.write("\n".join(sa_sections).rstrip() + "\n") - print(f"Wrote {sa_path}") - - ma_path = os.path.join(out_dir, "prompts_ma.txt") - with open(ma_path, "w", encoding="utf-8") as f: - f.write("\n".join(ma_sections).rstrip() + "\n") - print(f"Wrote {ma_path}") - - -if __name__ == "__main__": - main() diff --git a/test/prompts_ma.txt b/test/prompts_ma.txt deleted file mode 100644 index 9ac58f4..0000000 --- a/test/prompts_ma.txt +++ /dev/null @@ -1,124 +0,0 @@ -=== MODE: expert_edits | num_agents=2 === - -AUX PROMPT: -Your previous aux(...) implementation: -def aux(x, y): - # helper returns sum - return x + y - -Here is edited snippet from an expert model: -Perfect! No changes needed! - -Revise your aux(...) accordingly. Output ONLY the function code with no extra text. - -MAIN PROMPT: -Your previous main implementation: -def add(x, y): - # use helper - return aux(x, y) - -Here is edited snippet from an expert model: -Perfect! No changes needed! - -Revise your add(...) accordingly. Output ONLY the function code with no extra text. - -=== MODE: level_feedback | num_agents=2 === - -AUX PROMPT: -Your previous aux(...) implementation: -def aux(x, y): - # helper returns sum - return x + y - -Static and execution diagnostics: -- Aux definition: FOUND (Aux function properly defined with return statement) -- Main definition: FOUND (Main function (add) properly defined with return statement) -- Syntax: OK (Combined code syntax OK) -- Tests: 3/3 passed - -Revise your aux(...) accordingly. Output ONLY the function code with no extra text. - -MAIN PROMPT: -Your previous main implementation: -def add(x, y): - # use helper - return aux(x, y) - -Static and execution diagnostics: -- Main definition: FOUND (Main function (add) properly defined with return statement) -- Aux definition: FOUND (Aux function properly defined with return statement) -- Syntax: OK (Combined code syntax OK) -- Tests: 3/3 passed -- Aux usage: main calls aux and uses its result -- Warning: main appears to be a thin wrapper over aux - -Revise your add(...) accordingly. Output ONLY the function code with no extra text. - -=== MODE: level_passed | num_agents=2 === - -AUX PROMPT: -Your previous aux(...) implementation: -def aux(x, y): - # helper returns sum - return x + y - -Signals: -- Implementation: OK -- Syntax: Syntax correct -- Tests: Passed all tests - -Revise your aux(...) accordingly. Output ONLY the function code. - -MAIN PROMPT: -Your previous main implementation: -def add(x, y): - # use helper - return aux(x, y) - -Signals: -- Implementation: OK -- Syntax: Syntax correct -- Tests: Passed all tests -- Aux usage: Aux call present and used - -Revise your add(...) accordingly. Output ONLY the function code. - -=== MODE: passed | num_agents=2 === - -AUX PROMPT: -Your previous aux(...) implementation: -def aux(x, y): - # helper returns sum - return x + y - -Signal: All levels passed - -Revise your aux(...) if needed. Output ONLY the function code. - -MAIN PROMPT: -Your previous main implementation: -def add(x, y): - # use helper - return aux(x, y) - -Signal: All levels passed - -Revise your add(...) if needed. Output ONLY the function code. - -=== MODE: plain | num_agents=2 === - -AUX PROMPT: -Your previous aux(...) implementation: -def aux(x, y): - # helper returns sum - return x + y - -Revise your aux(...) accordingly. Output ONLY the function code. - -MAIN PROMPT: -Your previous main implementation: -def add(x, y): - # use helper - return aux(x, y) - -Revise your add(...) accordingly. Output ONLY the function code. diff --git a/test/prompts_sa.txt b/test/prompts_sa.txt deleted file mode 100644 index bcbdee8..0000000 --- a/test/prompts_sa.txt +++ /dev/null @@ -1,64 +0,0 @@ -=== MODE: expert_edits | num_agents=1 === - -MAIN PROMPT: -Your previous implementation: -def add(x, y): - # simple implementation - return x + y - -Here is an edited snippet from an expert model: -Perfect! No changes needed! - -Revise your add(...) accordingly. Output ONLY the function code with no extra text. - -=== MODE: level_feedback | num_agents=1 === - -MAIN PROMPT: -Your previous implementation: -def add(x, y): - # simple implementation - return x + y - -Static and execution diagnostics: -- Main definition: FOUND (Main function (add) properly defined with return statement) -- Syntax: OK (Combined code syntax OK) -- Tests: 3/3 passed - -Revise your add(...) accordingly. Output ONLY the function code with no extra text. - -=== MODE: level_passed | num_agents=1 === - -MAIN PROMPT: -Your previous implementation: -def add(x, y): - # simple implementation - return x + y - -Signals: -- Implementation: OK -- Syntax: Syntax correct -- Tests: Passed all tests - -Revise your add(...) accordingly. Output ONLY the function code. - -=== MODE: passed | num_agents=1 === - -MAIN PROMPT: -Your previous implementation: -def add(x, y): - # simple implementation - return x + y - -Signal: All levels passed - -Revise your add(...) if needed. Output ONLY the function code. - -=== MODE: plain | num_agents=1 === - -MAIN PROMPT: -Your previous implementation: -def add(x, y): - # simple implementation - return x + y - -Revise your add(...) accordingly. Output ONLY the function code. diff --git a/test/run_external.sh b/test/run_external.sh deleted file mode 100755 index 39a7b63..0000000 --- a/test/run_external.sh +++ /dev/null @@ -1,43 +0,0 @@ -#!/usr/bin/env bash -set -euo pipefail - -# Run the external prompt generation test and list outputs. -# Environment setup mirrors your sbatch reference (no srun involved). - -SCRIPT_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )" -REPO_DIR="$( cd "$SCRIPT_DIR/.." && pwd )" # LLM_Collaboration_with_MARL - -# Optional: conda activation (safe if conda exists) -if command -v conda >/dev/null 2>&1; then - # Load conda shell hook - source "$(conda info --base)/etc/profile.d/conda.sh" || true - # Load user bashrc for aliases/env if present - [ -f "$HOME/.bashrc" ] && source "$HOME/.bashrc" - # Activate target env if available - conda env list | grep -qE '^comlrl\s' && conda activate comlrl || true -fi - -export PYTHONPATH="${PYTHONPATH:-}:$REPO_DIR" - -PYTHON_BIN=${PYTHON:-python3} -echo "Using Python: $PYTHON_BIN" -echo "Repo dir: $REPO_DIR" -echo "Script dir: $SCRIPT_DIR" - -cd "$REPO_DIR" - -# If REAL_EXPERT=1, request real expert edits with selected model (default deepseek-coder) -EXPERT_ARGS=() -if [[ "${REAL_EXPERT:-}" == "1" ]]; then - : "${EXPERT_MODEL:=deepseek-coder}" - if [[ -z "${DEEPSEEK_API_KEY:-}" ]]; then - echo "[warn] REAL_EXPERT=1 but DEEPSEEK_API_KEY is not set; request may fail." >&2 - fi - EXPERT_ARGS+=("--real-expert" "--expert-model" "$EXPERT_MODEL") -fi - -"$PYTHON_BIN" "$SCRIPT_DIR/dump_external_prompts.py" "${EXPERT_ARGS[@]}" - -echo -echo "Generated files:" -ls -1 "$SCRIPT_DIR"/prompts_sa.txt "$SCRIPT_DIR"/prompts_ma.txt 2>/dev/null || echo "No files generated." diff --git a/train_grpo.py b/train_grpo.py index 6dda67c..7950994 100644 --- a/train_grpo.py +++ b/train_grpo.py @@ -350,6 +350,9 @@ def main(): temperature = grpo_config.get("temperature", model_config.temperature) top_p = grpo_config.get("top_p", model_config.top_p) + # External configuration (mode, sandbox, expert model, context flags) + external_cfg = config.get_section("external") if hasattr(config, "get_section") else {} + # Register external context resolver using dataset items (for external modes) def _normalize_prompt(p: str) -> str: return " ".join((p or "").split()).strip() @@ -357,12 +360,22 @@ def _normalize_prompt(p: str) -> str: context_map: Dict[str, Any] = {} # Optionally restrict sandbox tests to the first N eval asserts - # Set grpo.sandbox_slice to an integer N (>0) to keep only the first N asserts - sandbox_slice = grpo_config.get("sandbox_slice", None) - try: - sandbox_slice = int(sandbox_slice) if sandbox_slice is not None else None - except (TypeError, ValueError): - sandbox_slice = None + # Default: keep only the first assert (sandbox_slice=1) + # Set external.sandbox_slice to an integer N (>0) to keep the first N asserts, + # or to 0 / None / 'all' to keep all eval asserts. + _sandbox_val = external_cfg.get("sandbox_slice", 1) + if isinstance(_sandbox_val, str): + _sv = _sandbox_val.strip().lower() + if _sv == "all": + sandbox_slice = 0 + elif _sv.lstrip("-").isdigit(): + sandbox_slice = int(_sv) + else: + sandbox_slice = None + elif isinstance(_sandbox_val, int): + sandbox_slice = _sandbox_val + else: + sandbox_slice = None if _sandbox_val is None else 0 import re as _re @@ -444,6 +457,7 @@ def _resolver(prompt: str): ), early_termination_weight=grpo_config.get("early_termination_weight", 2.0), early_termination_threshold=grpo_config.get("early_termination_threshold", 2.1), + handoff=grpo_config.get("handoff", "random"), ) formatter = get_formatter(dataset_type) @@ -459,8 +473,10 @@ def _resolver(prompt: str): else: wandb_name = wandb_section.get("name", f"grpo_{dataset_type}") + # external_cfg already loaded above # Compute tags and add self-evolved when using analysis-based external modes - external_mode = grpo_config.get("external_mode", "expert_edits") + external_mode = external_cfg.get("mode", "level_feedback") + handoff_mode = grpo_config.get("handoff", "random") default_tags = ["grpo", dataset_type or "code", f"turns_{num_turns}"] tags_from_cfg = wandb_section.get("tags", default_tags) tags = list(tags_from_cfg) if isinstance(tags_from_cfg, list) else default_tags @@ -469,14 +485,16 @@ def _resolver(prompt: str): tags.append("self-evolved") # If sandbox_slice is active (non-zero), append _slice to run name - try: - if sandbox_slice is not None and int(sandbox_slice) != 0: - if not str(wandb_name).endswith("_slice"): - wandb_name = f"{wandb_name}_slice" - if "slice" not in tags: - tags.append("slice") - except Exception: - pass + if isinstance(sandbox_slice, int) and sandbox_slice != 0: + if not str(wandb_name).endswith("_slice"): + wandb_name = f"{wandb_name}_slice" + if "slice" not in tags: + tags.append("slice") + + # Collect full config sections for W&B searchability + dataset_section = config.get_section("dataset") if hasattr(config, "get_section") else {} + model_section = config.get_section("model") if hasattr(config, "get_section") else {} + output_section = config.get_section("output") if hasattr(config, "get_section") else {} wandb_config = { "project": wandb_section.get("project", "mlrl"), @@ -484,6 +502,14 @@ def _resolver(prompt: str): "name": f"{wandb_name}_{model_short_name}", "dir": wandb_section.get("dir", "../../../projects/bepg/sliu30"), "tags": tags, + # Provide full sections for the trainer to log cleanly + "config_sections": { + "dataset": dataset_section, + "model": model_section, + "output": output_section, + "external": external_cfg, + "trainer": grpo_config, + }, } reward_processor = None @@ -513,18 +539,16 @@ def _resolver(prompt: str): and dataset_type and dataset_type.lower() in ["humaneval", "coophumaneval"] ): - expert_model = grpo_config.get("expert_model", "deepseek-coder") + expert_model = external_cfg.get("expert_model", "deepseek-coder") def external_transition_wrapper( - prompt, agent_completions, num_agents, **et_kwargs + prompt, agent_completions, num_agents ): # Single-agent: pass prior main completion; aux is empty internally main_best = agent_completions[0] if agent_completions else "" - original_prompt_flag = grpo_config.get("external_original_prompt", False) - previous_response_flag = grpo_config.get("external_previous_response", True) - handoff_strategy = grpo_config.get("external_handoff", "best") - + original_prompt_flag = external_cfg.get("original_prompt", True) + previous_response_flag = external_cfg.get("previous_response", True) prompts = get_external_transition( prompt=prompt, agent_completions=[main_best], @@ -533,8 +557,6 @@ def external_transition_wrapper( mode=external_mode, original_prompt=original_prompt_flag, previous_response=previous_response_flag, - handoff_strategy=handoff_strategy, - **et_kwargs, ) # Ensure list of one string is returned @@ -547,7 +569,7 @@ def external_transition_wrapper( trainer = MAGRPOTrainer(**trainer_kwargs) trainer.train() - save_final = config.get("output.save_final_model", True) + save_final = config.get("output.save_final_model", False) if save_final: save_path = config.get( "output.save_path", os.path.join(output_dir, "final_model") diff --git a/train_magrpo.py b/train_magrpo.py index 988ca77..1ef166e 100644 --- a/train_magrpo.py +++ b/train_magrpo.py @@ -307,12 +307,15 @@ def main(): config_save_path = os.path.join(output_dir, "config.yaml") config.save(config_save_path) + train_dataset = None + eval_dataset = None try: train_dataset = load_dataset(dataset_name, split=train_split) eval_dataset = load_dataset(dataset_name, split=eval_split) except Exception as e: print(f"Error loading dataset: {e}") + return print(f"\nUsing model: {model_name}") print(f"Model type: {model_config.type}") @@ -340,6 +343,9 @@ def main(): temperature = magrpo_config.get("temperature", model_config.temperature) top_p = magrpo_config.get("top_p", model_config.top_p) + # External configuration (mode, sandbox, expert model, context flags) + external_cfg = config.get_section("external") if hasattr(config, "get_section") else {} + # Register external context resolver using dataset items def _normalize_prompt(p: str) -> str: return " ".join((p or "").split()).strip() @@ -347,12 +353,22 @@ def _normalize_prompt(p: str) -> str: context_map = {} # Optionally restrict sandbox tests to the first N eval asserts - # Set magrpo.sandbox_slice to an integer N (>0) to keep only the first N asserts - sandbox_slice = magrpo_config.get("sandbox_slice", None) - try: - sandbox_slice = int(sandbox_slice) if sandbox_slice is not None else None - except (TypeError, ValueError): - sandbox_slice = None + # Default: keep only the first assert (sandbox_slice=1) + # Set external.sandbox_slice to an integer N (>0) to keep the first N asserts, + # or to 0 / None / 'all' to keep all eval asserts. + _sandbox_val = external_cfg.get("sandbox_slice", 1) + if isinstance(_sandbox_val, str): + _sv = _sandbox_val.strip().lower() + if _sv == "all": + sandbox_slice = 0 + elif _sv.lstrip("-").isdigit(): + sandbox_slice = int(_sv) + else: + sandbox_slice = None + elif isinstance(_sandbox_val, int): + sandbox_slice = _sandbox_val + else: + sandbox_slice = None if _sandbox_val is None else 0 import re @@ -448,6 +464,7 @@ def _resolver(prompt: str): early_termination_threshold=magrpo_config.get( "early_termination_threshold", 4.0 ), + handoff=magrpo_config.get("handoff", "random"), ) # Get appropriate formatters and functions based on dataset type, agent count, and training mode @@ -468,8 +485,10 @@ def _resolver(prompt: str): else: wandb_name = wandb_section.get("name", f"magrpo_{dataset_type}") + # external_cfg already loaded above # Compute tags and add self-evolved when using analysis-based external modes - external_mode = magrpo_config.get("external_mode", "expert_edits") + external_mode = external_cfg.get("mode", "level_feedback") + handoff_mode = magrpo_config.get("handoff", "random") default_tags = ["magrpo", dataset_type or "code", f"turns_{num_turns}"] tags_from_cfg = wandb_section.get("tags", default_tags) # Ensure list @@ -478,12 +497,25 @@ def _resolver(prompt: str): if "self-evolved" not in tags: tags.append("self-evolved") + # Collect full config sections for W&B searchability + dataset_section = config.get_section("dataset") if hasattr(config, "get_section") else {} + model_section = config.get_section("model") if hasattr(config, "get_section") else {} + output_section = config.get_section("output") if hasattr(config, "get_section") else {} + wandb_config = { "project": wandb_section.get("project", "mlrl"), "entity": wandb_section.get("entity", "nu-llpr"), "name": f"{wandb_name}_{model_short_name}", "dir": wandb_section.get("dir", "../../../projects/bepg/sliu30"), "tags": tags, + # Provide full sections for the trainer to log cleanly + "config_sections": { + "dataset": dataset_section, + "model": model_section, + "output": output_section, + "external": external_cfg, + "trainer": magrpo_config, + }, } # Get num_agents from magrpo config (where it belongs for MAGRPO training) @@ -523,20 +555,16 @@ def _resolver(prompt: str): and dataset_type and dataset_type.lower() in ["humaneval", "coophumaneval"] ): - expert_model = magrpo_config.get("expert_model", "deepseek-coder") + expert_model = external_cfg.get("expert_model", "deepseek-coder") # external_mode already loaded above def external_transition_wrapper( - prompt, agent_completions, num_agents, **et_kwargs + prompt, agent_completions, num_agents ): # Returns full next-turn prompts per agent (strings) - # Allow overrides via config and forwarded kwargs - original_prompt_flag = magrpo_config.get("external_original_prompt", False) - previous_response_flag = magrpo_config.get( - "external_previous_response", True - ) - handoff_strategy = magrpo_config.get("external_handoff", "best") - + # External prompt composition flags + original_prompt_flag = external_cfg.get("original_prompt", True) + previous_response_flag = external_cfg.get("previous_response", True) return get_external_transition( prompt=prompt, agent_completions=agent_completions, @@ -545,15 +573,13 @@ def external_transition_wrapper( mode=external_mode, original_prompt=original_prompt_flag, previous_response=previous_response_flag, - handoff_strategy=handoff_strategy, - **et_kwargs, ) trainer_kwargs["external_transition"] = external_transition_wrapper trainer = MAGRPOTrainer(**trainer_kwargs) trainer.train() - save_final = config.get("output.save_final_model", True) + save_final = config.get("output.save_final_model", False) if save_final: save_path = config.get( "output.save_path", os.path.join(output_dir, "final_model")