Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 16 additions & 20 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,44 +28,40 @@ python LLM_Collaboration_with_MARL/train_magrpo.py \

## Settings

### Joint Action Modes
### Joint Action

`magrpo.joint_mode` determines how to combine each agent’s G generations into joint actions at each turn. Two modes are supported: `align` (default), which pairs the g‑th generation of every agent to form G joint actions per node; and `cross`, which forms the Cartesian product within a node, yielding G^N joint actions per node (N agents). Total leaf joint trajectories after T turns (no early termination): `align` → G^T; `cross` → (G^N)^T = G^{N·T}.
`magrpo.joint_mode` determines how to combine each agent’s G generations into joint actions at each turn. Two modes are supported: 'align' (default), which pairs the g‑th generation of every agent to form G joint actions per node; and 'cross', which forms the Cartesian product within a node, yielding G^N joint actions per node (N agents). Total leaf joint trajectories after T turns (no early termination): align → G^T; cross G^{N·T}.

Aligned is faster in wall‑time (fewer sibling evaluations per node), while cross is more sample‑efficient (better value estimation) without extra VRAM because it reuses the same G generations per agent and only crosses them within the node. We never cross across different nodes/prompts; this preserves causal state consistency (actions are conditioned on the same prompts), keeps siblings comparable for the baseline/advantage, maintains correct credit assignment (log‑probs matched to rewards from the same state), and remains computationally tractable.

### Advantage Calculation
### Advantage

`magrpo.normalize_advantage` is false by default. When true, compute z-scored advantages over sibling returns; when false, use a mean baseline without normalization.
Advantages are used to optimize the agents policies, which use a mean baseline without any standard‑deviation normalization to make training unbiased (see [Dr. GRPO](https://arxiv.org/pdf/2503.20783)). We do not apply importance sampling ratios either, since our training is in an on-policy manner (the same policy is used for sampling and training).

`magrpo.epsilon_clip` clamps the advantage to [-epsilon_clip, +epsilon_clip] after normalization (default: None). 0 or None skips clamping entirely.
### Number of Samples

We do not apply the importance sampling ratio because the policy changes slowly with LLMs, and the ratio is close to 1.0. This avoids numerical instability from multiplying many small probabilities.
`magrpo.num_turns` is the number of turns in training and evaluation, and `magrpo.num_generations` is the number of samples per generation. Leaf (total samples at current turn) counts grow with T: `aligned` → G^T; `cross` → G^{N·T}. At each node, the sibling set (competing joint actions under the same prompt/context/turn) has size G for `aligned`, and G^N for `cross`. The policy‑gradient baseline is the mean return over these siblings at that node, i.e., advantage Aᵢ = Returnᵢ − mean_sibling(Return).

### Number of Turns
### Termination

`magrpo.num_turns` determines the number of turns (default: 2). Leaf counts grow with T: `aligned` → G^T; `cross` → G^{N·T}. At each node, the sibling set (competing joint actions under the same prompt/context/turn) has size G for `aligned`, and G^N for `cross`. The policy‑gradient baseline is the mean return over these siblings at that node, i.e., advantage Aᵢ = Returnᵢ − mean_sibling(Return).
`magrpo.termination_threshold` is used to incentivize agents to find high‑reward solutions quickly instead of expanding the full Monte Carlo tree. At each node (branch, turn), we compute the mean immediate reward across that node’s sibling joint actions; if the mean exceeds the threshold, that branch stops expanding at this turn and the trainer backpropagates from the truncated subtree. Other branches continue.

### Early Termination
### New Prompts

`magrpo.termination_threshold` is used to incentivize agents to find high‑reward solutions quickly, instead of expanding the full Monte Carlo tree. At each node (branch, turn), compute the mean immediate reward across that node’s sibling joint actions; if the mean exceeds the threshold, that branch stops expanding at this turn and the trainer backpropagates from the truncated subtree. Other branches continue.

### 2+Turn Prompt

`external.original_prompt` and `external.previous_response` both default as `true`. 2+ turn prompts include both the original first‑turn problem prompt and the previous response by default to preserve full context; you can shorten the context by setting either to `false` (for example, keep only the previous response to reduce tokens while retaining the most recent interaction).
`external.original_prompt` and `external.previous_response` both default as true. 2+ turn prompts include both the original first‑turn problem prompt and the previous response by default to preserve full context; you can shorten the context by setting either to false (for example, keep only the previous response to reduce tokens while retaining the most recent interaction).

### External Modes

`external.mode` is set to 'level_feedback' by default. This gives additional information from external to prompts in the following turns; 'level_feedback' attaches test‑driven diagnostics, while alternatives include:
`external.mode` is used to imitate the environment transition, which is set to 'level_feedback' by default. This gives additional information from external to prompts in the following turns; 'level_feedback' attaches test‑driven diagnostics, while alternatives include:

- `expert_edits`: an LLM proposes edits; prompts include edit suggestions plus context.
- `level_passed` / `passed`: binary outcome oriented prompts with minimal context.
- `plain`: no diagnostics, but still includes previous response (unless disabled) and a "Revise ..." instruction.
- 'expert_edits': an LLM proposes edits; prompts include edit suggestions plus context.
- 'level_passed' / 'passed': binary outcome oriented prompts with minimal context.
- 'plain': no diagnostics, but still includes previous response (unless disabled) and a "revise your previous response" instruction.

Specific settings for 'level_feedback' is `external.sandbox_slice`, which controls how many eval tests to include in the feedback. By default, sandbox executes only the first assert (sandbox_slice=1). Use all eval tests by setting `external.sandbox_slice` to 0, None, or 'all'. Negative values use the last asserts. `external.sandbox_slice` only affects analysis-based modes ('level_feedback', 'level_passed', 'passed'), and it has no effect on 'expert_edits'.

Specific settings for 'expert_edits' is `external.expert_edits_model`, which controls which LLM to use for proposing edits. By default, it uses DeepSeek-Coder. You can also change it to Claude-3, GPT-4, once you have keys/tokens in your global environment variables.
Specific settings for 'expert_edits' is `external.expert_edits_model`, which controls which LLM to use for proposing edits. By default, it uses DeepSeek-Coder. You can also change it to Claude, GPT, and other models, once you have keys/tokens in your environment.

### Output

`output.save_model` is set to `false` by default because of the huge storage required by multiple LLMs. `verbose` is used for debug printing on cluster if set to be true, but it is default to be false and you can only see a tqdm bar that shows the training progress. You can also turn on `magrpo.log_code_levels` to log the level-rewards during training, but it will crazily slow down the training.
`output.save_model` is set to 'false' by default because of the huge storage required by multiple LLMs. `output.verbose` is used for debug printing on cluster if set to be true, but it is default to be false and you can only see a tqdm bar that shows the training progress.
3 changes: 1 addition & 2 deletions baselines/che_concat.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,7 @@
import re
import signal
import time
from collections import defaultdict
from math import comb


import numpy as np
import torch
Expand Down
12 changes: 9 additions & 3 deletions baselines/che_discuss.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,7 @@
import re
import signal
import time
from collections import defaultdict
from math import comb


import numpy as np
import torch
Expand Down Expand Up @@ -1009,6 +1008,9 @@ def evaluate_coophumaneval_two_round(


def main():
# --------------------------------------------------------------
# CLI: parse arguments
# --------------------------------------------------------------
parser = argparse.ArgumentParser(
description="CoopHumanEval Two-Round Model Evaluation"
)
Expand Down Expand Up @@ -1043,14 +1045,18 @@ def main():

args = parser.parse_args()

# Initialize two-round evaluator
# --------------------------------------------------------------
# Initialize evaluator
# --------------------------------------------------------------
evaluator = QwenCoopHumanEvalTwoRoundEvaluator(
aux_model_name=args.aux_model,
main_model_name=args.main_model,
device=args.device,
)

# --------------------------------------------------------------
# Run evaluation
# --------------------------------------------------------------
aggregated_metrics, sample_results = evaluator.evaluate_coophumaneval_two_round(
num_samples=args.samples,
num_generations=args.generations,
Expand Down
3 changes: 1 addition & 2 deletions baselines/che_sequential.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,7 @@
import re
import signal
import time
from collections import defaultdict
from math import comb


import numpy as np
import torch
Expand Down
12 changes: 9 additions & 3 deletions baselines/che_single_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,7 @@
import re
import signal
import time
from collections import defaultdict
from math import comb


import numpy as np
import torch
Expand Down Expand Up @@ -731,6 +730,9 @@ def evaluate_coophumaneval_baseline(


def main():
# --------------------------------------------------------------
# CLI: parse arguments
# --------------------------------------------------------------
parser = argparse.ArgumentParser(
description="CoopHumanEval Single Agent Baseline Evaluation"
)
Expand Down Expand Up @@ -760,12 +762,16 @@ def main():

args = parser.parse_args()

# Initialize baseline evaluator
# --------------------------------------------------------------
# Initialize evaluator
# --------------------------------------------------------------
evaluator = QwenCoopHumanEvalSingleAgentBaseline(
model_name=args.model, device=args.device
)

# --------------------------------------------------------------
# Run evaluation
# --------------------------------------------------------------
aggregated_metrics, sample_results = evaluator.evaluate_coophumaneval_baseline(
num_samples=args.samples,
num_generations=args.generations,
Expand Down
3 changes: 1 addition & 2 deletions baselines/he_concat.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,7 @@
import re
import signal
import time
from collections import defaultdict
from math import comb


import numpy as np
import torch
Expand Down
12 changes: 9 additions & 3 deletions baselines/he_discuss.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,7 @@
import re
import signal
import time
from collections import defaultdict
from math import comb


import numpy as np
import torch
Expand Down Expand Up @@ -1012,6 +1011,9 @@ def evaluate_humaneval_two_round(


def main():
# --------------------------------------------------------------
# CLI: parse arguments
# --------------------------------------------------------------
parser = argparse.ArgumentParser(description="HumanEval Two-Round Model Evaluation")
parser.add_argument(
"--aux-model", default="Qwen/Qwen2.5-Coder-3B", help="Auxiliary model name"
Expand Down Expand Up @@ -1044,14 +1046,18 @@ def main():

args = parser.parse_args()

# Initialize two-round evaluator
# --------------------------------------------------------------
# Initialize evaluator
# --------------------------------------------------------------
evaluator = QwenHumanEvalTwoRoundEvaluator(
aux_model_name=args.aux_model,
main_model_name=args.main_model,
device=args.device,
)

# --------------------------------------------------------------
# Run evaluation
# --------------------------------------------------------------
aggregated_metrics, sample_results = evaluator.evaluate_humaneval_two_round(
num_samples=args.samples,
num_generations=args.generations,
Expand Down
3 changes: 1 addition & 2 deletions baselines/he_sequential.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,7 @@
import re
import signal
import time
from collections import defaultdict
from math import comb


import numpy as np
import torch
Expand Down
12 changes: 9 additions & 3 deletions baselines/he_single_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,7 @@
import re
import signal
import time
from collections import defaultdict
from math import comb


import numpy as np
import torch
Expand Down Expand Up @@ -734,6 +733,9 @@ def evaluate_humaneval_baseline(


def main():
# --------------------------------------------------------------
# CLI: parse arguments
# --------------------------------------------------------------
parser = argparse.ArgumentParser(
description="HumanEval Single Agent Baseline Evaluation"
)
Expand Down Expand Up @@ -763,12 +765,16 @@ def main():

args = parser.parse_args()

# Initialize baseline evaluator
# --------------------------------------------------------------
# Initialize evaluator
# --------------------------------------------------------------
evaluator = QwenHumanEvalSingleAgentBaseline(
model_name=args.model, device=args.device
)

# --------------------------------------------------------------
# Run evaluation
# --------------------------------------------------------------
aggregated_metrics, sample_results = evaluator.evaluate_humaneval_baseline(
num_samples=args.samples,
num_generations=args.generations,
Expand Down
1 change: 0 additions & 1 deletion configs/grpo_che_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@ grpo:
discount: 0.9
termination_threshold: -0.1
reward_shift: -2.1
normalize_advantage: false
epsilon_clip: null

# wandb
Expand Down
1 change: 0 additions & 1 deletion configs/grpo_he_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@ grpo:
discount: 0.9
termination_threshold: -0.1
reward_shift: -2.1
normalize_advantage: false
epsilon_clip: null

# wandb
Expand Down
58 changes: 58 additions & 0 deletions configs/grpo_mbpp_config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
# model
model:
name: "Qwen/Qwen2.5-Coder-3B"
type: "qwen"
temperature: 0.7
top_p: 0.9
max_length: 2048
tokenizer_kwargs:
trust_remote_code: true
model_kwargs:
trust_remote_code: true
torch_dtype: "bfloat16"

# dataset
dataset:
name: "OpenMLRL/MBPP"
type: "mbpp"
train_split: "test[15:65]"
eval_split: "test[:15]"

# output
output:
base_dir: "output"
save_final_model: false
verbose: false

# external
external:
mode: "level_feedback"
sandbox_slice: 1
original_prompt: true
previous_response: true

# grpo
grpo:
num_turns: 2
num_train_epochs: 8
per_device_train_batch_size: 1
learning_rate: 3.0e-5
logging_steps: 50
save_steps: 200
num_generations: 4
max_new_tokens: 256
joint_mode: aligned
temperature: 0.8
top_p: 0.95
discount: 0.9
termination_threshold: -0.1
reward_shift: -2.1
epsilon_clip: null

# wandb
wandb:
project: "mlrl"
entity: "nu-llpr"
name: "grpo_mbpp"
dir: "output"
tags: ["grpo", "mbpp"]
1 change: 0 additions & 1 deletion configs/magrpo_che_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,6 @@ magrpo:
discount: 0.9
termination_threshold: -0.2
reward_shift: -4
normalize_advantage: false
epsilon_clip: null

# wandb
Expand Down
1 change: 0 additions & 1 deletion configs/magrpo_he_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@ magrpo:
discount: 0.9
termination_threshold: -0.2
reward_shift: -4
normalize_advantage: false
epsilon_clip: null

# wandb
Expand Down
Loading