diff --git a/README.md b/README.md index a20ccba..57d797a 100644 --- a/README.md +++ b/README.md @@ -28,44 +28,40 @@ python LLM_Collaboration_with_MARL/train_magrpo.py \ ## Settings -### Joint Action Modes +### Joint Action -`magrpo.joint_mode` determines how to combine each agent’s G generations into joint actions at each turn. Two modes are supported: `align` (default), which pairs the g‑th generation of every agent to form G joint actions per node; and `cross`, which forms the Cartesian product within a node, yielding G^N joint actions per node (N agents). Total leaf joint trajectories after T turns (no early termination): `align` → G^T; `cross` → (G^N)^T = G^{N·T}. +`magrpo.joint_mode` determines how to combine each agent’s G generations into joint actions at each turn. Two modes are supported: 'align' (default), which pairs the g‑th generation of every agent to form G joint actions per node; and 'cross', which forms the Cartesian product within a node, yielding G^N joint actions per node (N agents). Total leaf joint trajectories after T turns (no early termination): align → G^T; cross → G^{N·T}. Aligned is faster in wall‑time (fewer sibling evaluations per node), while cross is more sample‑efficient (better value estimation) without extra VRAM because it reuses the same G generations per agent and only crosses them within the node. We never cross across different nodes/prompts; this preserves causal state consistency (actions are conditioned on the same prompts), keeps siblings comparable for the baseline/advantage, maintains correct credit assignment (log‑probs matched to rewards from the same state), and remains computationally tractable. -### Advantage Calculation +### Advantage -`magrpo.normalize_advantage` is false by default. When true, compute z-scored advantages over sibling returns; when false, use a mean baseline without normalization. +Advantages are used to optimize the agents policies, which use a mean baseline without any standard‑deviation normalization to make training unbiased (see [Dr. GRPO](https://arxiv.org/pdf/2503.20783)). We do not apply importance sampling ratios either, since our training is in an on-policy manner (the same policy is used for sampling and training). -`magrpo.epsilon_clip` clamps the advantage to [-epsilon_clip, +epsilon_clip] after normalization (default: None). 0 or None skips clamping entirely. +### Number of Samples -We do not apply the importance sampling ratio because the policy changes slowly with LLMs, and the ratio is close to 1.0. This avoids numerical instability from multiplying many small probabilities. +`magrpo.num_turns` is the number of turns in training and evaluation, and `magrpo.num_generations` is the number of samples per generation. Leaf (total samples at current turn) counts grow with T: `aligned` → G^T; `cross` → G^{N·T}. At each node, the sibling set (competing joint actions under the same prompt/context/turn) has size G for `aligned`, and G^N for `cross`. The policy‑gradient baseline is the mean return over these siblings at that node, i.e., advantage Aᵢ = Returnᵢ − mean_sibling(Return). -### Number of Turns +### Termination -`magrpo.num_turns` determines the number of turns (default: 2). Leaf counts grow with T: `aligned` → G^T; `cross` → G^{N·T}. At each node, the sibling set (competing joint actions under the same prompt/context/turn) has size G for `aligned`, and G^N for `cross`. The policy‑gradient baseline is the mean return over these siblings at that node, i.e., advantage Aᵢ = Returnᵢ − mean_sibling(Return). +`magrpo.termination_threshold` is used to incentivize agents to find high‑reward solutions quickly instead of expanding the full Monte Carlo tree. At each node (branch, turn), we compute the mean immediate reward across that node’s sibling joint actions; if the mean exceeds the threshold, that branch stops expanding at this turn and the trainer backpropagates from the truncated subtree. Other branches continue. -### Early Termination +### New Prompts -`magrpo.termination_threshold` is used to incentivize agents to find high‑reward solutions quickly, instead of expanding the full Monte Carlo tree. At each node (branch, turn), compute the mean immediate reward across that node’s sibling joint actions; if the mean exceeds the threshold, that branch stops expanding at this turn and the trainer backpropagates from the truncated subtree. Other branches continue. - -### 2+Turn Prompt - -`external.original_prompt` and `external.previous_response` both default as `true`. 2+ turn prompts include both the original first‑turn problem prompt and the previous response by default to preserve full context; you can shorten the context by setting either to `false` (for example, keep only the previous response to reduce tokens while retaining the most recent interaction). +`external.original_prompt` and `external.previous_response` both default as true. 2+ turn prompts include both the original first‑turn problem prompt and the previous response by default to preserve full context; you can shorten the context by setting either to false (for example, keep only the previous response to reduce tokens while retaining the most recent interaction). ### External Modes -`external.mode` is set to 'level_feedback' by default. This gives additional information from external to prompts in the following turns; 'level_feedback' attaches test‑driven diagnostics, while alternatives include: +`external.mode` is used to imitate the environment transition, which is set to 'level_feedback' by default. This gives additional information from external to prompts in the following turns; 'level_feedback' attaches test‑driven diagnostics, while alternatives include: -- `expert_edits`: an LLM proposes edits; prompts include edit suggestions plus context. -- `level_passed` / `passed`: binary outcome oriented prompts with minimal context. -- `plain`: no diagnostics, but still includes previous response (unless disabled) and a "Revise ..." instruction. +- 'expert_edits': an LLM proposes edits; prompts include edit suggestions plus context. +- 'level_passed' / 'passed': binary outcome oriented prompts with minimal context. +- 'plain': no diagnostics, but still includes previous response (unless disabled) and a "revise your previous response" instruction. Specific settings for 'level_feedback' is `external.sandbox_slice`, which controls how many eval tests to include in the feedback. By default, sandbox executes only the first assert (sandbox_slice=1). Use all eval tests by setting `external.sandbox_slice` to 0, None, or 'all'. Negative values use the last asserts. `external.sandbox_slice` only affects analysis-based modes ('level_feedback', 'level_passed', 'passed'), and it has no effect on 'expert_edits'. -Specific settings for 'expert_edits' is `external.expert_edits_model`, which controls which LLM to use for proposing edits. By default, it uses DeepSeek-Coder. You can also change it to Claude-3, GPT-4, once you have keys/tokens in your global environment variables. +Specific settings for 'expert_edits' is `external.expert_edits_model`, which controls which LLM to use for proposing edits. By default, it uses DeepSeek-Coder. You can also change it to Claude, GPT, and other models, once you have keys/tokens in your environment. ### Output -`output.save_model` is set to `false` by default because of the huge storage required by multiple LLMs. `verbose` is used for debug printing on cluster if set to be true, but it is default to be false and you can only see a tqdm bar that shows the training progress. You can also turn on `magrpo.log_code_levels` to log the level-rewards during training, but it will crazily slow down the training. +`output.save_model` is set to 'false' by default because of the huge storage required by multiple LLMs. `output.verbose` is used for debug printing on cluster if set to be true, but it is default to be false and you can only see a tqdm bar that shows the training progress. diff --git a/baselines/che_concat.py b/baselines/che_concat.py index 3badb96..c20fde0 100644 --- a/baselines/che_concat.py +++ b/baselines/che_concat.py @@ -3,8 +3,7 @@ import re import signal import time -from collections import defaultdict -from math import comb + import numpy as np import torch diff --git a/baselines/che_discuss.py b/baselines/che_discuss.py index 6878226..e2ecfc3 100644 --- a/baselines/che_discuss.py +++ b/baselines/che_discuss.py @@ -3,8 +3,7 @@ import re import signal import time -from collections import defaultdict -from math import comb + import numpy as np import torch @@ -1009,6 +1008,9 @@ def evaluate_coophumaneval_two_round( def main(): + # -------------------------------------------------------------- + # CLI: parse arguments + # -------------------------------------------------------------- parser = argparse.ArgumentParser( description="CoopHumanEval Two-Round Model Evaluation" ) @@ -1043,14 +1045,18 @@ def main(): args = parser.parse_args() - # Initialize two-round evaluator + # -------------------------------------------------------------- + # Initialize evaluator + # -------------------------------------------------------------- evaluator = QwenCoopHumanEvalTwoRoundEvaluator( aux_model_name=args.aux_model, main_model_name=args.main_model, device=args.device, ) + # -------------------------------------------------------------- # Run evaluation + # -------------------------------------------------------------- aggregated_metrics, sample_results = evaluator.evaluate_coophumaneval_two_round( num_samples=args.samples, num_generations=args.generations, diff --git a/baselines/che_sequential.py b/baselines/che_sequential.py index 70cf021..9ca1fca 100644 --- a/baselines/che_sequential.py +++ b/baselines/che_sequential.py @@ -3,8 +3,7 @@ import re import signal import time -from collections import defaultdict -from math import comb + import numpy as np import torch diff --git a/baselines/che_single_agent.py b/baselines/che_single_agent.py index ac4b8ff..42b34b1 100644 --- a/baselines/che_single_agent.py +++ b/baselines/che_single_agent.py @@ -3,8 +3,7 @@ import re import signal import time -from collections import defaultdict -from math import comb + import numpy as np import torch @@ -731,6 +730,9 @@ def evaluate_coophumaneval_baseline( def main(): + # -------------------------------------------------------------- + # CLI: parse arguments + # -------------------------------------------------------------- parser = argparse.ArgumentParser( description="CoopHumanEval Single Agent Baseline Evaluation" ) @@ -760,12 +762,16 @@ def main(): args = parser.parse_args() - # Initialize baseline evaluator + # -------------------------------------------------------------- + # Initialize evaluator + # -------------------------------------------------------------- evaluator = QwenCoopHumanEvalSingleAgentBaseline( model_name=args.model, device=args.device ) + # -------------------------------------------------------------- # Run evaluation + # -------------------------------------------------------------- aggregated_metrics, sample_results = evaluator.evaluate_coophumaneval_baseline( num_samples=args.samples, num_generations=args.generations, diff --git a/baselines/he_concat.py b/baselines/he_concat.py index a598320..d51494c 100644 --- a/baselines/he_concat.py +++ b/baselines/he_concat.py @@ -3,8 +3,7 @@ import re import signal import time -from collections import defaultdict -from math import comb + import numpy as np import torch diff --git a/baselines/he_discuss.py b/baselines/he_discuss.py index 4bf227e..9c4768f 100644 --- a/baselines/he_discuss.py +++ b/baselines/he_discuss.py @@ -3,8 +3,7 @@ import re import signal import time -from collections import defaultdict -from math import comb + import numpy as np import torch @@ -1012,6 +1011,9 @@ def evaluate_humaneval_two_round( def main(): + # -------------------------------------------------------------- + # CLI: parse arguments + # -------------------------------------------------------------- parser = argparse.ArgumentParser(description="HumanEval Two-Round Model Evaluation") parser.add_argument( "--aux-model", default="Qwen/Qwen2.5-Coder-3B", help="Auxiliary model name" @@ -1044,14 +1046,18 @@ def main(): args = parser.parse_args() - # Initialize two-round evaluator + # -------------------------------------------------------------- + # Initialize evaluator + # -------------------------------------------------------------- evaluator = QwenHumanEvalTwoRoundEvaluator( aux_model_name=args.aux_model, main_model_name=args.main_model, device=args.device, ) + # -------------------------------------------------------------- # Run evaluation + # -------------------------------------------------------------- aggregated_metrics, sample_results = evaluator.evaluate_humaneval_two_round( num_samples=args.samples, num_generations=args.generations, diff --git a/baselines/he_sequential.py b/baselines/he_sequential.py index 61dacbb..b07f97d 100644 --- a/baselines/he_sequential.py +++ b/baselines/he_sequential.py @@ -3,8 +3,7 @@ import re import signal import time -from collections import defaultdict -from math import comb + import numpy as np import torch diff --git a/baselines/he_single_agent.py b/baselines/he_single_agent.py index 4996bed..0dbb8f6 100644 --- a/baselines/he_single_agent.py +++ b/baselines/he_single_agent.py @@ -3,8 +3,7 @@ import re import signal import time -from collections import defaultdict -from math import comb + import numpy as np import torch @@ -734,6 +733,9 @@ def evaluate_humaneval_baseline( def main(): + # -------------------------------------------------------------- + # CLI: parse arguments + # -------------------------------------------------------------- parser = argparse.ArgumentParser( description="HumanEval Single Agent Baseline Evaluation" ) @@ -763,12 +765,16 @@ def main(): args = parser.parse_args() - # Initialize baseline evaluator + # -------------------------------------------------------------- + # Initialize evaluator + # -------------------------------------------------------------- evaluator = QwenHumanEvalSingleAgentBaseline( model_name=args.model, device=args.device ) + # -------------------------------------------------------------- # Run evaluation + # -------------------------------------------------------------- aggregated_metrics, sample_results = evaluator.evaluate_humaneval_baseline( num_samples=args.samples, num_generations=args.generations, diff --git a/configs/grpo_che_config.yaml b/configs/grpo_che_config.yaml index 411e444..ed9f6d9 100644 --- a/configs/grpo_che_config.yaml +++ b/configs/grpo_che_config.yaml @@ -47,7 +47,6 @@ grpo: discount: 0.9 termination_threshold: -0.1 reward_shift: -2.1 - normalize_advantage: false epsilon_clip: null # wandb diff --git a/configs/grpo_he_config.yaml b/configs/grpo_he_config.yaml index c944c29..60740bb 100644 --- a/configs/grpo_he_config.yaml +++ b/configs/grpo_he_config.yaml @@ -47,7 +47,6 @@ grpo: discount: 0.9 termination_threshold: -0.1 reward_shift: -2.1 - normalize_advantage: false epsilon_clip: null # wandb diff --git a/configs/grpo_mbpp_config.yaml b/configs/grpo_mbpp_config.yaml new file mode 100644 index 0000000..64de420 --- /dev/null +++ b/configs/grpo_mbpp_config.yaml @@ -0,0 +1,58 @@ +# model +model: + name: "Qwen/Qwen2.5-Coder-3B" + type: "qwen" + temperature: 0.7 + top_p: 0.9 + max_length: 2048 + tokenizer_kwargs: + trust_remote_code: true + model_kwargs: + trust_remote_code: true + torch_dtype: "bfloat16" + +# dataset +dataset: + name: "OpenMLRL/MBPP" + type: "mbpp" + train_split: "test[15:65]" + eval_split: "test[:15]" + +# output +output: + base_dir: "output" + save_final_model: false + verbose: false + +# external +external: + mode: "level_feedback" + sandbox_slice: 1 + original_prompt: true + previous_response: true + +# grpo +grpo: + num_turns: 2 + num_train_epochs: 8 + per_device_train_batch_size: 1 + learning_rate: 3.0e-5 + logging_steps: 50 + save_steps: 200 + num_generations: 4 + max_new_tokens: 256 + joint_mode: aligned + temperature: 0.8 + top_p: 0.95 + discount: 0.9 + termination_threshold: -0.1 + reward_shift: -2.1 + epsilon_clip: null + +# wandb +wandb: + project: "mlrl" + entity: "nu-llpr" + name: "grpo_mbpp" + dir: "output" + tags: ["grpo", "mbpp"] diff --git a/configs/magrpo_che_config.yaml b/configs/magrpo_che_config.yaml index 44c3a63..88de6a2 100644 --- a/configs/magrpo_che_config.yaml +++ b/configs/magrpo_che_config.yaml @@ -48,7 +48,6 @@ magrpo: discount: 0.9 termination_threshold: -0.2 reward_shift: -4 - normalize_advantage: false epsilon_clip: null # wandb diff --git a/configs/magrpo_he_config.yaml b/configs/magrpo_he_config.yaml index 2eda90a..5f75554 100644 --- a/configs/magrpo_he_config.yaml +++ b/configs/magrpo_he_config.yaml @@ -46,7 +46,6 @@ magrpo: discount: 0.9 termination_threshold: -0.2 reward_shift: -4 - normalize_advantage: false epsilon_clip: null # wandb diff --git a/configs/magrpo_mbpp_config.yaml b/configs/magrpo_mbpp_config.yaml new file mode 100644 index 0000000..bbfb7b9 --- /dev/null +++ b/configs/magrpo_mbpp_config.yaml @@ -0,0 +1,59 @@ +# model +model: + name: "Qwen/Qwen2.5-Coder-3B" + type: "qwen" + temperature: 0.7 + top_p: 0.9 + max_length: 2048 + tokenizer_kwargs: + trust_remote_code: true + model_kwargs: + trust_remote_code: true + torch_dtype: "bfloat16" + +# dataset +dataset: + name: "OpenMLRL/MBPP" + type: "mbpp" + train_split: "test[15:65]" + eval_split: "test[:15]" + +# output +output: + base_dir: "output" + save_final_model: false + verbose: false + +# external +external: + mode: "level_feedback" + sandbox_slice: 1 + original_prompt: true + previous_response: true + +# magrpo +magrpo: + num_turns: 2 + num_train_epochs: 8 + per_device_train_batch_size: 1 + learning_rate: 3.0e-5 + logging_steps: 50 + save_steps: 200 + num_generations: 4 + max_new_tokens: 256 + temperature: 0.8 + top_p: 0.95 + joint_mode: aligned + num_agents: 2 + discount: 0.9 + termination_threshold: -0.2 + reward_shift: -4 + epsilon_clip: null + +# wandb +wandb: + project: "mlrl" + entity: "nu-llpr" + name: "magrpo_mbpp" + dir: "output" + tags: ["magrpo", "mbpp", "multi-agent"] diff --git a/data/README.md b/data/README.md deleted file mode 100644 index a62c0db..0000000 --- a/data/README.md +++ /dev/null @@ -1,7 +0,0 @@ -# MLRL Datasets - -This directory contains dataset processing instructions for the MLRL experiments. - -## CoopHumanEval - -This dataset is based on the HumanEval benchmark and is used for the code generation task. diff --git a/data/create_che.py b/data/create_che.py index a2a1ce3..b61fafc 100644 --- a/data/create_che.py +++ b/data/create_che.py @@ -1,6 +1,6 @@ import json -import pandas as pd + from datasets import Dataset, DatasetDict, Features, Value from huggingface_hub import HfApi, login @@ -15,6 +15,10 @@ def upload_json_to_hf_dataset(json_file_path, dataset_name, hf_token=None): hf_token (str, optional): Hugging Face API token. If None, will prompt for login. """ + # ------------------------------------------------------------------ + # Load raw JSON and normalize schema + # ------------------------------------------------------------------ + # Step 1: Load the JSON file print(f"Loading JSON file from {json_file_path}...") with open(json_file_path, "r") as f: @@ -52,7 +56,9 @@ def upload_json_to_hf_dataset(json_file_path, dataset_name, hf_token=None): # Create the dataset with features dataset = Dataset.from_dict(dataset_dict, features=features) - # Step 3: Login to Hugging Face + # ------------------------------------------------------------------ + # Authenticate to Hugging Face + # ------------------------------------------------------------------ if hf_token: login(token=hf_token) else: @@ -67,7 +73,9 @@ def upload_json_to_hf_dataset(json_file_path, dataset_name, hf_token=None): df = dataset.to_pandas() dataset_size = sum(df.memory_usage(deep=True)) - # Step 6: Create README content BEFORE pushing + # ------------------------------------------------------------------ + # Build dataset card (README) with metadata + # ------------------------------------------------------------------ readme_content = f"""--- dataset_info: features: diff --git a/data/create_che_sol.py b/data/create_che_sol.py index e17b0f1..ed3fd78 100644 --- a/data/create_che_sol.py +++ b/data/create_che_sol.py @@ -1,5 +1,5 @@ import json -import os + import re from datetime import datetime from typing import Any, Dict @@ -137,18 +137,20 @@ def generate_completion( def main(): - # Load the full dataset + # ------------------------------------------------------------------ + # Load dataset + # ------------------------------------------------------------------ print("Loading CoopHumanEval dataset...") dataset_name = "LovelyBuggies/CoopHumanEval" dataset = load_dataset(dataset_name, split="test") print(f"Total examples: {len(dataset)}") - # Model names + # ------------------------------------------------------------------ + # Load models + # ------------------------------------------------------------------ aux_model_name = "LovelyBuggies/2xQwen2.5-Coder-3B-Griffin-Aux" main_model_name = "LovelyBuggies/2xQwen2.5-Coder-3B-Griffin-Main" - - # Load models print("Loading models...") agent_0, tokenizer_0 = load_agent_from_hf(aux_model_name) # Aux model agent_1, tokenizer_1 = load_agent_from_hf(main_model_name) # Main model @@ -157,7 +159,9 @@ def main(): print("Failed to load models. Exiting.") return - # Process all examples + # ------------------------------------------------------------------ + # Generate solutions for all examples + # ------------------------------------------------------------------ results = [] print("Generating solutions...") @@ -212,7 +216,9 @@ def main(): } results.append(result) + # ------------------------------------------------------------------ # Save results + # ------------------------------------------------------------------ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") output_filename = f"coophumaneval_mlrl_solutions_{timestamp}.json" diff --git a/data/create_mbpp.py b/data/create_mbpp.py new file mode 100644 index 0000000..39859a5 --- /dev/null +++ b/data/create_mbpp.py @@ -0,0 +1,200 @@ +import json + +import pandas as pd +from datasets import Dataset, DatasetDict, Features, Value +from huggingface_hub import HfApi, login + + +def upload_json_to_hf_dataset(json_file_path, dataset_name, hf_token=None): + """ + Upload a JSON file to Hugging Face as a dataset with proper test split and dataset card. + + Args: + json_file_path (str): Path to the JSON file + dataset_name (str): Name of the dataset on Hugging Face (e.g., "OpenMLRL/MBPP") + hf_token (str, optional): Hugging Face API token. If None, will prompt for login. + """ + + # Step 1: Load the JSON file + print(f"Loading JSON file from {json_file_path}...") + with open(json_file_path, "r") as f: + data = json.load(f) + + print(f"Loaded {len(data)} examples from JSON file") + + # Step 2: Convert to Hugging Face Dataset + print("Converting to Hugging Face Dataset format...") + + # Create a dictionary with lists for each field + dataset_dict = { + "task_id": [], + "prompt": [], + "test": [], + "entry_point": [], + } + + for item in data: + dataset_dict["task_id"].append(item["task_id"]) + dataset_dict["prompt"].append(item["prompt"]) + dataset_dict["test"].append(item["test"]) + dataset_dict["entry_point"].append(item["entry_point"]) + + # Define features explicitly + features = Features( + { + "task_id": Value("string"), + "prompt": Value("string"), + "test": Value("string"), + "entry_point": Value("string"), + } + ) + + # Create the dataset with features + dataset = Dataset.from_dict(dataset_dict, features=features) + + # Step 3: Login to Hugging Face + if hf_token: + login(token=hf_token) + else: + print("Please login to Hugging Face:") + login() + + # Step 4: Create a DatasetDict with a TEST split (not train) + dataset_dict = DatasetDict({"test": dataset}) + + # Step 5: Calculate dataset size + # Convert dataset to pandas to estimate size + df = dataset.to_pandas() + dataset_size = sum(df.memory_usage(deep=True)) + + # Step 6: Create README content BEFORE pushing + readme_content = f"""--- +dataset_info: + features: + - name: task_id + dtype: string + - name: prompt + dtype: string + - name: test + dtype: string + - name: entry_point + dtype: string + splits: + - name: test + num_bytes: {dataset_size} + num_examples: {len(data)} + download_size: {dataset_size} + dataset_size: {dataset_size} +configs: +- config_name: default + data_files: + - split: test + path: data/test-* +license: apache-2.0 +task_categories: +- text-generation +language: +- en +tags: +- code +- python +- programming +- code-generation +pretty_name: MBPP +size_categories: +- 1K i . Return the calculated integer value.\n '''\n", + "test": "def check(candidate):\n assert candidate([1, 101, 2, 3, 100, 4, 5 ], 7, 4, 6) == 11\n assert candidate([1, 101, 2, 3, 100, 4, 5 ], 7, 2, 5) == 7\n assert candidate([11, 15, 19, 21, 26, 28, 31], 7, 2, 4) == 71", + "entry_point": "max_sum_increasing_subseq" + }, + { + "task_id": "MBPP/167", + "prompt": "def colon_tuplex(tuplex,m,n):\n '''Write a function to get a colon of a tuple. Return the modified tuple.\n '''\n", + "test": "def check(candidate):\n assert candidate((\"HELLO\", 5, [], True) ,2,50)==(\"HELLO\", 5, [50], True)\n assert candidate((\"HELLO\", 5, [], True) ,2,100)==((\"HELLO\", 5, [100],True))\n assert candidate((\"HELLO\", 5, [], True) ,2,500)==(\"HELLO\", 5, [500], True)", + "entry_point": "colon_tuplex" + }, + { + "task_id": "MBPP/168", + "prompt": "def large_product(nums1, nums2, N):\n '''Write a function to find the specified number of largest products from two given lists, selecting one factor from each list. Return the list of largest products.\n '''\n", + "test": "def check(candidate):\n assert candidate([1, 2, 3, 4, 5, 6],[3, 6, 8, 9, 10, 6],3)==[60, 54, 50]\n assert candidate([1, 2, 3, 4, 5, 6],[3, 6, 8, 9, 10, 6],4)==[60, 54, 50, 48]\n assert candidate([1, 2, 3, 4, 5, 6],[3, 6, 8, 9, 10, 6],5)==[60, 54, 50, 48, 45]", + "entry_point": "large_product" + }, + { + "task_id": "MBPP/169", + "prompt": "def maximum(a,b):\n '''Write a python function to find the maximum of two numbers. Return the calculated integer value.\n '''\n", + "test": "def check(candidate):\n assert candidate(5,10) == 10\n assert candidate(-1,-2) == -1\n assert candidate(9,7) == 9", + "entry_point": "maximum" + }, + { + "task_id": "MBPP/170", + "prompt": "def string_to_tuple(str1):\n '''Write a function to convert a given string to a tuple of characters. Return the resulting tuple.\n '''\n", + "test": "def check(candidate):\n assert candidate(\"python 3.0\")==('p', 'y', 't', 'h', 'o', 'n', '3', '.', '0')\n assert candidate(\"item1\")==('i', 't', 'e', 'm', '1')\n assert candidate(\"15.10\")==('1', '5', '.', '1', '0')", + "entry_point": "string_to_tuple" + }, + { + "task_id": "MBPP/171", + "prompt": "def set_left_most_unset_bit(n):\n '''Write a python function to set the left most unset bit. Return the calculated integer value.\n '''\n", + "test": "def check(candidate):\n assert candidate(10) == 14\n assert candidate(12) == 14\n assert candidate(15) == 15", + "entry_point": "set_left_most_unset_bit" + }, + { + "task_id": "MBPP/172", + "prompt": "def volume_cone(r,h):\n '''Write a function to find the volume of a cone. Return the calculated number.\n '''\n", + "test": "def check(candidate):\n assert math.candidate(candidate(5,12), 314.15926535897927, rel_tol=0.001)\n assert math.candidate(candidate(10,15), 1570.7963267948965, rel_tol=0.001)\n assert math.candidate(candidate(19,17), 6426.651371693521, rel_tol=0.001)", + "entry_point": "volume_cone" + }, + { + "task_id": "MBPP/173", + "prompt": "def highest_Power_of_2(n):\n '''Write a python function to find the highest power of 2 that is less than or equal to n. Return the calculated integer value.\n '''\n", + "test": "def check(candidate):\n assert candidate(10) == 8\n assert candidate(19) == 16\n assert candidate(32) == 32", + "entry_point": "highest_Power_of_2" + }, + { + "task_id": "MBPP/174", + "prompt": "def find_lucas(n):\n '''Write a function to find the n'th lucas number. Return the calculated integer value.\n '''\n", + "test": "def check(candidate):\n assert candidate(9) == 76\n assert candidate(4) == 7\n assert candidate(3) == 4", + "entry_point": "find_lucas" + }, + { + "task_id": "MBPP/175", + "prompt": "def add_string(list_, string):\n '''Write a function to apply a given format string to all of the elements in a list. Return the modified list.\n '''\n", + "test": "def check(candidate):\n assert candidate([1,2,3,4],'temp{0}')==['temp1', 'temp2', 'temp3', 'temp4']\n assert candidate(['a','b','c','d'], 'python{0}')==[ 'pythona', 'pythonb', 'pythonc', 'pythond']\n assert candidate([5,6,7,8],'string{0}')==['string5', 'string6', 'string7', 'string8']", + "entry_point": "add_string" + }, + { + "task_id": "MBPP/176", + "prompt": "def convert_list_dictionary(l1, l2, l3):\n '''Write a function to convert more than one list to nested dictionary. Return the list of nested dictionaries.\n '''\n", + "test": "def check(candidate):\n assert candidate([\"S001\", \"S002\", \"S003\", \"S004\"],[\"Adina Park\", \"Leyton Marsh\", \"Duncan Boyle\", \"Saim Richards\"] ,[85, 98, 89, 92])==[{'S001': {'Adina Park': 85}}, {'S002': {'Leyton Marsh': 98}}, {'S003': {'Duncan Boyle': 89}}, {'S004': {'Saim Richards': 92}}]\n assert candidate([\"abc\",\"def\",\"ghi\",\"jkl\"],[\"python\",\"program\",\"language\",\"programs\"],[100,200,300,400])==[{'abc':{'python':100}},{'def':{'program':200}},{'ghi':{'language':300}},{'jkl':{'programs':400}}]\n assert candidate([\"A1\",\"A2\",\"A3\",\"A4\"],[\"java\",\"C\",\"C++\",\"DBMS\"],[10,20,30,40])==[{'A1':{'java':10}},{'A2':{'C':20}},{'A3':{'C++':30}},{'A4':{'DBMS':40}}]", + "entry_point": "convert_list_dictionary" + }, + { + "task_id": "MBPP/177", + "prompt": "def get_max_sum(n):\n '''Write a function to find the maximum sum possible by using the given equation f(n) = max( (f(n/2) + f(n/3) + f(n/4) + f(n/5)), n). Return the calculated integer value.\n '''\n", + "test": "def check(candidate):\n assert candidate(60) == 106\n assert candidate(10) == 12\n assert candidate(2) == 2", + "entry_point": "get_max_sum" + }, + { + "task_id": "MBPP/178", + "prompt": "def max_length_list(input_list):\n '''Write a function to find the list with maximum length. Return the maximum length and list as a tuple.\n '''\n", + "test": "def check(candidate):\n assert candidate([[0], [1, 3], [5, 7], [9, 11], [13, 15, 17]])==(3, [13, 15, 17])\n assert candidate([[1,2,3,4,5],[1,2,3,4],[1,2,3],[1,2],[1]])==(5,[1,2,3,4,5])\n assert candidate([[3,4,5],[6,7,8,9],[10,11,12]])==(4,[6,7,8,9])", + "entry_point": "max_length_list" + }, + { + "task_id": "MBPP/179", + "prompt": "def check_distinct(test_tup):\n '''Write a function to check if given tuple contains no duplicates. Return True if the condition is met, False otherwise.\n '''\n", + "test": "def check(candidate):\n assert candidate((1, 4, 5, 6, 1, 4)) == False\n assert candidate((1, 4, 5, 6)) == True\n assert candidate((2, 3, 4, 5, 6)) == True", + "entry_point": "check_distinct" + }, + { + "task_id": "MBPP/180", + "prompt": "def first_non_repeating_character(str1):\n '''Write a python function to find the first non-repeated character in a given string. Return the first non-repeated character or None.\n '''\n", + "test": "def check(candidate):\n assert candidate(\"abcabc\") == None\n assert candidate(\"abc\") == \"a\"\n assert candidate(\"ababc\") == \"c\"", + "entry_point": "first_non_repeating_character" + }, + { + "task_id": "MBPP/181", + "prompt": "def check_char(string):\n '''Write a function to check whether the given string starts and ends with the same character or not. Return Valid or Invalid.\n '''\n", + "test": "def check(candidate):\n assert candidate(\"abba\") == \"Valid\"\n assert candidate(\"a\") == \"Valid\"\n assert candidate(\"abcd\") == \"Invalid\"", + "entry_point": "check_char" + }, + { + "task_id": "MBPP/182", + "prompt": "def median_numbers(a,b,c):\n '''Write a function to find the median of three numbers. Return the calculated number.\n '''\n", + "test": "def check(candidate):\n assert candidate(25,55,65)==55.0\n assert candidate(20,10,30)==20.0\n assert candidate(15,45,75)==45.0", + "entry_point": "median_numbers" + }, + { + "task_id": "MBPP/183", + "prompt": "def sum_of_digits(nums):\n '''Write a function to compute the sum of digits of each number of a given list. Return the sum as an integer.\n '''\n", + "test": "def check(candidate):\n assert candidate([10,2,56])==14\n assert candidate([[10,20,4,5,'b',70,'a']])==19\n assert candidate([10,20,-4,5,-70])==19", + "entry_point": "sum_of_digits" + }, + { + "task_id": "MBPP/184", + "prompt": "def bitwise_xor(test_tup1, test_tup2):\n '''Write a function to perform the mathematical bitwise xor operation across the given tuples. Return the resulting tuple.\n '''\n", + "test": "def check(candidate):\n assert candidate((10, 4, 6, 9), (5, 2, 3, 3)) == (15, 6, 5, 10)\n assert candidate((11, 5, 7, 10), (6, 3, 4, 4)) == (13, 6, 3, 14)\n assert candidate((12, 6, 8, 11), (7, 4, 5, 6)) == (11, 2, 13, 13)", + "entry_point": "bitwise_xor" + }, + { + "task_id": "MBPP/185", + "prompt": "def extract_freq(test_list):\n '''Write a function to extract the number of unique tuples in the given list. Return the count as an integer.\n '''\n", + "test": "def check(candidate):\n assert candidate([(3, 4), (1, 2), (4, 3), (5, 6)] ) == 3\n assert candidate([(4, 15), (2, 3), (5, 4), (6, 7)] ) == 4\n assert candidate([(5, 16), (2, 3), (6, 5), (6, 9)] ) == 4", + "entry_point": "extract_freq" + }, + { + "task_id": "MBPP/186", + "prompt": "def add_nested_tuples(test_tup1, test_tup2):\n '''Write a function to perform index wise addition of tuple elements in the given two nested tuples. Return the resulting tuple.\n '''\n", + "test": "def check(candidate):\n assert candidate(((1, 3), (4, 5), (2, 9), (1, 10)), ((6, 7), (3, 9), (1, 1), (7, 3))) == ((7, 10), (7, 14), (3, 10), (8, 13))\n assert candidate(((2, 4), (5, 6), (3, 10), (2, 11)), ((7, 8), (4, 10), (2, 2), (8, 4))) == ((9, 12), (9, 16), (5, 12), (10, 15))\n assert candidate(((3, 5), (6, 7), (4, 11), (3, 12)), ((8, 9), (5, 11), (3, 3), (9, 5))) == ((11, 14), (11, 18), (7, 14), (12, 17))", + "entry_point": "add_nested_tuples" + }, + { + "task_id": "MBPP/187", + "prompt": "def minimum(a,b):\n '''Write a python function to find the minimum of two numbers. Return the calculated integer value.\n '''\n", + "test": "def check(candidate):\n assert candidate(1,2) == 1\n assert candidate(-5,-4) == -5\n assert candidate(0,0) == 0", + "entry_point": "minimum" + }, + { + "task_id": "MBPP/188", + "prompt": "def check_tuplex(tuplex,tuple1):\n '''Write a function to check whether an element exists within a tuple. Return True if the condition is met, False otherwise.\n '''\n", + "test": "def check(candidate):\n assert candidate((\"w\", 3, \"r\", \"e\", \"s\", \"o\", \"u\", \"r\", \"c\", \"e\"),'r')==True\n assert candidate((\"w\", 3, \"r\", \"e\", \"s\", \"o\", \"u\", \"r\", \"c\", \"e\"),'5')==False\n assert candidate((\"w\", 3, \"r\", \"e\", \"s\", \"o\", \"u\", \"r\", \"c\",\"e\"),3)==True", + "entry_point": "check_tuplex" + }, + { + "task_id": "MBPP/189", + "prompt": "def find_Parity(x):\n '''Write a python function to find whether the parity of a given number is odd. Return True if the condition is met, False otherwise.\n '''\n", + "test": "def check(candidate):\n assert candidate(12) == False\n assert candidate(7) == True\n assert candidate(10) == False", + "entry_point": "find_Parity" + }, + { + "task_id": "MBPP/190", + "prompt": "def rearrange_bigger(n):\n '''Write a function to create the next bigger number by rearranging the digits of a given number. Return the rearranged number or False.\n '''\n", + "test": "def check(candidate):\n assert candidate(12)==21\n assert candidate(10)==False\n assert candidate(102)==120", + "entry_point": "rearrange_bigger" + }, + { + "task_id": "MBPP/191", + "prompt": "def k_smallest_pairs(nums1, nums2, k):\n '''Write a function to find k number of smallest pairs which consist of one element from the first array and one element from the second array. Return the list of smallest pairs.\n '''\n", + "test": "def check(candidate):\n assert candidate([1,3,7],[2,4,6],2)==[[1, 2], [1, 4]]\n assert candidate([1,3,7],[2,4,6],1)==[[1, 2]]\n assert candidate([1,3,7],[2,4,6],7)==[[1, 2], [1, 4], [3, 2], [1, 6], [3, 4], [3, 6], [7, 2]]", + "entry_point": "k_smallest_pairs" + }, + { + "task_id": "MBPP/192", + "prompt": "def min_product_tuple(list1):\n '''Write a function to find the minimum product from the pairs of tuples within a given list. Return the calculated integer value.\n '''\n", + "test": "def check(candidate):\n assert candidate([(2, 7), (2, 6), (1, 8), (4, 9)] )==8\n assert candidate([(10,20), (15,2), (5,10)] )==30\n assert candidate([(11,44), (10,15), (20,5), (12, 9)] )==100", + "entry_point": "min_product_tuple" + }, + { + "task_id": "MBPP/193", + "prompt": "def min_val(listval):\n '''Write a function to find the minimum value in a given heterogeneous list. Return the minimum value.\n '''\n", + "test": "def check(candidate):\n assert candidate(['Python', 3, 2, 4, 5, 'version'])==2\n assert candidate(['Python', 15, 20, 25])==15\n assert candidate(['Python', 30, 20, 40, 50, 'version'])==20", + "entry_point": "min_val" + }, + { + "task_id": "MBPP/194", + "prompt": "def snake_to_camel_1(word):\n '''Write a function to convert the given snake case string to camel case string. Return the converted string.\n '''\n", + "test": "def check(candidate):\n assert candidate('android_tv') == 'AndroidTv'\n assert candidate('google_pixel') == 'GooglePixel'\n assert candidate('apple_watch') == 'AppleWatch'", + "entry_point": "snake_to_camel_1" + }, + { + "task_id": "MBPP/195", + "prompt": "def remove_odd(l):\n '''Write a python function to remove odd numbers from a given list. Return the modified list.\n '''\n", + "test": "def check(candidate):\n assert candidate([1,2,3]) == [2]\n assert candidate([2,4,6]) == [2,4,6]\n assert candidate([10,20,3]) == [10,20]", + "entry_point": "remove_odd" + }, + { + "task_id": "MBPP/196", + "prompt": "def extract_nth_element(list1, n):\n '''Write a function to extract the nth element from a given list of tuples. Return the list of nth elements.\n '''\n", + "test": "def check(candidate):\n assert candidate([('Greyson Fulton', 98, 99), ('Brady Kent', 97, 96), ('Wyatt Knott', 91, 94), ('Beau Turnbull', 94, 98)] ,0)==['Greyson Fulton', 'Brady Kent', 'Wyatt Knott', 'Beau Turnbull']\n assert candidate([('Greyson Fulton', 98, 99), ('Brady Kent', 97, 96), ('Wyatt Knott', 91, 94), ('Beau Turnbull', 94, 98)] ,2)==[99, 96, 94, 98]\n assert candidate([('Greyson Fulton', 98, 99), ('Brady Kent', 97, 96), ('Wyatt Knott', 91, 94), ('Beau Turnbull', 94, 98)],1)==[98, 97, 91, 94]", + "entry_point": "extract_nth_element" + }, + { + "task_id": "MBPP/197", + "prompt": "def overlapping(list1,list2):\n '''Write a python function to check whether any value in a sequence exists in a sequence or not. Return True if the condition is met, False otherwise.\n '''\n", + "test": "def check(candidate):\n assert candidate([1,2,3,4,5],[6,7,8,9]) == False\n assert candidate([1,2,3],[4,5,6]) == False\n assert candidate([1,4,5],[1,4,5]) == True", + "entry_point": "overlapping" + }, + { + "task_id": "MBPP/198", + "prompt": "def max_Product(arr):\n '''Write a python function to find a pair with highest product from a given array of integers. Return the pair as a tuple.\n '''\n", + "test": "def check(candidate):\n assert candidate([1,2,3,4,7,0,8,4]) == (7,8)\n assert candidate([0,-1,-2,-4,5,0,-6]) == (-4,-6)\n assert candidate([1,2,3]) == (2,3)", + "entry_point": "max_Product" + }, + { + "task_id": "MBPP/199", + "prompt": "def group_tuples(Input):\n '''Write a function to find common first element in given list of tuple. Return the grouped list of tuples.\n '''\n", + "test": "def check(candidate):\n assert candidate([('x', 'y'), ('x', 'z'), ('w', 't')]) == [('x', 'y', 'z'), ('w', 't')]\n assert candidate([('a', 'b'), ('a', 'c'), ('d', 'e')]) == [('a', 'b', 'c'), ('d', 'e')]\n assert candidate([('f', 'g'), ('f', 'g'), ('h', 'i')]) == [('f', 'g', 'g'), ('h', 'i')]", + "entry_point": "group_tuples" + }, + { + "task_id": "MBPP/200", + "prompt": "def Find_Max(lst):\n '''Write a python function to find the element of a list having maximum length. Return the element with maximum length.\n '''\n", + "test": "def check(candidate):\n assert candidate([['A'],['A','B'],['A','B','C']]) == ['A','B','C']\n assert candidate([[1],[1,2],[1,2,3]]) == [1,2,3]\n assert candidate([[1,1],[1,2,3],[1,5,6,1]]) == [1,5,6,1]", + "entry_point": "Find_Max" + }, + { + "task_id": "MBPP/201", + "prompt": "def round_and_sum(list1):\n '''Write a function to round every number of a given list of numbers and print the total sum multiplied by the length of the list. Return the sum as an integer.\n '''\n", + "test": "def check(candidate):\n assert candidate([22.4, 4.0, -16.22, -9.10, 11.00, -12.22, 14.20, -5.20, 17.50])==243\n assert candidate([5,2,9,24.3,29])==345\n assert candidate([25.0,56.7,89.2])==513", + "entry_point": "round_and_sum" + }, + { + "task_id": "MBPP/202", + "prompt": "def cube_Sum(n):\n '''Write a python function to find the cube sum of first n even natural numbers. Return the sum as an integer.\n '''\n", + "test": "def check(candidate):\n assert candidate(2) == 72\n assert candidate(3) == 288\n assert candidate(4) == 800", + "entry_point": "cube_Sum" + }, + { + "task_id": "MBPP/203", + "prompt": "def concatenate_tuple(test_tup):\n '''Write a function to concatenate each element of tuple by the delimiter. Return the resulting string.\n '''\n", + "test": "def check(candidate):\n assert candidate((\"ID\", \"is\", 4, \"UTS\") ) == 'ID-is-4-UTS'\n assert candidate((\"QWE\", \"is\", 4, \"RTY\") ) == 'QWE-is-4-RTY'\n assert candidate((\"ZEN\", \"is\", 4, \"OP\") ) == 'ZEN-is-4-OP'", + "entry_point": "concatenate_tuple" + }, + { + "task_id": "MBPP/204", + "prompt": "def find_Average_Of_Cube(n):\n '''Write a python function to find the average of cubes of first n natural numbers. Return the calculated number.\n '''\n", + "test": "def check(candidate):\n assert candidate(2) == 4.5\n assert candidate(3) == 12\n assert candidate(1) == 1", + "entry_point": "find_Average_Of_Cube" + }, + { + "task_id": "MBPP/205", + "prompt": "def extract_rear(test_tuple):\n '''Write a function to extract only the rear index element of each string in the given tuple. Return the list of rear elements.\n '''\n", + "test": "def check(candidate):\n assert candidate(('Mers', 'for', 'Vers') ) == ['s', 'r', 's']\n assert candidate(('Avenge', 'for', 'People') ) == ['e', 'r', 'e']\n assert candidate(('Gotta', 'get', 'go') ) == ['a', 't', 'o']", + "entry_point": "extract_rear" + }, + { + "task_id": "MBPP/206", + "prompt": "def count_element_in_list(list1, x):\n '''Write a function to count the number of sublists containing a particular element. Return the count as an integer.\n '''\n", + "test": "def check(candidate):\n assert candidate([[1, 3], [5, 7], [1, 11], [1, 15, 7]],1)==3\n assert candidate([['A', 'B'], ['A', 'C'], ['A', 'D', 'E'], ['B', 'C', 'D']],'A')==3\n assert candidate([['A', 'B'], ['A', 'C'], ['A', 'D', 'E'], ['B', 'C', 'D']],'E')==1", + "entry_point": "count_element_in_list" + }, + { + "task_id": "MBPP/207", + "prompt": "def filter_oddnumbers(nums):\n '''Write a function to filter odd numbers. Return the list of odd numbers.\n '''\n", + "test": "def check(candidate):\n assert candidate([1, 2, 3, 4, 5, 6, 7, 8, 9, 10])==[1,3,5,7,9]\n assert candidate([10,20,45,67,84,93])==[45,67,93]\n assert candidate([5,7,9,8,6,4,3])==[5,7,9,3]", + "entry_point": "filter_oddnumbers" + }, + { + "task_id": "MBPP/208", + "prompt": "def change_date_format(dt):\n '''Write a function to convert a date of yyyy-mm-dd format to dd-mm-yyyy format. Return the converted string.\n '''\n", + "test": "def check(candidate):\n assert candidate(\"2026-01-02\") == '02-01-2026'\n assert candidate(\"2020-11-13\") == '13-11-2020'\n assert candidate(\"2021-04-26\") == '26-04-2021'", + "entry_point": "change_date_format" + }, + { + "task_id": "MBPP/209", + "prompt": "def shell_sort(my_list):\n '''Write a function to sort the given array by using shell sort. Return the sorted list.\n '''\n", + "test": "def check(candidate):\n assert candidate([12, 23, 4, 5, 3, 2, 12, 81, 56, 95]) == [2, 3, 4, 5, 12, 12, 23, 56, 81, 95]\n assert candidate([24, 22, 39, 34, 87, 73, 68]) == [22, 24, 34, 39, 68, 73, 87]\n assert candidate([32, 30, 16, 96, 82, 83, 74]) == [16, 30, 32, 74, 82, 83, 96]", + "entry_point": "shell_sort" + }, + { + "task_id": "MBPP/210", + "prompt": "def and_tuples(test_tup1, test_tup2):\n '''Write a function to extract the elementwise and tuples from the given two tuples. Return the resulting tuple.\n '''\n", + "test": "def check(candidate):\n assert candidate((10, 4, 6, 9), (5, 2, 3, 3)) == (0, 0, 2, 1)\n assert candidate((1, 2, 3, 4), (5, 6, 7, 8)) == (1, 2, 3, 0)\n assert candidate((8, 9, 11, 12), (7, 13, 14, 17)) == (0, 9, 10, 0)", + "entry_point": "and_tuples" + }, + { + "task_id": "MBPP/211", + "prompt": "def parabola_directrix(a, b, c):\n '''Write a function to find the directrix of a parabola. Return the calculated integer value.\n '''\n", + "test": "def check(candidate):\n assert candidate(5,3,2)==-198\n assert candidate(9,8,4)==-2336\n assert candidate(2,4,6)==-130", + "entry_point": "parabola_directrix" + }, + { + "task_id": "MBPP/212", + "prompt": "def common_element(list1, list2):\n '''Write a function that takes two lists and returns true if they have at least one common element. Return True if the condition is met, None otherwise.\n '''\n", + "test": "def check(candidate):\n assert candidate([1,2,3,4,5], [5,6,7,8,9])==True\n assert candidate([1,2,3,4,5], [6,7,8,9])==None\n assert candidate(['a','b','c'], ['d','b','e'])==True", + "entry_point": "common_element" + }, + { + "task_id": "MBPP/213", + "prompt": "def median_trapezium(base1,base2,height):\n '''Write a function to find the median length of a trapezium. Return the calculated integer value.\n '''\n", + "test": "def check(candidate):\n assert candidate(15,25,35)==20\n assert candidate(10,20,30)==15\n assert candidate(6,9,4)==7.5", + "entry_point": "median_trapezium" + }, + { + "task_id": "MBPP/214", + "prompt": "def check_greater(arr, number):\n '''Write a function to check whether the entered number is greater than the elements of the given array. Return True if the condition is met, False otherwise.\n '''\n", + "test": "def check(candidate):\n assert candidate([1, 2, 3, 4, 5], 4) == False\n assert candidate([2, 3, 4, 5, 6], 8) == True\n assert candidate([9, 7, 4, 8, 6, 1], 11) == True", + "entry_point": "check_greater" + }, + { + "task_id": "MBPP/215", + "prompt": "def text_match_one(text):\n '''Write a function that matches a string that has an a followed by one or more b's. Return True if the condition is met, False otherwise.\n '''\n", + "test": "def check(candidate):\n assert candidate(\"ac\")==False\n assert candidate(\"dc\")==False\n assert candidate(\"abba\")==True", + "entry_point": "text_match_one" + }, + { + "task_id": "MBPP/216", + "prompt": "def last_Digit(n):\n '''Write a python function to find the last digit of a given number. Return the calculated integer value.\n '''\n", + "test": "def check(candidate):\n assert candidate(123) == 3\n assert candidate(25) == 5\n assert candidate(30) == 0", + "entry_point": "last_Digit" + }, + { + "task_id": "MBPP/217", + "prompt": "def neg_nos(list1):\n '''Write a python function to return the negative numbers in a list. Return the list of negative numbers.\n '''\n", + "test": "def check(candidate):\n assert candidate([-1,4,5,-6]) == [-1,-6]\n assert candidate([-1,-2,3,4]) == [-1,-2]\n assert candidate([-7,-6,8,9]) == [-7,-6]", + "entry_point": "neg_nos" + }, + { + "task_id": "MBPP/218", + "prompt": "def remove_odd_1(str1):\n '''Write a function to remove odd characters in a string. Return the modified string.\n '''\n", + "test": "def check(candidate):\n assert candidate(\"python\")==(\"yhn\")\n assert candidate(\"program\")==(\"rga\")\n assert candidate(\"language\")==(\"agae\")", + "entry_point": "remove_odd_1" + }, + { + "task_id": "MBPP/219", + "prompt": "def count_bidirectional(test_list):\n '''Write a function to count bidirectional tuple pairs. Return the count as an integer.\n '''\n", + "test": "def check(candidate):\n assert candidate([(5, 6), (1, 2), (6, 5), (9, 1), (6, 5), (2, 1)] ) == 3\n assert candidate([(5, 6), (1, 3), (6, 5), (9, 1), (6, 5), (2, 1)] ) == 2\n assert candidate([(5, 6), (1, 2), (6, 5), (9, 2), (6, 5), (2, 1)] ) == 4", + "entry_point": "count_bidirectional" + }, + { + "task_id": "MBPP/220", + "prompt": "def multiple_to_single(L):\n '''Write a function to join a list of multiple integers into a single integer. Return the calculated integer value.\n '''\n", + "test": "def check(candidate):\n assert candidate([11, 33, 50])==113350\n assert candidate([-1,2,3,4,5,6])==-123456\n assert candidate([10,15,20,25])==10152025", + "entry_point": "multiple_to_single" + }, + { + "task_id": "MBPP/221", + "prompt": "def find_adverb_position(text):\n '''Write a function to find the first adverb and their positions in a given sentence. Return the position tuple.\n '''\n", + "test": "def check(candidate):\n assert candidate(\"clearly!! we can see the sky\")==(0, 7, 'clearly')\n assert candidate(\"seriously!! there are many roses\")==(0, 9, 'seriously')\n assert candidate(\"unfortunately!! sita is going to home\")==(0, 13, 'unfortunately')", + "entry_point": "find_adverb_position" + }, + { + "task_id": "MBPP/222", + "prompt": "def surfacearea_cube(l):\n '''Write a function to find the surface area of a cube of a given size. Return the calculated integer value.\n '''\n", + "test": "def check(candidate):\n assert candidate(5)==150\n assert candidate(3)==54\n assert candidate(10)==600", + "entry_point": "surfacearea_cube" + }, + { + "task_id": "MBPP/223", + "prompt": "def positive_count(nums):\n '''Write a function to find the ration of positive numbers in an array of integers. Return the calculated number.\n '''\n", + "test": "def check(candidate):\n assert candidate([0, 1, 2, -1, -5, 6, 0, -3, -2, 3, 4, 6, 8])==0.54\n assert candidate([2, 1, 2, -1, -5, 6, 4, -3, -2, 3, 4, 6, 8])==0.69\n assert candidate([2, 4, -6, -9, 11, -12, 14, -5, 17])==0.56", + "entry_point": "positive_count" + }, + { + "task_id": "MBPP/224", + "prompt": "def largest_neg(list1):\n '''Write a python function to find the largest negative number from the given list. Return the largest negative number.\n '''\n", + "test": "def check(candidate):\n assert candidate([1,2,3,-4,-6]) == -6\n assert candidate([1,2,3,-8,-9]) == -9\n assert candidate([1,2,3,4,-1]) == -1", + "entry_point": "largest_neg" + }, + { + "task_id": "MBPP/225", + "prompt": "def trim_tuple(test_list, K):\n '''Write a function to trim each tuple by k in the given tuple list. Return the string representation of the trimmed tuples.\n '''\n", + "test": "def check(candidate):\n assert candidate([(5, 3, 2, 1, 4), (3, 4, 9, 2, 1),(9, 1, 2, 3, 5), (4, 8, 2, 1, 7)], 2) == '[(2,), (9,), (2,), (2,)]'\n assert candidate([(5, 3, 2, 1, 4), (3, 4, 9, 2, 1), (9, 1, 2, 3, 5), (4, 8, 2, 1, 7)], 1) == '[(3, 2, 1), (4, 9, 2), (1, 2, 3), (8, 2, 1)]'\n assert candidate([(7, 8, 4, 9), (11, 8, 12, 4),(4, 1, 7, 8), (3, 6, 9, 7)], 1) == '[(8, 4), (8, 12), (1, 7), (6, 9)]'", + "entry_point": "trim_tuple" + }, + { + "task_id": "MBPP/226", + "prompt": "def index_multiplication(test_tup1, test_tup2):\n '''Write a function to perform index wise multiplication of tuple elements in the given two tuples. Return the resulting tuple.\n '''\n", + "test": "def check(candidate):\n assert candidate(((1, 3), (4, 5), (2, 9), (1, 10)),((6, 7), (3, 9), (1, 1), (7, 3)) ) == ((6, 21), (12, 45), (2, 9), (7, 30))\n assert candidate(((2, 4), (5, 6), (3, 10), (2, 11)),((7, 8), (4, 10), (2, 2), (8, 4)) ) == ((14, 32), (20, 60), (6, 20), (16, 44))\n assert candidate(((3, 5), (6, 7), (4, 11), (3, 12)),((8, 9), (5, 11), (3, 3), (9, 5)) ) == ((24, 45), (30, 77), (12, 33), (27, 60))", + "entry_point": "index_multiplication" + }, + { + "task_id": "MBPP/227", + "prompt": "def count_Occurrence(tup, lst):\n '''Write a python function to count the occurence of all elements of list in a tuple. Return the count as an integer.\n '''\n", + "test": "def check(candidate):\n assert candidate(('a', 'a', 'c', 'b', 'd'),['a', 'b'] ) == 3\n assert candidate((1, 2, 3, 1, 4, 6, 7, 1, 4),[1, 4, 7]) == 6\n assert candidate((1,2,3,4,5,6),[1,2]) == 2", + "entry_point": "count_Occurrence" + }, + { + "task_id": "MBPP/228", + "prompt": "def cube_nums(nums):\n '''Write a function to find cubes of individual elements in a list. Return the list of cubes.\n '''\n", + "test": "def check(candidate):\n assert candidate([1, 2, 3, 4, 5, 6, 7, 8, 9, 10])==[1, 8, 27, 64, 125, 216, 343, 512, 729, 1000]\n assert candidate([10,20,30])==([1000, 8000, 27000])\n assert candidate([12,15])==([1728, 3375])", + "entry_point": "cube_nums" + }, + { + "task_id": "MBPP/229", + "prompt": "def cal_sum(n):\n '''Write a function to calculate the sum of perrin numbers. Return the sum as an integer.\n '''\n", + "test": "def check(candidate):\n assert candidate(9) == 49\n assert candidate(10) == 66\n assert candidate(11) == 88", + "entry_point": "cal_sum" + }, + { + "task_id": "MBPP/230", + "prompt": "def extract_string(str, l):\n '''Write a function to extract specified size of strings from a given list of string values. Return the list of extracted strings.\n '''\n", + "test": "def check(candidate):\n assert candidate(['Python', 'list', 'exercises', 'practice', 'solution'] ,8)==['practice', 'solution']\n assert candidate(['Python', 'list', 'exercises', 'practice', 'solution'] ,6)==['Python']\n assert candidate(['Python', 'list', 'exercises', 'practice', 'solution'] ,9)==['exercises']", + "entry_point": "extract_string" + }, + { + "task_id": "MBPP/231", + "prompt": "def remove_whitespaces(text1):\n '''Write a function to remove all whitespaces from the given string. Return the modified string.\n '''\n", + "test": "def check(candidate):\n assert candidate(' Google Flutter ') == 'GoogleFlutter'\n assert candidate(' Google Dart ') == 'GoogleDart'\n assert candidate(' iOS Swift ') == 'iOSSwift'", + "entry_point": "remove_whitespaces" + }, + { + "task_id": "MBPP/232", + "prompt": "def loss_amount(actual_cost,sale_amount):\n '''Write a function that gives loss amount on a sale if the given amount has loss else return 0. Return the loss amount as an integer.\n '''\n", + "test": "def check(candidate):\n assert candidate(1500,1200)==0\n assert candidate(100,200)==100\n assert candidate(2000,5000)==3000", + "entry_point": "loss_amount" + }, + { + "task_id": "MBPP/233", + "prompt": "def sumofFactors(n):\n '''Write a python function to find the sum of even factors of a number. Return the sum as an integer.\n '''\n", + "test": "def check(candidate):\n assert candidate(18) == 26\n assert candidate(30) == 48\n assert candidate(6) == 8", + "entry_point": "sumofFactors" + }, + { + "task_id": "MBPP/234", + "prompt": "def text_match_wordz(text):\n '''Write a function that matches a word containing 'z'. Return True if the condition is met, False otherwise.\n '''\n", + "test": "def check(candidate):\n assert candidate(\"pythonz.\")==True\n assert candidate(\"xyz.\")==True\n assert candidate(\" lang .\")==False", + "entry_point": "text_match_wordz" + }, + { + "task_id": "MBPP/235", + "prompt": "def check_monthnumb_number(monthnum2):\n '''Write a function to check whether the given month number contains 31 days or not. Return True if the condition is met, False otherwise.\n '''\n", + "test": "def check(candidate):\n assert candidate(5)==True\n assert candidate(2)==False\n assert candidate(6)==False", + "entry_point": "check_monthnumb_number" + }, + { + "task_id": "MBPP/236", + "prompt": "def reverse_string_list(stringlist):\n '''Write a function to reverse each string in a given list of string values. Return the list of reversed strings.\n '''\n", + "test": "def check(candidate):\n assert candidate(['Red', 'Green', 'Blue', 'White', 'Black'])==['deR', 'neerG', 'eulB', 'etihW', 'kcalB']\n assert candidate(['john','amal','joel','george'])==['nhoj','lama','leoj','egroeg']\n assert candidate(['jack','john','mary'])==['kcaj','nhoj','yram']", + "entry_point": "reverse_string_list" + }, + { + "task_id": "MBPP/237", + "prompt": "def Find_Min(lst):\n '''Write a python function to find the sublist having minimum length. Return the sublist with minimum length.\n '''\n", + "test": "def check(candidate):\n assert candidate([[1],[1,2],[1,2,3]]) == [1]\n assert candidate([[1,1],[1,1,1],[1,2,7,8]]) == [1,1]\n assert candidate([['x'],['x','y'],['x','y','z']]) == ['x']", + "entry_point": "Find_Min" + }, + { + "task_id": "MBPP/238", + "prompt": "def rectangle_area(l,b):\n '''Write a function to find the area of a rectangle. Return the calculated integer value.\n '''\n", + "test": "def check(candidate):\n assert candidate(10,20)==200\n assert candidate(10,5)==50\n assert candidate(4,2)==8", + "entry_point": "rectangle_area" + }, + { + "task_id": "MBPP/239", + "prompt": "def remove_uppercase(str1):\n '''Write a function to remove uppercase substrings from a given string. Return the modified string.\n '''\n", + "test": "def check(candidate):\n assert candidate('cAstyoUrFavoRitETVshoWs') == 'cstyoravoitshos'\n assert candidate('wAtchTheinTernEtrAdIo') == 'wtchheinerntrdo'\n assert candidate('VoicESeaRchAndreComMendaTionS') == 'oiceachndreomendaion'", + "entry_point": "remove_uppercase" + }, + { + "task_id": "MBPP/240", + "prompt": "def Extract(lst):\n '''Write a python function to get the first element of each sublist. Return the list of first elements.\n '''\n", + "test": "def check(candidate):\n assert candidate([[1, 2], [3, 4, 5], [6, 7, 8, 9]]) == [1, 3, 6]\n assert candidate([[1,2,3],[4, 5]]) == [1,4]\n assert candidate([[9,8,1],[1,2]]) == [9,1]", + "entry_point": "Extract" + }, + { + "task_id": "MBPP/241", + "prompt": "def upper_ctr(str):\n '''Write a python function to count the upper case characters in a given string. Return the count as an integer.\n '''\n", + "test": "def check(candidate):\n assert candidate('PYthon') == 1\n assert candidate('BigData') == 1\n assert candidate('program') == 0", + "entry_point": "upper_ctr" + }, + { + "task_id": "MBPP/242", + "prompt": "def combinations_list(list1):\n '''Write a function to find all possible combinations of the elements of a given list. Return the list of all combinations.\n '''\n", + "test": "def check(candidate):\n assert candidate(['orange', 'red', 'green', 'blue'])==[[], ['orange'], ['red'], ['red', 'orange'], ['green'], ['green', 'orange'], ['green', 'red'], ['green', 'red', 'orange'], ['blue'], ['blue', 'orange'], ['blue', 'red'], ['blue', 'red', 'orange'], ['blue', 'green'], ['blue', 'green', 'orange'], ['blue', 'green', 'red'], ['blue', 'green', 'red', 'orange']]\n assert candidate(['red', 'green', 'blue', 'white', 'black', 'orange'])==[[], ['red'], ['green'], ['green', 'red'], ['blue'], ['blue', 'red'], ['blue', 'green'], ['blue', 'green', 'red'], ['white'], ['white', 'red'], ['white', 'green'], ['white', 'green', 'red'], ['white', 'blue'], ['white', 'blue', 'red'], ['white', 'blue', 'green'], ['white', 'blue', 'green', 'red'], ['black'], ['black', 'red'], ['black', 'green'], ['black', 'green', 'red'], ['black', 'blue'], ['black', 'blue', 'red'], ['black', 'blue', 'green'], ['black', 'blue', 'green', 'red'], ['black', 'white'], ['black', 'white', 'red'], ['black', 'white', 'green'], ['black', 'white', 'green', 'red'], ['black', 'white', 'blue'], ['black', 'white', 'blue', 'red'], ['black', 'white', 'blue', 'green'], ['black', 'white', 'blue', 'green', 'red'], ['orange'], ['orange', 'red'], ['orange', 'green'], ['orange', 'green', 'red'], ['orange', 'blue'], ['orange', 'blue', 'red'], ['orange', 'blue', 'green'], ['orange', 'blue', 'green', 'red'], ['orange', 'white'], ['orange', 'white', 'red'], ['orange', 'white', 'green'], ['orange', 'white', 'green', 'red'], ['orange', 'white', 'blue'], ['orange', 'white', 'blue', 'red'], ['orange', 'white', 'blue', 'green'], ['orange', 'white', 'blue', 'green', 'red'], ['orange', 'black'], ['orange', 'black', 'red'], ['orange', 'black', 'green'], ['orange', 'black', 'green', 'red'], ['orange', 'black', 'blue'], ['orange', 'black', 'blue', 'red'], ['orange', 'black', 'blue', 'green'], ['orange', 'black', 'blue', 'green', 'red'], ['orange', 'black', 'white'], ['orange', 'black', 'white', 'red'], ['orange', 'black', 'white', 'green'], ['orange', 'black', 'white', 'green', 'red'], ['orange', 'black', 'white', 'blue'], ['orange', 'black', 'white', 'blue', 'red'], ['orange', 'black', 'white', 'blue', 'green'], ['orange', 'black', 'white', 'blue', 'green', 'red']]\n assert candidate(['red', 'green', 'black', 'orange'])==[[], ['red'], ['green'], ['green', 'red'], ['black'], ['black', 'red'], ['black', 'green'], ['black', 'green', 'red'], ['orange'], ['orange', 'red'], ['orange', 'green'], ['orange', 'green', 'red'], ['orange', 'black'], ['orange', 'black', 'red'], ['orange', 'black', 'green'], ['orange', 'black', 'green', 'red']]", + "entry_point": "combinations_list" + }, + { + "task_id": "MBPP/243", + "prompt": "def max_subarray_product(arr):\n '''Write a function to find the maximum product subarray of the given array. Return the calculated integer value.\n '''\n", + "test": "def check(candidate):\n assert candidate([1, -2, -3, 0, 7, -8, -2]) == 112\n assert candidate([6, -3, -10, 0, 2]) == 180\n assert candidate([-2, -40, 0, -2, -3]) == 80", + "entry_point": "max_subarray_product" + }, + { + "task_id": "MBPP/244", + "prompt": "def check_value(dict, n):\n '''Write a function to check if all values are same in a dictionary. Return True if the condition is met, False otherwise.\n '''\n", + "test": "def check(candidate):\n assert candidate({'Cierra Vega': 12, 'Alden Cantrell': 12, 'Kierra Gentry': 12, 'Pierre Cox': 12},10)==False\n assert candidate({'Cierra Vega': 12, 'Alden Cantrell': 12, 'Kierra Gentry': 12, 'Pierre Cox': 12},12)==True\n assert candidate({'Cierra Vega': 12, 'Alden Cantrell': 12, 'Kierra Gentry': 12, 'Pierre Cox': 12},5)==False", + "entry_point": "check_value" + }, + { + "task_id": "MBPP/245", + "prompt": "def drop_empty(dict1):\n '''Write a function to drop empty items from a given dictionary. Return the filtered dictionary.\n '''\n", + "test": "def check(candidate):\n assert candidate({'c1': 'Red', 'c2': 'Green', 'c3':None})=={'c1': 'Red', 'c2': 'Green'}\n assert candidate({'c1': 'Red', 'c2': None, 'c3':None})=={'c1': 'Red'}\n assert candidate({'c1': None, 'c2': 'Green', 'c3':None})=={ 'c2': 'Green'}", + "entry_point": "drop_empty" + }, + { + "task_id": "MBPP/246", + "prompt": "def max_product(arr):\n '''Write a function to find the maximum product formed by multiplying numbers of an increasing subsequence of that array. Return the calculated integer value.\n '''\n", + "test": "def check(candidate):\n assert candidate([3, 100, 4, 5, 150, 6]) == 3000\n assert candidate([4, 42, 55, 68, 80]) == 50265600\n assert candidate([10, 22, 9, 33, 21, 50, 41, 60]) == 2460", + "entry_point": "max_product" + }, + { + "task_id": "MBPP/247", + "prompt": "def add_pairwise(test_tup):\n '''Write a function to find the pairwise addition of the neighboring elements of the given tuple. Return the resulting tuple.\n '''\n", + "test": "def check(candidate):\n assert candidate((1, 5, 7, 8, 10)) == (6, 12, 15, 18)\n assert candidate((2, 6, 8, 9, 11)) == (8, 14, 17, 20)\n assert candidate((3, 7, 9, 10, 12)) == (10, 16, 19, 22)", + "entry_point": "add_pairwise" + }, + { + "task_id": "MBPP/248", + "prompt": "def find_remainder(arr, n):\n '''Write a python function to find the product of the array multiplication modulo n. Return the calculated integer value.\n '''\n", + "test": "def check(candidate):\n assert candidate([ 100, 10, 5, 25, 35, 14 ],11) ==9\n assert candidate([1,1,1],1) == 0\n assert candidate([1,2,1],2) == 0", + "entry_point": "find_remainder" + }, + { + "task_id": "MBPP/249", + "prompt": "def check_Consecutive(l):\n '''Write a python function to check whether the given list contains consecutive numbers or not. Return True if the condition is met, False otherwise.\n '''\n", + "test": "def check(candidate):\n assert candidate([1,2,3,4,5]) == True\n assert candidate([1,2,3,5,6]) == False\n assert candidate([1,2,1]) == False", + "entry_point": "check_Consecutive" + }, + { + "task_id": "MBPP/250", + "prompt": "def tuple_intersection(test_list1, test_list2):\n '''Write a function to find the tuple intersection of elements in the given tuple list irrespective of their order. Return the set of intersecting tuples.\n '''\n", + "test": "def check(candidate):\n assert candidate([(3, 4), (5, 6), (9, 10), (4, 5)] , [(5, 4), (3, 4), (6, 5), (9, 11)]) == {(4, 5), (3, 4), (5, 6)}\n assert candidate([(4, 1), (7, 4), (11, 13), (17, 14)] , [(1, 4), (7, 4), (16, 12), (10, 13)]) == {(4, 7), (1, 4)}\n assert candidate([(2, 1), (3, 2), (1, 3), (1, 4)] , [(11, 2), (2, 3), (6, 2), (1, 3)]) == {(1, 3), (2, 3)}", + "entry_point": "tuple_intersection" + }, + { + "task_id": "MBPP/251", + "prompt": "def replace_char(str1,ch,newch):\n '''Write a function to replace characters in a string. Return the modified string.\n '''\n", + "test": "def check(candidate):\n assert candidate(\"polygon\",'y','l')==(\"pollgon\")\n assert candidate(\"character\",'c','a')==(\"aharaater\")\n assert candidate(\"python\",'l','a')==(\"python\")", + "entry_point": "replace_char" + }, + { + "task_id": "MBPP/252", + "prompt": "def sort_counter(dict1):\n '''Write a function to sort a dictionary by value. Return the sorted list of tuples.\n '''\n", + "test": "def check(candidate):\n assert candidate({'Math':81, 'Physics':83, 'Chemistry':87})==[('Chemistry', 87), ('Physics', 83), ('Math', 81)]\n assert candidate({'Math':400, 'Physics':300, 'Chemistry':250})==[('Math', 400), ('Physics', 300), ('Chemistry', 250)]\n assert candidate({'Math':900, 'Physics':1000, 'Chemistry':1250})==[('Chemistry', 1250), ('Physics', 1000), ('Math', 900)]", + "entry_point": "sort_counter" + }, + { + "task_id": "MBPP/253", + "prompt": "def big_sum(nums):\n '''Write a python function to find the sum of the largest and smallest value in a given array. Return the sum as an integer.\n '''\n", + "test": "def check(candidate):\n assert candidate([1,2,3]) == 4\n assert candidate([-1,2,3,4]) == 3\n assert candidate([2,3,6]) == 8", + "entry_point": "big_sum" + }, + { + "task_id": "MBPP/254", + "prompt": "def is_lower(string):\n '''Write a python function to convert the given string to lower case. Return the converted string.\n '''\n", + "test": "def check(candidate):\n assert candidate(\"InValid\") == \"invalid\"\n assert candidate(\"TruE\") == \"true\"\n assert candidate(\"SenTenCE\") == \"sentence\"", + "entry_point": "is_lower" + }, + { + "task_id": "MBPP/255", + "prompt": "def remove_lowercase(str1):\n '''Write a function to remove lowercase substrings from a given string. Return the modified string.\n '''\n", + "test": "def check(candidate):\n assert candidate(\"PYTHon\")==('PYTH')\n assert candidate(\"FInD\")==('FID')\n assert candidate(\"STRinG\")==('STRG')", + "entry_point": "remove_lowercase" + }, + { + "task_id": "MBPP/256", + "prompt": "def first_Digit(n):\n '''Write a python function to find the first digit of a given number. Return the calculated integer value.\n '''\n", + "test": "def check(candidate):\n assert candidate(123) == 1\n assert candidate(456) == 4\n assert candidate(12) == 1", + "entry_point": "first_Digit" + } +] \ No newline at end of file diff --git a/loggers/code_logger.py b/loggers/code_logger.py index 3b00b28..c9e99d1 100644 --- a/loggers/code_logger.py +++ b/loggers/code_logger.py @@ -1,6 +1,6 @@ import signal -import numpy as np + from rewards.code_utils import ( TimeoutException, diff --git a/plotting/plot_che.py b/plotting/plot_che.py index d4ae17b..ce3611c 100644 --- a/plotting/plot_che.py +++ b/plotting/plot_che.py @@ -4,7 +4,7 @@ import numpy as np import pandas as pd import PyPDF2 -import seaborn as sns + import wandb from PyPDF2.generic import RectangleObject @@ -12,6 +12,10 @@ def plot_combined_three_panels(): """Plot all three panels: Single Turn, MT Turn 1, and MT Turn 2 with larger fonts.""" + # ------------------------------------------------------------------ + # W&B setup and project selection + # ------------------------------------------------------------------ + # Initialize wandb API api = wandb.Api() @@ -31,7 +35,9 @@ def plot_combined_three_panels(): print(f"Found {len(humaneval_runs)} CoopHumanEval runs") print(f"Found {len(mt_expert_runs)} Multi-Turn Expert HumanEval runs") - # Process single-turn data + # ------------------------------------------------------------------ + # Fetch and clean histories + # ------------------------------------------------------------------ single_turn_data = [] for run in humaneval_runs: try: diff --git a/plotting/plot_che_return.py b/plotting/plot_che_return.py index aa865e7..43ccaf2 100644 --- a/plotting/plot_che_return.py +++ b/plotting/plot_che_return.py @@ -4,7 +4,7 @@ import numpy as np import pandas as pd import PyPDF2 -import seaborn as sns + import wandb from PyPDF2.generic import RectangleObject @@ -12,6 +12,10 @@ def plot_combined_two_panels(): """Plot two panels: Single Turn and Multi-Turn Average (Turn1+Turn2)/2 with larger fonts.""" + # ------------------------------------------------------------------ + # W&B setup and project selection + # ------------------------------------------------------------------ + # Initialize wandb API api = wandb.Api() @@ -31,7 +35,9 @@ def plot_combined_two_panels(): print(f"Found {len(humaneval_runs)} CoopHumanEval runs") print(f"Found {len(mt_expert_runs)} Multi-Turn Expert HumanEval runs") - # Process single-turn data + # ------------------------------------------------------------------ + # Fetch and clean histories + # ------------------------------------------------------------------ single_turn_data = [] for run in humaneval_runs: try: @@ -145,7 +151,9 @@ def plot_combined_two_panels(): df_mt["eval/turn_1/avg_total_reward"] + df_mt["eval/turn_2/avg_total_reward"] ) / 2 - # Create uniform x-axis for single-turn + # ------------------------------------------------------------------ + # Create uniform x-axis for single-turn and multi-turn + # ------------------------------------------------------------------ single_run_lengths = df_single.groupby("run_id").size() single_min_length = single_run_lengths.min() diff --git a/plotting/plot_he.py b/plotting/plot_he.py index a88381e..1c1ab84 100644 --- a/plotting/plot_he.py +++ b/plotting/plot_he.py @@ -4,7 +4,7 @@ import numpy as np import pandas as pd import PyPDF2 -import seaborn as sns + import wandb from PyPDF2.generic import RectangleObject @@ -12,6 +12,10 @@ def plot_combined_three_panels(): """Plot all three panels: Single Turn, MT Turn 1, and MT Turn 2 with larger fonts.""" + # ------------------------------------------------------------------ + # W&B setup and project selection + # ------------------------------------------------------------------ + # Initialize wandb API api = wandb.Api() @@ -31,7 +35,9 @@ def plot_combined_three_panels(): print(f"Found {len(humaneval_runs)} HumanEval runs") print(f"Found {len(mt_expert_runs)} Multi-Turn Expert HumanEval runs") - # Process single-turn data + # ------------------------------------------------------------------ + # Fetch and clean histories + # ------------------------------------------------------------------ single_turn_data = [] for run in humaneval_runs: try: diff --git a/plotting/plot_he_return.py b/plotting/plot_he_return.py index 582f7cb..f049fff 100644 --- a/plotting/plot_he_return.py +++ b/plotting/plot_he_return.py @@ -1,10 +1,10 @@ -import time + import matplotlib.pyplot as plt import numpy as np import pandas as pd import PyPDF2 -import seaborn as sns + import wandb from PyPDF2.generic import RectangleObject @@ -12,6 +12,10 @@ def plot_combined_two_panels(plot_id=1): """Plot two panels: Single Turn and Multi-Turn Average (Turn1+Turn2)/2 with larger fonts.""" + # ------------------------------------------------------------------ + # W&B setup and project selection + # ------------------------------------------------------------------ + # Initialize wandb API api = wandb.Api() @@ -31,7 +35,9 @@ def plot_combined_two_panels(plot_id=1): print(f"Found {len(humaneval_runs)} HumanEval runs") print(f"Found {len(mt_expert_runs)} Multi-Turn Expert HumanEval runs") - # Process single-turn data + # ------------------------------------------------------------------ + # Fetch and clean histories + # ------------------------------------------------------------------ single_turn_data = [] for run in humaneval_runs: try: diff --git a/plotting/plot_improve.py b/plotting/plot_improve.py index e05f081..519b085 100644 --- a/plotting/plot_improve.py +++ b/plotting/plot_improve.py @@ -3,13 +3,17 @@ import matplotlib.pyplot as plt import numpy as np import pandas as pd -import seaborn as sns + import wandb def plot_mt_humaneval_reward(): """Plot Multi-turn HumanEval rewards from WandB runs with 2 curves: Turn 1 and Turn 2.""" + # ------------------------------------------------------------------ + # W&B setup and project selection + # ------------------------------------------------------------------ + # Initialize wandb API api = wandb.Api() @@ -25,7 +29,9 @@ def plot_mt_humaneval_reward(): print(f"Found {len(mt_humaneval_runs)} Multi-turn HumanEval runs") - # Collect data from all runs + # ------------------------------------------------------------------ + # Collect and normalize run histories + # ------------------------------------------------------------------ all_data = [] for run in mt_humaneval_runs: try: @@ -77,7 +83,9 @@ def plot_mt_humaneval_reward(): min_length = run_lengths.min() print(f"Run lengths: min={min_length}, max={run_lengths.max()}") - # Create uniform x-axis mapping for each run, truncated to minimum length + # ------------------------------------------------------------------ + # Create uniform x-axis per run (truncate to min length for alignment) + # ------------------------------------------------------------------ uniform_data = [] for run_id in df["run_id"].unique(): run_data = df[df["run_id"] == run_id].copy() @@ -113,8 +121,9 @@ def plot_mt_humaneval_reward(): "turn2_std", ] - # Scale the aggregated means and stds - # Assuming total reward range is 0-4 as in the original code + # ------------------------------------------------------------------ + # Scale aggregated means and stds (total reward assumed in [0,4]) + # ------------------------------------------------------------------ total_min, total_max = 0, 4.0 # Scale Turn 1 @@ -129,7 +138,9 @@ def plot_mt_humaneval_reward(): ) aggregated["turn2_std_scaled"] = aggregated["turn2_std"] / (total_max - total_min) - # Create plot + # ------------------------------------------------------------------ + # Plot with bounded error bands + # ------------------------------------------------------------------ plt.style.use("default") fig, ax = plt.subplots(figsize=(7, 4)) diff --git a/plotting/posg_1.py b/plotting/posg_1.py index d8e918c..4f93c47 100644 --- a/plotting/posg_1.py +++ b/plotting/posg_1.py @@ -1,4 +1,4 @@ -from itertools import product + import matplotlib.pyplot as plt import numpy as np @@ -7,7 +7,9 @@ plt.rcParams["font.sans-serif"] = ["SimHei", "Arial Unicode MS", "DejaVu Sans"] plt.rcParams["axes.unicode_minus"] = False -# Create figure +# ------------------------------------------------------------------ +# Figure setup +# ------------------------------------------------------------------ fig, ax = plt.subplots(1, 1, figsize=(5, 5)) # Chicken Game payoff matrix with modified values diff --git a/plotting/posg_2.py b/plotting/posg_2.py index 129a4fb..330e0bb 100644 --- a/plotting/posg_2.py +++ b/plotting/posg_2.py @@ -1,4 +1,4 @@ -from itertools import product + import matplotlib.pyplot as plt import numpy as np diff --git a/train_grpo.py b/train_grpo.py index a2376ec..dfd4ba1 100644 --- a/train_grpo.py +++ b/train_grpo.py @@ -105,7 +105,7 @@ def execution_reward_single_agent(completions, batch_items=None): return raw_rewards -## Removed dead factory create_execution_reward_function (unused) + def get_formatter(dataset_type: str): @@ -118,6 +118,7 @@ def get_formatter(dataset_type: str): formatters_map = { "humaneval": complete_function_formatter, "coophumaneval": complete_function_formatter, + "mbpp": complete_function_formatter, } return formatters_map.get(dataset_type.lower(), complete_function_formatter) @@ -129,7 +130,7 @@ def get_reward_function(dataset_type: str): "dataset.type not specified in config. Please add 'type: humaneval/coophumaneval' to the dataset section." ) - if dataset_type.lower() in ["humaneval", "coophumaneval"]: + if dataset_type.lower() in ["humaneval", "coophumaneval", "mbpp"]: return execution_reward_single_agent else: raise ValueError(f"Unknown dataset type: {dataset_type}") @@ -144,6 +145,9 @@ def main(): args = parser.parse_args() + # ------------------------------------------------------------------ + # Config: load YAML and apply overrides + # ------------------------------------------------------------------ if args.config: config = Config(args.config) else: @@ -157,14 +161,14 @@ def main(): overrides = parse_overrides(args.override) config.update(overrides) - - - # Load model configuration + # ------------------------------------------------------------------ + # Config: model, dataset, output + # ------------------------------------------------------------------ model_config = config.get_model_config() model_name = model_config.name - output_base_dir = config.get("output.base_dir") dataset_name = config.get("dataset.name") dataset_type = config.get("dataset.type") + output_base_dir = config.get("output.base_dir") # Try to infer dataset type from dataset name if not specified if dataset_type is None: @@ -172,6 +176,8 @@ def main(): dataset_type = "humaneval" elif "coophumaneval" in dataset_name.lower() or "coop" in dataset_name.lower(): dataset_type = "coophumaneval" + elif "mbpp" in dataset_name.lower(): + dataset_type = "mbpp" else: raise ValueError( f"Could not infer dataset type from dataset name '{dataset_name}'. Please specify 'type' in dataset config." @@ -180,13 +186,12 @@ def main(): train_split = config.get("dataset.train_split") eval_split = config.get("dataset.eval_split") - # Read GRPO section early (for multi-turn flags) + # ------------------------------------------------------------------ + # Config: GRPO training params and verbosity + # ------------------------------------------------------------------ grpo_config = config.get_section("grpo") if hasattr(config, "get_section") else {} - - # Determine single vs multi-turn num_turns = grpo_config.get("num_turns", 1) is_multi_turn = num_turns > 1 - output_verbose = config.get("output.verbose", True) if output_verbose: print(f"Multi-turn GRPO enabled: num_turns={num_turns}") if is_multi_turn else print( @@ -251,7 +256,9 @@ def main(): temperature = grpo_config.get("temperature", model_config.temperature) top_p = grpo_config.get("top_p", model_config.top_p) - # External configuration (mode, sandbox, expert model, context flags) + # ------------------------------------------------------------------ + # Config: External transitions (mode, sandbox, expert model, context flags) + # ------------------------------------------------------------------ external_cfg = config.get_section("external") if hasattr(config, "get_section") else {} # Register external context resolver using dataset items (for external modes) @@ -278,7 +285,7 @@ def _normalize_prompt(p: str) -> str: else: sandbox_slice = None if _sandbox_val is None else 0 - import re as _re + # re already imported at module level def _make_sliced_assert_tests(test_code: str, n: int) -> str: if not isinstance(test_code, str) or not test_code.strip(): @@ -289,7 +296,7 @@ def _make_sliced_assert_tests(test_code: str, n: int) -> str: preamble = [] check_idx = None for idx, line in enumerate(lines): - if _re.match(r"\s*def\s+check\s*\(candidate\)\s*:\s*", line): + if re.match(r"\s*def\s+check\s*\(candidate\)\s*:\s*", line): check_idx = idx break preamble.append(line) @@ -340,6 +347,9 @@ def _resolver(prompt: str): external_ctx.set_context_resolver(_resolver) + # ------------------------------------------------------------------ + # Build training args + # ------------------------------------------------------------------ grpo_args = MAGRPOConfig( output_dir=output_dir, num_train_epochs=grpo_config.get("num_train_epochs", 10), @@ -356,18 +366,22 @@ def _resolver(prompt: str): discount=grpo_config.get("discount", 0.9), joint_mode=grpo_config.get("joint_mode", "aligned"), termination_threshold=grpo_config.get("termination_threshold", None), - # GRPO-style advantage params - normalize_advantage=grpo_config.get("normalize_advantage", False), epsilon_clip=grpo_config.get("epsilon_clip", None), ) + # ------------------------------------------------------------------ + # Formatters, rewards, and logging + # ------------------------------------------------------------------ formatter = get_formatter(dataset_type) reward_func = get_reward_function(dataset_type) + # ------------------------------------------------------------------ + # W&B configuration and tags + # ------------------------------------------------------------------ wandb_section = ( config.get_section("wandb") if hasattr(config, "get_section") else {} ) - model_short_name = model_name.split("/")[-1].lower() + # Model short name no longer used in W&B naming # Use different wandb name for multi-turn if is_multi_turn: wandb_name = wandb_section.get("name", f"mt_grpo_{dataset_type}") @@ -399,7 +413,7 @@ def _resolver(prompt: str): wandb_config = { "project": wandb_section.get("project", "mlrl"), "entity": wandb_section.get("entity", "nu-llpr"), - "name": f"{wandb_name}_{model_short_name}", + "name": f"{wandb_name}", "dir": wandb_section.get("dir", "../../../projects/bepg/sliu30"), "tags": tags, # Provide full sections for the trainer to log cleanly @@ -444,18 +458,24 @@ def _resolver(prompt: str): prev = reward_processor reward_processor = (lambda p=prev, s=shift_proc: (lambda x: s(p(x))))() - # Use agents=[model] to keep dtype and loading behavior aligned with MAGRPO + # ------------------------------------------------------------------ + # Build trainer kwargs (grouped: model/data, reward/formatting, logging, args) + # ------------------------------------------------------------------ trainer_kwargs = { + # Model / data "agents": [model], "num_agents": 1, - "reward_func": reward_func, - "formatters": formatter, - "args": grpo_args, + "tokenizer": tokenizer, "train_dataset": train_dataset, "eval_dataset": eval_dataset, - "tokenizer": tokenizer, + # Reward / formatting + "reward_func": reward_func, + "formatters": formatter, + # Logging / config "wandb_config": wandb_config, "dataset_type": dataset_type, + # Training args + "args": grpo_args, } if reward_processor is not None: @@ -465,7 +485,7 @@ def _resolver(prompt: str): if ( is_multi_turn and dataset_type - and dataset_type.lower() in ["humaneval", "coophumaneval"] + and dataset_type.lower() in ["humaneval", "coophumaneval", "mbpp"] ): expert_model = external_cfg.get("expert_model", "deepseek-coder") diff --git a/train_magrpo.py b/train_magrpo.py index 47f4593..31a515d 100644 --- a/train_magrpo.py +++ b/train_magrpo.py @@ -10,8 +10,8 @@ import sys sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) -from pathlib import Path -from typing import Any, Dict, Optional + +from typing import Any, Dict from config import Config, add_config_args, parse_overrides from datasets import load_dataset @@ -110,7 +110,7 @@ def get_formatters(dataset_type: str, num_agents: int): For code tasks, use aux formatters for all agents except the last, which uses main. """ - if dataset_type.lower() in ["humaneval", "coophumaneval"] and num_agents == 2: + if dataset_type.lower() in ["humaneval", "coophumaneval", "mbpp"] and num_agents == 2: return [aux_function_formatter, main_function_formatter] raise NotImplementedError("Other number of agents have not been implemented yet") @@ -125,7 +125,7 @@ def get_logger_and_aggregator(dataset_type: str, is_multi_turn: bool = False): return None, None # Use unified multi-turn compatible logger/aggregator for code datasets - if dataset_type.lower() in ["humaneval", "coophumaneval"]: + if dataset_type.lower() in ["humaneval", "coophumaneval", "mbpp"]: return mt_humaneval_logger, aggregate_mt_humaneval_metrics_for_logging return None, None @@ -142,7 +142,7 @@ def get_reward_function(dataset_type: str, num_agents: int): "dataset.type not specified in config. Please add 'type: humaneval/coophumaneval' to the dataset section." ) - if dataset_type.lower() in ["humaneval", "coophumaneval"]: + if dataset_type.lower() in ["humaneval", "coophumaneval", "mbpp"]: def reward_wrapper(*agent_completions, batch_items=None, prompts=None): # agent_completions: tuple of lists (one list per agent), each list contains strings per completion @@ -189,6 +189,9 @@ def main(): args = parser.parse_args() + # ------------------------------------------------------------------ + # Config: load YAML and apply overrides + # ------------------------------------------------------------------ if args.config: config = Config(args.config) else: @@ -201,12 +204,14 @@ def main(): # Apply command-line overrides - # Load model configuration + # ------------------------------------------------------------------ + # Config: model, dataset, output + # ------------------------------------------------------------------ model_config = config.get_model_config() model_name = model_config.name - output_base_dir = config.get("output.base_dir") dataset_name = config.get("dataset.name") dataset_type = config.get("dataset.type") + output_base_dir = config.get("output.base_dir") # Try to infer dataset type from dataset name if not specified if dataset_type is None: @@ -214,6 +219,8 @@ def main(): dataset_type = "humaneval" elif "coophumaneval" in dataset_name.lower() or "coop" in dataset_name.lower(): dataset_type = "coophumaneval" + elif "mbpp" in dataset_name.lower(): + dataset_type = "mbpp" else: raise ValueError( f"Could not infer dataset type from dataset name '{dataset_name}'. Please specify 'type' in dataset config." @@ -223,15 +230,14 @@ def main(): train_split = config.get("dataset.train_split") eval_split = config.get("dataset.eval_split") - # Get MAGRPO configuration (works for both single and multi-turn) + # ------------------------------------------------------------------ + # Config: MAGRPO training params and verbosity + # ------------------------------------------------------------------ magrpo_config = ( config.get_section("magrpo") if hasattr(config, "get_section") else {} ) - - # Check if this is multi-turn training num_turns = magrpo_config.get("num_turns", 1) is_multi_turn = num_turns > 1 - output_verbose = config.get("output.verbose", True) if output_verbose: print(f"Multi-turn training enabled: num_turns={num_turns}") if is_multi_turn else print( @@ -291,7 +297,9 @@ def main(): temperature = magrpo_config.get("temperature", model_config.temperature) top_p = magrpo_config.get("top_p", model_config.top_p) - # External configuration (mode, sandbox, expert model, context flags) + # ------------------------------------------------------------------ + # Config: External transitions (mode, sandbox, expert model, context flags) + # ------------------------------------------------------------------ external_cfg = config.get_section("external") if hasattr(config, "get_section") else {} # Register external context resolver using dataset items @@ -318,7 +326,7 @@ def _normalize_prompt(p: str) -> str: else: sandbox_slice = None if _sandbox_val is None else 0 - import re + # re already imported at module level def _make_sliced_assert_tests(test_code: str, n: int) -> str: if not isinstance(test_code, str) or not test_code.strip(): @@ -389,6 +397,9 @@ def _resolver(prompt: str): external_ctx.set_context_resolver(_resolver) # Use unified MAGRPOConfig which handles both single-turn and multi-turn + # ------------------------------------------------------------------ + # Build training args + # ------------------------------------------------------------------ magrpo_args = MAGRPOConfig( output_dir=output_dir, num_agents=magrpo_config.get("num_agents", 2), # Pass num_agents to the config @@ -408,22 +419,25 @@ def _resolver(prompt: str): discount=magrpo_config.get("discount", 0.9), joint_mode=magrpo_config.get("joint_mode", "aligned"), termination_threshold=magrpo_config.get("termination_threshold", None), - # GRPO-style advantage params - normalize_advantage=magrpo_config.get("normalize_advantage", False), epsilon_clip=magrpo_config.get("epsilon_clip", None), ) - # Get appropriate formatters and functions based on dataset type, agent count, and training mode + # ------------------------------------------------------------------ + # Formatters, rewards, and logging + # ------------------------------------------------------------------ formatters = get_formatters(dataset_type, config.get("magrpo.num_agents", 2)) reward_func = get_reward_function(dataset_type, config.get("magrpo.num_agents", 2)) eval_logger, eval_aggregator = get_logger_and_aggregator( dataset_type, is_multi_turn ) + # ------------------------------------------------------------------ + # W&B configuration and tags + # ------------------------------------------------------------------ wandb_section = ( config.get_section("wandb") if hasattr(config, "get_section") else {} ) - model_short_name = model_name.split("/")[-1].lower() + # Model short name no longer used in W&B naming # Use different wandb name for multi-turn if is_multi_turn: @@ -450,7 +464,7 @@ def _resolver(prompt: str): wandb_config = { "project": wandb_section.get("project", "mlrl"), "entity": wandb_section.get("entity", "nu-llpr"), - "name": f"{wandb_name}_{model_short_name}", + "name": f"{wandb_name}", "dir": wandb_section.get("dir", "../../../projects/bepg/sliu30"), "tags": tags, # Provide full sections for the trainer to log cleanly @@ -506,19 +520,26 @@ def _resolver(prompt: str): prev = reward_processor reward_processor = (lambda p=prev, s=shift_proc: (lambda x: s(p(x))))() + # ------------------------------------------------------------------ + # Build trainer kwargs (grouped: model/data, reward/formatting, logging, args) + # ------------------------------------------------------------------ trainer_kwargs = { + # Model / data "agents": agents, "num_agents": num_agents, - "reward_func": reward_func, - "formatters": formatters, - "args": magrpo_args, + "tokenizer": tokenizer, "train_dataset": train_dataset, "eval_dataset": eval_dataset, - "tokenizer": tokenizer, + # Reward / formatting + "reward_func": reward_func, + "formatters": formatters, + # Logging / eval / config "wandb_config": wandb_config, "eval_logger": eval_logger, "eval_aggregator": eval_aggregator, "dataset_type": dataset_type, + # Training args + "args": magrpo_args, } if reward_processor is not None: @@ -527,7 +548,7 @@ def _resolver(prompt: str): if ( is_multi_turn and dataset_type - and dataset_type.lower() in ["humaneval", "coophumaneval"] + and dataset_type.lower() in ["humaneval", "coophumaneval", "mbpp"] ): expert_model = external_cfg.get("expert_model", "deepseek-coder") # external_mode already loaded above