Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 28 additions & 64 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,101 +1,65 @@
# LLM Collaboration with MARL

This repository contains training scripts and configurations for the paper "LLM Collaboration with Multi‑Agent Reinforcement Learning".
- [Benchmarks](#benchmarks)
- [Training Scripts](#training-scripts)
- [Default Configs](#default-configs)
- [Parameter Overrides](#parameter-overrides)
- [Multi-Turn Settings](#multi-turn-settings)
- [2+Turn Prompt Composition](#2turn-prompt-composition)
- [External Modes](#external-modes)
- [Sandbox Tests](#sandbox-tests)
Training scripts and configs for _"LLM Collaboration with Multi‑Agent Reinforcement Learning"_.

## Benchmarks

- HumanEval (HE): 164 problems on split `test`
- CoopHumanEval (CHE): 82 problems on split `test`
- MBPP: 427 problems on split `sanitized`
- HumanEval: 164 problems on split `test`
- CoopHumanEval: 82 problems on split `test`

## Training Scripts

### Default Configs

```bash
# Single-agent HumanEval (GRPO)
python LLM_Collaboration_with_MARL/train_grpo.py \
--config LLM_Collaboration_with_MARL/configs/grpo_he_config.yaml

# Multi-agent CoopHumanEval (MAGRPO)
python LLM_Collaboration_with_MARL/train_magrpo.py \
--config LLM_Collaboration_with_MARL/configs/magrpo_che_config.yaml

# Multi-turn HumanEval (MT-MAGRPO)
python LLM_Collaboration_with_MARL/train_magrpo.py \
--config LLM_Collaboration_with_MARL/configs/mt_magrpo_he_config.yaml
```

### Parameter Overrides

You can override any configuration parameter using `--override`:
You can always override any configuration parameter using `--override`:

```bash
# Change model
python LLM_Collaboration_with_MARL/train_magrpo.py \
--config LLM_Collaboration_with_MARL/configs/magrpo_he_config.yaml \
--override model_name='bigcode/starcoder2-3b'
--override model.name='bigcode/starcoder2-3b' magrpo.num_turns=1
```

# Modify training params
python LLM_Collaboration_with_MARL/train_grpo.py \
--config LLM_Collaboration_with_MARL/configs/grpo_che_config.yaml \
--override grpo.num_train_epochs=20 grpo.learning_rate=3e-5
## Settings

# Multi-turn override example
python LLM_Collaboration_with_MARL/train_magrpo.py \
--config LLM_Collaboration_with_MARL/configs/mt_magrpo_che_config.yaml \
--override dataset.train_split='test[16:]' dataset.eval_split='test[:16]' \
magrpo.num_turns=2
### Joint Action Modes

# Enable code-level training metrics (expensive; default is off)
python LLM_Collaboration_with_MARL/train_magrpo.py \
--config LLM_Collaboration_with_MARL/configs/magrpo_he_config.yaml \
--override magrpo.log_code_levels=true
```
## Multi-Turn Settings
`magrpo.joint_mode` determine how to combine each agent's K generations into joint actions at each turn. 2 modes are supported: if set 'align' by default, each agent's k-th generation is paired with the other agents' k-th generations to form a joint action; if set 'cross', all combinations of the agents' K generations are used to form joint actions (K^N joint actions for N agents).

### 2+Turn Prompt Composition
Since the number of samples will also grow exponentially with the number of turns, aligned joint will be **more flexible** (\#samples could not be a perfect power) and hence faster to train in wall time. However, using cross joint will be more sample efficient (much lower VRAM compare to 'align' when num_generations=K^N), it also performs better since the value estimation is more accurate.

To save memory usage, 2+ turn prompts **include the previous response without the original first‑turn problem prompt by default**. You can add the original prompt to match the concept of observation-action history in MARL.
### Number of Turns

```bash
python LLM_Collaboration_with_MARL/train_magrpo.py \
--config LLM_Collaboration_with_MARL/configs/mt_magrpo_he_config.yaml \
--override magrpo.external_original_prompt=True magrpo.external_previous_response=True
```
`magrpo.num_turns` determines the number of turns (`magrpo.num_turns=2` by default). The number of samples at each turn will grow exponentially with the number of turns: K^TN at turn T if cross joint, K^N if aligned joint.

### External Modes
### Early Termination

Multi-turn training supports external transition modes for 2nd+ turns, set via `external.mode`:
`magrpo.termination_threshold` is used to incentive agents to find high-reward solutions quickly, instead of expanding the full Monte Carlo tree.

- `level_feedback` **(default)**: Detailed diagnostics (impl found, syntax with line/col, per-test pass/fail errors, aux usage).
- Requires `external.expert_model` in config when using `expert_edits` (e.g., `deepseek-coder`, Claude, etc.). This parameter is ignored for other modes (`level_feedback`, `level_passed`, `passed`, `plain`).
- Requires corrsponding API keys in env vars.
- `level_passed`: Binary passed signals (impl found, syntax, tests summary, aux usage).
- `passed`: A binary signal — "All levels passed" or "Not all levels passed".
- `plain`: No signals or diagnostics.
At each node (branch, turn), compute the mean immediate **reward across the sibling** joint actions at that node. If the mean exceeds the threshold, that branch stops expanding at this turn; training backpropagates from the truncated subtree. Other branches continue.

```bash
# HumanEval with detailed feedback signals
python LLM_Collaboration_with_MARL/train_magrpo.py \
--config LLM_Collaboration_with_MARL/configs/mt_magrpo_he_config.yaml \
--override external.mode='level_feedback'
```
### Multi-Turn Prompt

### Sandbox Tests
`external.original_prompt` and `external.previous_response` both default as `true`. 2+ turn prompts include both the original first‑turn problem prompt and the previous response by default to preserve full context; you can shorten the context by setting either to `false` (for example, keep only the previous response to reduce tokens while retaining the most recent interaction).

The external modes obtain `entry_point` and tests via an internal resolver registered by the training script. **By default, sandbox executes only the first assert (`sandbox_slice=1`).** Use all eval tests by setting `external.sandbox_slice` to `0`, `None`, or `'all'`. A negative value uses the last N asserts. Note: `external.sandbox_slice` only affects analysis-based modes (`level_feedback`, `level_passed`, `passed`), and it has no effect on `expert_edits`.
### External Modes

```bash
# Add an external.sandbox_slice override
python LLM_Collaboration_with_MARL/train_magrpo.py \
--config LLM_Collaboration_with_MARL/configs/mt_magrpo_che_config.yaml \
--override external.mode='level_feedback' external.sandbox_slice=-2
```
`external.mode` is set to be 'level_feedback' by default. This gives additional information from external to prompts in the following turns; 'level_feedback' attaches test‑driven diagnostics, while alternatives include 'expert_edits' (an LLM proposes edits), 'level_passed'/'passed' (binary outcomes), and 'plain' (no signals).

Specific settings for 'level_feedback' is `external.sandbox_slice`, which controls how many eval tests to include in the feedback. By default, sandbox executes only the first assert (sandbox_slice=1). Use all eval tests by setting `external.sandbox_slice` to 0, None, or 'all'. Negative values use the last asserts. `external.sandbox_slice` only affects analysis-based modes ('level_feedback', 'level_passed', 'passed'), and it has no effect on 'expert_edits'.

Specific settings for 'expert_edits' is `external.expert_edits_model`, which controls which LLM to use for proposing edits. By default, it uses DeepSeek-Coder. You can also change it to Claude-3, GPT-4, once you have keys/tokens in your global environment variables.

### Output

`output.save_model` is set to `false` by default because of the huge storage required by multiple LLMs. `verbose` is used for debug printing on cluster if set to be true, but it is default to be false and you can only see a tqdm bar that shows the training progress. You can also turn on `magrpo.log_code_levels` to log the level-rewards during training, but it will crazily slow down the training.
17 changes: 10 additions & 7 deletions configs/grpo_che_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ model:
trust_remote_code: true
model_kwargs:
trust_remote_code: true
torch_dtype: "auto"
torch_dtype: "bfloat16"

# dataset
dataset:
Expand All @@ -20,8 +20,9 @@ dataset:

# output
output:
base_dir: "../../../work/hdd/bepg/sliu30/output_st_grpo"
base_dir: "output"
save_final_model: false
verbose: false

# external
external:
Expand All @@ -32,23 +33,25 @@ external:

# grpo
grpo:
num_train_epochs: 16
num_turns: 2
num_train_epochs: 8
per_device_train_batch_size: 1
learning_rate: 1.0e-5
learning_rate: 2.0e-5
logging_steps: 50
save_steps: 200
num_generations: 4
max_new_tokens: 256
joint_mode: cross
joint_mode: aligned
temperature: 0.8
top_p: 0.95
discount: 0.9
termination_threshold: -0.1
reward_shift: -2.1

# wandb
wandb:
project: "mlrl"
entity: "nu-llpr"
name: "grpo_coophumaneval"
dir: "../../../work/hdd/bepg/sliu30/output_st_grpo"
tags: ["grpo", "coophumaneval", "single-agent"]
dir: "output"
tags: ["grpo", "coophumaneval"]
17 changes: 10 additions & 7 deletions configs/grpo_he_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ model:
trust_remote_code: true
model_kwargs:
trust_remote_code: true
torch_dtype: "auto"
torch_dtype: "bfloat16"

# dataset
dataset:
Expand All @@ -20,8 +20,9 @@ dataset:

# output
output:
base_dir: "../../../work/hdd/bepg/sliu30/output_st_grpo"
base_dir: "output"
save_final_model: false
verbose: false

# external
external:
Expand All @@ -32,23 +33,25 @@ external:

# grpo
grpo:
num_train_epochs: 8
num_turns: 2
num_train_epochs: 6
per_device_train_batch_size: 1
learning_rate: 1.0e-5
learning_rate: 2.0e-5
logging_steps: 50
save_steps: 200
num_generations: 4
max_new_tokens: 256
joint_mode: cross
joint_mode: aligned
temperature: 0.8
top_p: 0.95
discount: 0.9
termination_threshold: -0.1
reward_shift: -2.1

# wandb
wandb:
project: "mlrl"
entity: "nu-llpr"
name: "grpo_humaneval"
dir: "../../../work/hdd/bepg/sliu30/output_st_grpo"
tags: ["grpo", "humaneval", "single-agent"]
dir: "output"
tags: ["grpo", "humaneval"]
13 changes: 8 additions & 5 deletions configs/magrpo_che_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ model:
trust_remote_code: true
model_kwargs:
trust_remote_code: true
torch_dtype: "auto"
torch_dtype: "bfloat16"

# dataset
dataset:
Expand All @@ -20,8 +20,9 @@ dataset:

# output
output:
base_dir: "../../../work/hdd/bepg/sliu30/output_st_magrpo"
base_dir: "output"
save_final_model: false
verbose: false

# external
external:
Expand All @@ -32,7 +33,8 @@ external:

# magrpo
magrpo:
num_train_epochs: 16
num_turns: 2
num_train_epochs: 8
per_device_train_batch_size: 1
learning_rate: 2.0e-5
logging_steps: 50
Expand All @@ -41,15 +43,16 @@ magrpo:
max_new_tokens: 256
temperature: 0.8
top_p: 0.95
joint_mode: cross
joint_mode: aligned
num_agents: 2
discount: 0.9
termination_threshold: -0.2
reward_shift: -4

# wandb
wandb:
project: "mlrl"
entity: "nu-llpr"
name: "magrpo_coophumaneval"
dir: "../../../work/hdd/bepg/sliu30/output_st_magrpo"
dir: "output"
tags: ["magrpo", "coophumaneval", "multi-agent"]
13 changes: 8 additions & 5 deletions configs/magrpo_he_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ model:
trust_remote_code: true
model_kwargs:
trust_remote_code: true
torch_dtype: "auto"
torch_dtype: "bfloat16"

# dataset
dataset:
Expand All @@ -20,8 +20,9 @@ dataset:

# output
output:
base_dir: "../../../work/hdd/bepg/sliu30/output_st_magrpo"
base_dir: "output"
save_final_model: false
verbose: false

# external
external:
Expand All @@ -32,22 +33,24 @@ external:

# magrpo
magrpo:
num_train_epochs: 8
num_turns: 2
num_train_epochs: 6
per_device_train_batch_size: 1
learning_rate: 2.0e-5
logging_steps: 50
save_steps: 200
num_generations: 4
max_new_tokens: 256
joint_mode: cross
joint_mode: aligned
num_agents: 2
discount: 0.9
termination_threshold: -0.2
reward_shift: -4

# wandb
wandb:
project: "mlrl"
entity: "nu-llpr"
name: "magrpo_humaneval"
dir: "../../../work/hdd/bepg/sliu30/output_st_magrpo"
dir: "output"
tags: ["magrpo", "humaneval", "multi-agent"]
55 changes: 0 additions & 55 deletions configs/mt_grpo_che_config.yaml

This file was deleted.

Loading