In [37]:
# # REFACTOR: Standard imports + asyncio and nest_asyncio for Jupyter compatibility.
# import pandas as pd
# import numpy as np
# import time
# import datetime
# import importlib
# import inspect
# import os
# import re
# import json
# import random
# import openai
# import google.generativeai as genai
# import anthropic
# import asyncio
# import nest_asyncio
# from openai import AsyncOpenAI # REFACTOR: Import async client
# from anthropic import AsyncClient # REFACTOR: Import async client
# from datasets import load_dataset
# from tqdm.notebook import tqdm # REFACTOR: Use notebook-friendly tqdm
# from dotenv import load_dotenv
# from typing import List, Dict, Any
# from pathlib import Path

# def find_project_root():
#     """Traverse upwards to find the project root, marked by the .git folder."""
#     current_path = Path.cwd()
#     while current_path != current_path.parent:
#         if (current_path / ".git").is_dir():
#             return current_path
#         current_path = current_path.parent
#     raise FileNotFoundError("Could not find project root. Is this a git repository?")

# # REFACTOR: Apply nest_asyncio to allow asyncio to run in a Jupyter notebook.
# # This must be done once per kernel.
# nest_asyncio.apply()

# # --- 1. Client and Model Configuration ---

# load_dotenv()

# # REFACTOR: Initialize asynchronous clients.
# # The synchronous clients are no longer needed for the generation script.
# openai_client_async = AsyncOpenAI(api_key=os.getenv("OPENAI_API_KEY"))
# anthropic_client_async = AsyncClient(api_key=os.getenv("ANTHROPIC_API_KEY"))
# genai.configure(api_key=os.getenv("GOOGLE_API_KEY"))

# # Define the project root as a global constant
# PROJECT_ROOT = find_project_root()
# print(f"Project root identified: {PROJECT_ROOT}")
# BASE_OUTPUT_DIR = PROJECT_ROOT / 'data' / 'code_gen_outputs_raw'

# # REFACTOR: Define rate limits here.
# # These values are placeholders. You MUST replace them with the actual
# # requests-per-minute (RPM) limits for your API keys to avoid errors.
# # A safe starting point is slightly below the documented RPM.
# API_CONCURRENCY_LIMITS = {
#     "openai": 10,    # e.g., limit to 10 concurrent requests for OpenAI
#     "anthropic": 10, # e.g., limit to 10 concurrent requests for Anthropic
#     "google": 10,    # e.g., limit to 10 concurrent requests for Google
# }

# # --- 2. System Prompt and Helper Functions (no changes needed) ---

# SYSTEM_PROMPT = "You are an expert Python programmer specializing in data formalization. Your role is to meticulously convert natural language math problems and their step-by-step solutions into a single, well-structured Python function. You will be presented with examples of the required format followed by a final task to complete."


# PROMPT_GUIDELINES = """### Guidelines

# 0. **Output wrapping**
#    Return the code inside a single ```python … ``` block, and nothing else.

# 1.  **Function Naming & Docstring:** The function must be named `solve`. It must begin with a docstring that has exactly two lines:
#     *   The first line must be: "Index: [Index]." using the index from the task header.
#     *   The second line must be a succinct, one-sentence description of what the function returns (e.g., "Returns: the total cost of wages and taxes.").

# 2.  **Function Arguments:** The function arguments must be derived from the 'Question' text. 
#     *   Create a distinct argument for every numerical value that is directly stated in the text.
#     *   The arguments should be created **in the same order in which they appear in the question**.
#     *   **Note:** Some of these arguments may end up not being used in the function body. This is expected. Do not worry about this and leave the unused arguments in the function signature.

# 3.  **Argument Formatting:** Each argument must include a type-hint (e.g., `int`, `float`) and a default value equal to its value in the 'Question'. You must also add a comment (`#`) next to each argument that quotes or refers to the phrase in the 'Question' it comes from. 

# 4.  **Function Body:** The body of the function should follow the logic of the provided 'Solution' dict, which contains the step-by-step solution to the problem. The keys of this dict are strings (e.g. `"L1"`, `"L2"`) which refer to the line number, and the values of the dict are the corresponding steps in the solution. 
#     * For every relevant line in the 'Solution', you must include a comment in the Python code that indicates the line number (key) from the 'Solution' dict.
#     * These comments should be formatted as `#: L<n>`, where `<n>` is the line number from the 'Solution' dict.
#     * Immediately follow the comment with the Python statement that performs the calculation.
#     * Steps in the solution should result in the creation of new, intermediate variables, which should be named descriptively based on the context of the calculation.
#     * Wherever possible, in your code try to use only the variables from the function arguments and the intermediate variables you created before, and try to avoid using hard-coded numbers in the calculations.

# 5.  **Calculator Annotations:** Pay close attention to the calculator annotations (e.g., `[[25*8=200]]`) in the 'Solution' as they reveal the precise mathematical operations to implement. **Note**: Some lines in the solution may not contain calculator annotations, but you should still pay attention to the logic and calculations described in those lines.

# 6.  **Final Answer:** Store the final answer in a variable named 'answer', and on the same line, add the comment `# FINAL ANSWER`. In the next line, return the 'answer' variable.

# 7. **No extra output:** Your output should end with the ``` closing the code block. Do not include any additional text, explanations, or comments outside of the code block."""

# gsm8k_train = load_dataset("gsm8k", "main", split="train")

# def build_solution_mapping(
#     index: int,
#     dataset: Any,
#     convert_brackets: bool = True,
# ):
#     """
#     Parameters
#     ----------
#     index : int
#         Position of the sample in the loaded dataset.
#     dataset : iterable / HuggingFace Dataset
#     convert_brackets : bool, default ``True``
#         If ``True`` replace every ``<< … >>`` calculator annotation with
#         the canonical ``[[ … ]]`` form so downstream code sees a single
#         bracket style.

#     Returns
#     -------
#     Dict[str, str]
#         Mapping ``{"L1": <first non-empty line>, "L2": <second>, …}``.

#     Notes
#     -----
#     * Blank lines in ``sample["answer"]`` are ignored.
#     * The line numbering reflects the *order* in the original solution
#       string; there is no semantic grouping beyond that.
#     """
#     # extract & split solution text
#     solution_text = dataset[index]["answer"]
#     lines = [ln.strip() for ln in solution_text.splitlines() if ln.strip()]

#     # Remove the last line if it matches the '####' answer pattern
#     if lines and re.match(r"^####\s*\d+(\.\d+)?$", lines[-1]):
#         lines = lines[:-1]

#     # optional bracket normalisation
#     if convert_brackets:
#         angle = re.compile(r"<<([^>]+)>>")
#         lines = [angle.sub(r"[[\1]]", ln) for ln in lines]
#     # build mapping
#     return {f"L{i}": line for i, line in enumerate(lines, 1)}

# def get_code_strings(indices: List[int], savepath: Path = PROJECT_ROOT / 'data' / 'code_examples'):
#     """
#     Reads code examples directly from .py files instead of importing them.
#     This is more robust and avoids Python's complex import path mechanics.
#     """
#     code_strings = {}
#     for idx in indices:
#         # Construct the full file path
#         filepath = os.path.join(savepath, f"_{idx}.py")
#         try:
#             # Read the entire content of the file
#             with open(filepath, 'r', encoding='utf-8') as f:
#                 code_strings[idx] = f.read()
#         except FileNotFoundError:
#             print(f"Error: Could not find example file for index {idx} at: {filepath}")
#             code_strings[idx] = f"# Error: Code for example {idx} not found."
#     return code_strings

# def format_prompt_query(index: int, code_strings: dict, with_code: bool = False):
#     sample = gsm8k_train[index]
#     question = sample["question"]
#     solution_mapping = build_solution_mapping(index, gsm8k_train)
#     solution = json.dumps(solution_mapping)
#     out = f"""*Index*: 
# {index}

# *Question*: 
# {question}

# *Solution*: 
# {solution}

# *Code*:"""
#     if with_code:
#         out += f"""\n```python
# {code_strings.get(index, "# Code not found")}
# ```"""
#     return out

# def craft_user_prompt(index: int, example_indices: List[int], code_examples: Dict[int, str]):
#     example_prompts = [
#         format_prompt_query(index=idx, code_strings=code_examples, with_code=True)
#         for idx in example_indices
#     ]
#     task_prompt = format_prompt_query(index=index, code_strings=code_examples)
#     full_prompt = "\n".join([
#         PROMPT_GUIDELINES,
#         "\n--- EXAMPLES ---\n",
#         "\n".join(example_prompts),
#         "--- TASK ---\n",
#         task_prompt
#     ])
#     return full_prompt


# # --- 3. Asynchronous API Calling Function ---

# # REFACTOR: This new async function replaces the original `call_model_api`.
# async def call_model_api_async(
#     provider: str,
#     model: str,
#     system_prompt: str,
#     user_prompt: str,
#     semaphore: asyncio.Semaphore
# ) -> str | None:
#     """
#     Asynchronously calls the appropriate LLM API using a semaphore for rate limiting.
#     """
#     async with semaphore: # Wait for an open slot to respect rate limits
#         try:
#             if provider == "google":
#                 gemini = genai.GenerativeModel(
#                     model_name=model,
#                     system_instruction=system_prompt
#                 )
#                 generation_config = genai.types.GenerationConfig(
#                     temperature=0.1,
#                     max_output_tokens=4000
#                 )
#                 # Use the async method for the Google SDK
#                 response = await gemini.generate_content_async(
#                     user_prompt,
#                     generation_config=generation_config
#                 )
#                 return response.text

#             elif provider == "anthropic":
#                 # Use the async client for Anthropic
#                 response = await anthropic_client_async.messages.create(
#                     model=model,
#                     max_tokens=4000,
#                     temperature=0.1,
#                     system=system_prompt,
#                     messages=[{"role": "user", "content": user_prompt}]
#                 )
#                 return response.content[0].text

#             elif provider == "openai":
#                 kwargs = {
#                     "model": model,
#                     "messages": [
#                         {"role": "system", "content": system_prompt},
#                         {"role": "user", "content": user_prompt}
#                     ]
#                 }
#                 if model not in ["o3-mini", "o4-mini"]:
#                     kwargs["temperature"] = 0.1
#                     kwargs["max_tokens"] = 4000
#                 # Use the async client for OpenAI
#                 response = await openai_client_async.chat.completions.create(**kwargs)
#                 return response.choices[0].message.content
            
#             else:
#                 print(f"Unknown provider: {provider}")
#                 return None

#         except Exception as e:
#             # Important to print the error to know why a call failed
#             print(f"An API error occurred for {provider} model {model}: {e}")
#             # Return the exception itself to be handled by the orchestrator
#             return e

# # --- 4. Parallel Orchestration Function ---

# # REFACTOR: This is the new main orchestrator.
# async def generate_GSM8K_code_parallel(
#     model_dict: Dict[str, List[str]],
#     indices_to_generate: List[int],
#     example_indices: List[int],
#     system_prompt: str = SYSTEM_PROMPT,
#     output_dir: str = BASE_OUTPUT_DIR  # Add this argument
# ):
#     """
#     Calls multiple LLM APIs in parallel for each problem index, saves the
#     raw output, and logs performance.
#     """
#     performance_data = []
#     os.makedirs(output_dir, exist_ok=True)

#     # REFACTOR: Create a dictionary of semaphores based on the defined limits.
#     semaphores = {
#         provider: asyncio.Semaphore(limit)
#         for provider, limit in API_CONCURRENCY_LIMITS.items()
#     }

#     # Use tqdm for the outer loop to track progress over problem indices
#     for index in tqdm(indices_to_generate, desc="Processing Problems"):
#         problem_dir = os.path.join(output_dir, str(index))
#         os.makedirs(problem_dir, exist_ok=True)

#         # Craft prompt once per problem
#         user_prompt = craft_user_prompt(
#             index=index,
#             example_indices=example_indices,
#             code_examples=get_code_strings(example_indices)
#         )

#         # REFACTOR: Create a list of async tasks for all models for the current index.
#         tasks = []
#         for provider, models in model_dict.items():
#             for model_name in models:
#                 task = asyncio.create_task(
#                     call_model_api_async(
#                         provider,
#                         model_name,
#                         system_prompt,
#                         user_prompt,
#                         semaphores[provider]
#                     )
#                 )
#                 # Store metadata with the task for later processing
#                 task.meta = {'provider': provider, 'model': model_name, 'index': index, 'start_time': time.time()}
#                 tasks.append(task)
        
#         # REFACTOR: Run all tasks concurrently and wait for them to complete.
#         # `return_exceptions=True` ensures that one failed API call doesn't stop others.
#         print(f"Index {index}: Launching {len(tasks)} API calls in parallel...")
#         results = await asyncio.gather(*tasks, return_exceptions=True)
#         print(f"Index {index}: All API calls completed.")

#         # REFACTOR: Process the results.
#         for task, result in zip(tasks, results):
#             meta = task.meta
#             time_taken = time.time() - meta['start_time']
            
#             # Check if the result is an exception
#             if isinstance(result, Exception):
#                 raw_response = None
#                 print(f"  -> Failed: {meta['provider']}_{meta['model']} ({time_taken:.2f}s). Error: {result}")
#             else:
#                 raw_response = result
#                 print(f"  -> Success: {meta['provider']}_{meta['model']} ({time_taken:.2f}s).")

#             performance_data.append({
#                 'provider': meta['provider'],
#                 'model': meta['model'],
#                 'index': meta['index'],
#                 'time_taken': time_taken,
#                 'status': 'Failed' if raw_response is None else 'Success'
#             })
            
#             if raw_response:
#                 output_filename = f"{meta['provider']}_{meta['model']}.txt"
#                 output_path = os.path.join(problem_dir, output_filename)
#                 try:
#                     with open(output_path, 'w', encoding='utf-8') as f:
#                         f.write(raw_response)
#                 except IOError as e:
#                     print(f"    Error: Failed to write file. Reason: {e}")

#     # Save performance data to CSV at the end
#     df = pd.DataFrame(performance_data)
#     csv_path = os.path.join(output_dir, 'generation_performance.csv')
#     df.to_csv(csv_path, index=False)
#     print(f"\nGeneration complete. Performance data saved to {csv_path}.")
#     return df

# # # Add any problem indices you have generated outputs for.
# # problem_indices_to_test = sorted([3331, 1647, 636, 399, 4670, 5918, 1531, 7364, 5464, 1205, 3518, 6732, 3779, 4483, 6237, 1202, 2345])

In [38]:
# def generate_and_print_sample_prompt(target_index: int, example_indices: List[int]):
#     """
#     Generates and prints the full user prompt for a single target index
#     for debugging and inspection.

#     Args:
#         target_index: The GSM8K index for the final task in the prompt.
#         example_indices: A list of GSM8K indices to use as few-shot examples.
#     """
#     print(f"--- Generating sample prompt for target index: {target_index} ---")

#     # 1. Load the code strings for the few-shot examples using the same
#     #    function as the main pipeline. This tests if the examples are loading correctly.
#     code_examples = get_code_strings(indices=example_indices)

#     # 2. Craft the full user prompt.
#     full_prompt = craft_user_prompt(
#         index=target_index,
#         example_indices=example_indices,
#         code_examples=code_examples
#     )
#     print(full_prompt)

# generate_and_print_sample_prompt(
#     target_index=6237,
#     example_indices=[310, 3822, 7371]
# )

In [39]:
# # --- 5. Execution ---

# # Define your parameters here
# indices = [310, 3822, 7371] # Use the indices of your few-shot examples
# indices_to_generate = list(range(10))

# model_dict = {
#   "anthropic": ["claude-3-5-haiku-20241022"], 
#   "openai": ["gpt-4.1-mini"],
#   "google": ["gemini-2.0-flash-thinking-exp", 
#              "gemini-2.5-flash-lite-preview-06-17"]
# }

# # REFACTOR: To run the async function, you must `await` it.
# # This will execute the entire parallel generation process.
# # The result (a pandas DataFrame with performance logs) will be stored in `perf_df`.

# perf_df = await generate_GSM8K_code_parallel(
#     model_dict=model_dict,
#     indices_to_generate=indices_to_generate,
#     example_indices=indices
# )

# # print("\nFinal Performance Summary:")
# # print(perf_df)

In [40]:
# REFACTOR: Standard imports + asyncio and nest_asyncio for Jupyter compatibility.
import pandas as pd
import numpy as np
import time
import datetime
import importlib
import inspect
import os
import re
import json
import random
import openai
import google.generativeai as genai
import anthropic
import asyncio
import nest_asyncio
from openai import AsyncOpenAI # REFACTOR: Import async client
from anthropic import AsyncClient # REFACTOR: Import async client
from datasets import load_dataset
from tqdm.notebook import tqdm # REFACTOR: Use notebook-friendly tqdm
from dotenv import load_dotenv
from typing import List, Dict, Any
from pathlib import Path

def find_project_root():
    """Traverse upwards to find the project root, marked by the .git folder."""
    current_path = Path.cwd()
    while current_path != current_path.parent:
        if (current_path / ".git").is_dir():
            return current_path
        current_path = current_path.parent
    raise FileNotFoundError("Could not find project root. Is this a git repository?")

# REFACTOR: Apply nest_asyncio to allow asyncio to run in a Jupyter notebook.
# This must be done once per kernel.
nest_asyncio.apply()

# --- 1. Client and Model Configuration ---

load_dotenv()

# REFACTOR: Initialize asynchronous clients.
# The synchronous clients are no longer needed for the generation script.
openai_client_async = AsyncOpenAI(api_key=os.getenv("OPENAI_API_KEY"))
anthropic_client_async = AsyncClient(api_key=os.getenv("ANTHROPIC_API_KEY"))
genai.configure(api_key=os.getenv("GOOGLE_API_KEY"))

# Define the project root as a global constant
PROJECT_ROOT = find_project_root()
print(f"Project root identified: {PROJECT_ROOT}")
BASE_OUTPUT_DIR = PROJECT_ROOT / 'data' / 'code_gen_outputs_raw'

# REFACTOR: Define rate limits here.
# These values are placeholders. You MUST replace them with the actual
# requests-per-minute (RPM) limits for your API keys to avoid errors.
# A safe starting point is slightly below the documented RPM.
API_CONCURRENCY_LIMITS = {
    "openai": 10,    # e.g., limit to 10 concurrent requests for OpenAI
    "anthropic": 10, # e.g., limit to 10 concurrent requests for Anthropic
    "google": 10,    # e.g., limit to 10 concurrent requests for Google
}

# --- 2. System Prompt and Helper Functions (no changes needed) ---

SYSTEM_PROMPT = "You are an expert Python programmer specializing in data formalization. Your role is to meticulously convert natural language math problems and their step-by-step solutions into a single, well-structured Python function. You will be presented with examples of the required format followed by a final task to complete."


PROMPT_GUIDELINES = """### Guidelines

0. **Output wrapping**
   Return the code inside a single ```python … ``` block, and nothing else.

1.  **Function Naming & Docstring:** The function must be named `solve`. It must begin with a docstring that has exactly two lines:
    *   The first line must be: "Index: [Index]." using the index from the task header.
    *   The second line must be a succinct, one-sentence description of what the function returns (e.g., "Returns: the total cost of wages and taxes.").

2.  **Function Arguments:** The function arguments must be derived from the 'Question' text. 
    *   Create a distinct argument for every numerical value that is directly stated in the text.
    *   The arguments should be created **in the same order in which they appear in the question**.
    *   **Note:** Some of these arguments may end up not being used in the function body. This is expected. Do not worry about this and leave the unused arguments in the function signature.

3.  **Argument Formatting:** Each argument must include a type-hint (e.g., `int`, `float`) and a default value equal to its value in the 'Question'. You must also add a comment (`#`) next to each argument that quotes or refers to the phrase in the 'Question' it comes from. 

4.  **Function Body:** The body of the function should follow the logic of the provided 'Solution' dict, which contains the step-by-step solution to the problem. The keys of this dict are strings (e.g. `"L1"`, `"L2"`) which refer to the line number, and the values of the dict are the corresponding steps in the solution. 
    * For every relevant line in the 'Solution', you must include a comment in the Python code that indicates the line number (key) from the 'Solution' dict.
    * These comments should be formatted as `#: L<n>`, where `<n>` is the line number from the 'Solution' dict.
    * Immediately follow the comment with the Python statement that performs the calculation.
    * Steps in the solution should result in the creation of new, intermediate variables, which should be named descriptively based on the context of the calculation.
    * Wherever possible, in your code try to use only the variables from the function arguments and the intermediate variables you created before, and try to avoid using hard-coded numbers in the calculations.

5.  **Calculator Annotations:** Pay close attention to the calculator annotations (e.g., `[[25*8=200]]`) in the 'Solution' as they reveal the precise mathematical operations to implement. **Note**: Some lines in the solution may not contain calculator annotations, but you should still pay attention to the logic and calculations described in those lines.

6.  **Final Answer:** Store the final answer in a variable named 'answer', and on the same line, add the comment `# FINAL ANSWER`. In the next line, return the 'answer' variable.

7. **No extra output:** Your output should end with the ``` closing the code block. Do not include any additional text, explanations, or comments outside of the code block."""

gsm8k_train = load_dataset("gsm8k", "main", split="train")

def build_solution_mapping(
    index: int,
    dataset: Any,
    convert_brackets: bool = True,
):
    """
    Parameters
    ----------
    index : int
        Position of the sample in the loaded dataset.
    dataset : iterable / HuggingFace Dataset
    convert_brackets : bool, default ``True``
        If ``True`` replace every ``<< … >>`` calculator annotation with
        the canonical ``[[ … ]]`` form so downstream code sees a single
        bracket style.

    Returns
    -------
    Dict[str, str]
        Mapping ``{"L1": <first non-empty line>, "L2": <second>, …}``.

    Notes
    -----
    * Blank lines in ``sample["answer"]`` are ignored.
    * The line numbering reflects the *order* in the original solution
      string; there is no semantic grouping beyond that.
    """
    # extract & split solution text
    solution_text = dataset[index]["answer"]
    lines = [ln.strip() for ln in solution_text.splitlines() if ln.strip()]

    # Remove the last line if it matches the '####' answer pattern
    if lines and re.match(r"^####\s*\d+(\.\d+)?$", lines[-1]):
        lines = lines[:-1]

    # optional bracket normalisation
    if convert_brackets:
        angle = re.compile(r"<<([^>]+)>>")
        lines = [angle.sub(r"[[\1]]", ln) for ln in lines]
    # build mapping
    return {f"L{i}": line for i, line in enumerate(lines, 1)}

def get_code_strings(indices: List[int], savepath: Path = PROJECT_ROOT / 'data' / 'code_examples'):
    """
    Reads code examples directly from .py files instead of importing them.
    This is more robust and avoids Python's complex import path mechanics.
    """
    code_strings = {}
    for idx in indices:
        # Construct the full file path
        filepath = os.path.join(savepath, f"_{idx}.py")
        try:
            # Read the entire content of the file
            with open(filepath, 'r', encoding='utf-8') as f:
                code_strings[idx] = f.read()
        except FileNotFoundError:
            print(f"Error: Could not find example file for index {idx} at: {filepath}")
            code_strings[idx] = f"# Error: Code for example {idx} not found."
    return code_strings

def format_prompt_query(index: int, code_strings: dict, with_code: bool = False):
    sample = gsm8k_train[index]
    question = sample["question"]
    solution_mapping = build_solution_mapping(index, gsm8k_train)
    solution = json.dumps(solution_mapping)
    out = f"""*Index*: 
{index}

*Question*: 
{question}

*Solution*: 
{solution}

*Code*:"""
    if with_code:
        out += f"""\n```python
{code_strings.get(index, "# Code not found")}
```"""
    return out

def craft_user_prompt(index: int, example_indices: List[int], code_examples: Dict[int, str]):
    example_prompts = [
        format_prompt_query(index=idx, code_strings=code_examples, with_code=True)
        for idx in example_indices
    ]
    task_prompt = format_prompt_query(index=index, code_strings=code_examples)
    full_prompt = "\n".join([
        PROMPT_GUIDELINES,
        "\n--- EXAMPLES ---\n",
        "\n".join(example_prompts),
        "--- TASK ---\n",
        task_prompt
    ])
    return full_prompt


# --- 3. Asynchronous API Calling Function ---

# REFACTOR: Simplified function that removes the internal try/except block.
# It now lets exceptions propagate to be caught by asyncio.gather.
async def call_model_api_async(
    provider: str,
    model: str,
    system_prompt: str,
    user_prompt: str,
    semaphore: asyncio.Semaphore
) -> tuple[str, dict]:
    """
    Asynchronously calls the appropriate LLM API and returns the text response
    and a dictionary of token usage statistics.
    Raises an exception on API failure.
    """
    async with semaphore:
        token_usage = {"input_tokens": 0, "output_tokens": 0}

        if provider == "google":
            gemini = genai.GenerativeModel(model_name=model, system_instruction=system_prompt)
            generation_config = genai.types.GenerationConfig(temperature=0.1, max_output_tokens=4000)
            response = await gemini.generate_content_async(user_prompt, generation_config=generation_config)
            text = response.text
            if response.usage_metadata:
                token_usage["input_tokens"] = response.usage_metadata.prompt_token_count
                token_usage["output_tokens"] = response.usage_metadata.candidates_token_count
            return text, token_usage

        elif provider == "anthropic":
            response = await anthropic_client_async.messages.create(
                model=model, max_tokens=4000, temperature=0.1,
                system=system_prompt,
                messages=[{"role": "user", "content": user_prompt}]
            )
            text = response.content[0].text
            if response.usage:
                token_usage["input_tokens"] = response.usage.input_tokens
                token_usage["output_tokens"] = response.usage.output_tokens
            return text, token_usage

        elif provider == "openai":
            kwargs = {
                "model": model,
                "messages": [{"role": "system", "content": system_prompt}, {"role": "user", "content": user_prompt}]
            }
            if model not in ["o3-mini", "o4-mini"]:
                kwargs["temperature"] = 0.1
                kwargs["max_tokens"] = 4000
            response = await openai_client_async.chat.completions.create(**kwargs)
            text = response.choices[0].message.content
            if response.usage:
                token_usage["input_tokens"] = response.usage.prompt_tokens
                token_usage["output_tokens"] = response.usage.completion_tokens
            return text, token_usage
        
        else:
            raise ValueError(f"Unknown provider: {provider}")

# --- 4. Parallel Orchestration Function ---

# REFACTOR: Updated orchestrator with correct exception handling logic.
async def generate_GSM8K_code_parallel(
    model_dict: Dict[str, List[str]],
    indices_to_generate: List[int],
    example_indices: List[int],
    system_prompt: str = SYSTEM_PROMPT,
    output_dir: Path = BASE_OUTPUT_DIR
):
    """
    Calls multiple LLM APIs in parallel, saves the raw output, and logs
    performance including token usage and timestamps.
    """
    performance_data = []
    os.makedirs(output_dir, exist_ok=True)
    
    semaphores = {
        provider: asyncio.Semaphore(limit)
        for provider, limit in API_CONCURRENCY_LIMITS.items()
    }

    for index in tqdm(indices_to_generate, desc="Processing Problems"):
        problem_dir = output_dir / str(index)
        os.makedirs(problem_dir, exist_ok=True)

        user_prompt = craft_user_prompt(
            index=index,
            example_indices=example_indices,
            code_examples=get_code_strings(indices=example_indices)
        )

        tasks = []
        for provider, models in model_dict.items():
            for model_name in models:
                task = asyncio.create_task(
                    call_model_api_async(
                        provider, model_name, system_prompt, user_prompt, semaphores[provider]
                    )
                )
                task.meta = {'provider': provider, 'model': model_name, 'index': index, 'start_time': time.time()}
                tasks.append(task)
        
        print(f"Index {index}: Launching {len(tasks)} API calls in parallel...")
        results = await asyncio.gather(*tasks, return_exceptions=True)
        print(f"Index {index}: All API calls completed.")

        # This loop now correctly handles exceptions caught by asyncio.gather.
        for task, result in zip(tasks, results):
            meta = task.meta
            time_taken = time.time() - meta['start_time']
            completion_timestamp = datetime.datetime.now(datetime.timezone.utc).isoformat()
            
            raw_response = None
            token_usage = {"input_tokens": 0, "output_tokens": 0}
            status = "Failed" # Default to Failed

            # Correctly check if the result is an exception object.
            if isinstance(result, Exception):
                print(f"  -> Failed: {meta['provider']}_{meta['model']} ({time_taken:.2f}s). Error: {result}")
                # We do not unpack 'result' here because it is an exception.
            
            else:
                # If it's not an exception, it must be the (text, usage) tuple.
                raw_response, token_usage = result # This unpacking is now safe.
                status = "Success"
                print(f"  -> Success: {meta['provider']}_{meta['model']} ({time_taken:.2f}s).")

            performance_data.append({
                'provider': meta['provider'],
                'model': meta['model'],
                'index': meta['index'],
                'time_taken': time_taken,
                'status': status,
                'completion_timestamp_utc': completion_timestamp,
                'input_tokens': token_usage.get('input_tokens', 0),
                'output_tokens': token_usage.get('output_tokens', 0),
            })
            
            if raw_response:
                output_filename = f"{meta['provider']}_{meta['model']}.txt"
                output_path = problem_dir / output_filename
                try:
                    with open(output_path, 'w', encoding='utf-8') as f:
                        f.write(raw_response)
                except IOError as e:
                    print(f"    Error: Failed to write file. Reason: {e}")

    df = pd.DataFrame(performance_data)
    csv_path = output_dir / 'generation_performance.csv'
    df.to_csv(csv_path, index=False)
    print(f"\nGeneration complete. Performance data saved to {csv_path}.")
    return df

# # Add any problem indices you have generated outputs for.
# problem_indices_to_test = sorted([3331, 1647, 636, 399, 4670, 5918, 1531, 7364, 5464, 1205, 3518, 6732, 3779, 4483, 6237, 1202, 2345])

Project root identified: /Users/arvindsuresh/Documents/Github/Erdos-DL-June25-Math


In [41]:
# --- 5. Execution ---

# Define your parameters here
indices = [310, 3822, 7371] # Use the indices of your few-shot examples
indices_to_generate = list(range(5))

model_dict = {
  "anthropic": ["claude-3-5-haiku-20241022"], 
  "openai": ["gpt-4.1-mini"],
  "google": ["gemini-2.0-flash-thinking-exp", 
             "gemini-2.5-flash-lite-preview-06-17"]
}

# REFACTOR: To run the async function, you must `await` it.
# This will execute the entire parallel generation process.
# The result (a pandas DataFrame with performance logs) will be stored in `perf_df`.

perf_df = await generate_GSM8K_code_parallel(
    model_dict=model_dict,
    indices_to_generate=indices_to_generate,
    example_indices=indices
)

# print("\nFinal Performance Summary:")
# print(perf_df)

Processing Problems:   0%|          | 0/5 [00:00<?, ?it/s]

Index 0: Launching 4 API calls in parallel...
Index 0: All API calls completed.
  -> Success: anthropic_claude-3-5-haiku-20241022 (4.32s).
  -> Success: openai_gpt-4.1-mini (4.32s).
  -> Success: google_gemini-2.0-flash-thinking-exp (4.32s).
  -> Success: google_gemini-2.5-flash-lite-preview-06-17 (4.32s).
Index 1: Launching 4 API calls in parallel...
Index 1: All API calls completed.
  -> Success: anthropic_claude-3-5-haiku-20241022 (2.60s).
  -> Success: openai_gpt-4.1-mini (2.60s).
  -> Success: google_gemini-2.0-flash-thinking-exp (2.60s).
  -> Success: google_gemini-2.5-flash-lite-preview-06-17 (2.60s).
Index 2: Launching 4 API calls in parallel...
Index 2: All API calls completed.
  -> Success: anthropic_claude-3-5-haiku-20241022 (3.74s).
  -> Success: openai_gpt-4.1-mini (3.74s).
  -> Success: google_gemini-2.0-flash-thinking-exp (3.74s).
  -> Success: google_gemini-2.5-flash-lite-preview-06-17 (3.74s).
Index 3: Launching 4 API calls in parallel...
Index 3: All API calls complet

In [42]:
def generate_and_print_sample_prompt(target_index: int, example_indices: List[int]):
    """
    Generates and prints the full user prompt for a single target index
    for debugging and inspection.

    Args:
        target_index: The GSM8K index for the final task in the prompt.
        example_indices: A list of GSM8K indices to use as few-shot examples.
    """
    print(f"--- Generating sample prompt for target index: {target_index} ---")

    # 1. Load the code strings for the few-shot examples using the same
    #    function as the main pipeline. This tests if the examples are loading correctly.
    code_examples = get_code_strings(indices=example_indices)

    # 2. Craft the full user prompt.
    full_prompt = craft_user_prompt(
        index=target_index,
        example_indices=example_indices,
        code_examples=code_examples
    )
    print(full_prompt)

generate_and_print_sample_prompt(
    target_index=6237,
    example_indices=[310, 3822, 7371]
)

--- Generating sample prompt for target index: 6237 ---
### Guidelines

0. **Output wrapping**
   Return the code inside a single ```python … ``` block, and nothing else.

1.  **Function Naming & Docstring:** The function must be named `solve`. It must begin with a docstring that has exactly two lines:
    *   The first line must be: "Index: [Index]." using the index from the task header.
    *   The second line must be a succinct, one-sentence description of what the function returns (e.g., "Returns: the total cost of wages and taxes.").

2.  **Function Arguments:** The function arguments must be derived from the 'Question' text. 
    *   Create a distinct argument for every numerical value that is directly stated in the text.
    *   The arguments should be created **in the same order in which they appear in the question**.
    *   **Note:** Some of these arguments may end up not being used in the function body. This is expected. Do not worry about this and leave the unused arguments