Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions configs/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ data:

collab:
mode: TAKE_JOB # ONE | TAKE_JOB
num_agents: 3 # used when mode=TAKE_JOB
num_agents: 2 # used when mode=TAKE_JOB

external:
mode: code_feedback # plain | plain_simple | code_feedback
Expand All @@ -24,15 +24,15 @@ trainer:
num_train_epochs: 3
per_device_train_batch_size: 1
# Learning rate for optimizer (alias: lr)
learning_rate: 1.7e-5
learning_rate: 1e-5
logging_steps: 50
save_steps: 200
num_generations: 3
# Per-agent generation cap; increase if outputs truncate.
max_new_tokens: 660
max_new_tokens: 600
temperature: 0.25
top_p: 0.90
num_turns: 2
num_turns: 1
# PPO-related (CoMLRL MAGRPO) options
# Whether to normalize advantages when updating policy
# normalize_advantage: true
Expand Down
29 changes: 10 additions & 19 deletions rewards/CE_reward.py
Original file line number Diff line number Diff line change
Expand Up @@ -394,20 +394,7 @@ def collect_calls(fn: "ast.FunctionDef") -> Set[str]:
_count_pass_syntax = 0

def get_reward_function(strategy, num_agents: int) -> Callable[..., List[float]]:
"""Return a reward function implementing the redesigned lv1+lv2+lv3 scoring.

- V = total number of class methods requiring implementation
- lv1 = 2 * (|union of chosen methods across agents| / V)
Special case: if coverage < 1/2 then reward = 0 for this sample
- lv2 = total-picks control:
Let S = Σ_i |A_i| be the total number of functions generated by all agents.
* If S ≥ 2V+2: terminate this sample early with total reward = -INF (=-1)
* If 0 <= S <= V: lv2(S) = 2 - 3 * ((S - V)^2) / V^2 (assuming V>0)
* If V < S <= 2V+2: lv2(S) = 2 - 3 * ((S - V)^2) / (V + 1)^2
- lv3 = balance based on variance of |A_i| around t = V/N, with
MSD = (1/N) * Σ (s_i - t)^2 and MSD_max = (1/N) * V^2 * (1 - 1/N),
R_bal = max(0, 1 - MSD/(MSD_max + eps))
Total reward = lv1 + lv2 + lv3
"""Return a reward function
"""

def reward_wrapper(*agent_completions, batch_items=None, prompts=None):
Expand Down Expand Up @@ -447,9 +434,12 @@ def reward_wrapper(*agent_completions, batch_items=None, prompts=None):
INF = 1
_count_total += 1

V_set: Set[str] = set(method_names)
V = len(V_set)

# Early penalty: penalize by number of agents with zero functions (k * -INF) and skip
try:
zeros = sum(1 for s in A_sets if (len(s) if s is not None else 0) == 0)
zeros = sum(1 for s in A_sets if (len(s) if s is not None else 0) in (0, V))
if zeros > 0:
rewards.append(-INF * 0.5 * zeros)
continue
Expand All @@ -460,8 +450,7 @@ def reward_wrapper(*agent_completions, batch_items=None, prompts=None):
_count_pass_lv0 += 1

# New reward rules (lv1 + lv2)
V_set: Set[str] = set(method_names)
V = len(V_set)

if V <= 0:
rewards.append(-INF)
continue
Expand All @@ -477,11 +466,13 @@ def reward_wrapper(*agent_completions, batch_items=None, prompts=None):
continue

lv1 = 2.0 * coverage_ratio
if coverage_ratio == 1.0:
lv1 += 0.5 # bonus for full coverage

# lv2: constrain total picks S = sum_i |A_i|
S_total = sum(len(s) for s in A_sets)
# Early termination if total picks exceed 2V
if S_total > 2 * V + 2:
if S_total >= 2 * V:
rewards.append(-INF)
continue

Expand Down Expand Up @@ -522,7 +513,7 @@ def reward_wrapper(*agent_completions, batch_items=None, prompts=None):
sum_J += J
N_pairs += 1
mean_J = (sum_J / N_pairs) if N_pairs > 0 else 0.0
jaccard_term = 1.0 * (1.0 - 1.5 * mean_J)
jaccard_term = 1.0 * (1.0 - 2.0 * mean_J)
# if jaccard_term > 2.0:
# jaccard_term = 2.0
# elif jaccard_term < -2.0:
Expand Down
6 changes: 6 additions & 0 deletions utils/prompting.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,12 @@ def build_take_job_prompt(
SKELETON START
{skeleton.strip()}
SKELETON END

As a final reminder, please select a **non-empty, proper subset** of {v_braced} to implement.

We recommend choosing a consecutive block of methods that either starts at the beginning of {v_braced} or ends at its last method (DO NOT limit yourself to only the beginning).

Take particular care not to select all methods for implementation!
"""
).strip()

Expand Down