# Full augmentation pipeline

### Assembling a dataframe

In [2]:
import json
import re
import os
from tqdm import tqdm
import pandas as pd

# Function to read jsonl files
def read_jsonl(filename):
    data = []
    with open(filename, 'r', encoding='utf-8') as f:
        for line in f:
            try:
                data.append(json.loads(line))
            except json.JSONDecodeError:
                print(f"Error decoding line in {filename}")
    return data

def extract_context(prompt):
    """Extract just the problem context from the prompt."""
    # Split at "Given the following output/input:"
    parts = re.split(r"Given the following (output|input):", prompt)
    if len(parts) > 1:
        raw_context = parts[0]
        
        # Remove template beginning
        template_start = "You are given a question that requires some input and output variables as follows:\n\n"
        if raw_context.startswith(template_start):
            raw_context = raw_context[len(template_start):]
        
        return raw_context.strip()
    return None

def extract_reference_code(prompt):
    """Extract only the runnable Python code from the reference section."""
    # The reference code appears after this marker
    marker = "Tip: Here is a reference code snippet for this question."
    if marker in prompt:
        # Get everything after the marker
        code_part = prompt.split(marker)[1].strip()
        
        # Look for actual Python code patterns
        lines = code_part.split('\n')
        cleaned_lines = []
        code_started = False
        
        for line in lines:
            # Detect the start of actual code by looking for common Python patterns
            if not code_started and (line.startswith('import ') or 
                                    line.startswith('from ') or 
                                    line.startswith('def ') or 
                                    line.startswith('class ') or 
                                    line.startswith('# ')):
                code_started = True
                
            if code_started:
                cleaned_lines.append(line)
                
        # Join the actual code lines
        clean_code = '\n'.join(cleaned_lines)
        return clean_code
    return None

def process_data(input_file, max_rows=None):
    """Process the dataset and return a DataFrame with the extracted components."""
    data = read_jsonl(input_file)
    
    if max_rows is not None:
        data = data[:max_rows]
    
    records = []
    for item in tqdm(data):
        if 'prompt' not in item:
            continue
        
        context = extract_context(item['prompt'])
        reference_code = extract_reference_code(item['prompt'])
        
        if context and reference_code:
            records.append({
                'context': context,
                'reference_code': reference_code,
            })
    
    return pd.DataFrame(records)

# --- Configurable parameters ---
NUM_ROWS = None  # Set to None to process all rows
# OUTPUT_CSV = '../processed_data/extracted_contexts.csv'

# Setup paths
data_dir = '../generated_data'
input_file = os.path.join(data_dir, 'ast-pyedur_full_subset.jsonl')

# Process data and create DataFrame
df = process_data(input_file, max_rows=NUM_ROWS)

# Display first few rows
df.head()

100%|██████████| 10000/10000 [00:00<00:00, 23647.27it/s]


Unnamed: 0,context,reference_code
0,"In a 3D space, a drone is navigating towards a...",# import necessary packages\nimport numpy as n...
1,Given a differential equation and its initial ...,# import necessary packages\nimport math\nimpo...
2,"Given a list of integers, what are all the pos...",# import necessary packages\nimport json\n\n# ...
3,"Given a set of numerical data points, what are...",# import necessary packages\nimport math\nimpo...
4,"Given an even number greater than 2, what are ...",# import necessary packages\nfrom random impor...


### Generate input_generator prompt column

In [3]:
import re

# Function to extract input and output specifications from context
def extract_io_specs(context):
    """Extract input and output specifications from the problem context."""
    input_spec = ""
    output_spec = ""
    
    # Try to find input specifications
    input_match = re.search(r'Input:\s*(.*?)(?=Output:|$)', context, re.DOTALL)
    if input_match:
        input_spec = input_match.group(1).strip()
    
    # Try to find output specifications
    output_match = re.search(r'Output:\s*(.*?)(?=\n\n|$)', context, re.DOTALL)
    if output_match:
        output_spec = output_match.group(1).strip()
    
    return input_spec, output_spec

# Define the base input generator prompt template
inputgen_prompt_template = """
You are an expert programmer tasked with creating an input generator function for a given code snippet. This function will be used to generate test inputs for the code.

I'll provide you with a reference code implementation. Your job is to create a Python function called `input_generator()` that:

- You need to provide a function named `input_generator` that generates the input arguments for the `main_solution` function.
- The `input_generator` function should not require any input arguments, and each time it is called, it should return a set of input arguments that meet the requirements of the `main_solution` function.
- The output of `input_generator` should always be a dictionary because we always call by `**kwargs` in the `main_solution` function.
- Add some randomness in the `input_generator` function to ensure the input arguments are different each time it is called.
- Please try to make the generated input arguments as reasonable as possible, try to avoid generating too complex or too trivial input variables, also the size of the variables should be reasonable, like less than 1KB.

The input and output requirements for the main function are as follows:

Input:
{input_spec}

Output:
{output_spec}

The input generator should ONLY generate the inputs, not execute the main function or process any outputs.

Here is the reference code:    
```python	
{reference_code}
```
Please respond with ONLY the input_generator() function definition. Your response should start with "import" statements if needed, followed by the function definition. Do not include any explanations or other text.
"""

# Function to create a customized prompt for each row
def create_input_generator_prompt(row):
    input_spec, output_spec = extract_io_specs(row['context'])
    return inputgen_prompt_template.format(input_spec=input_spec, output_spec=output_spec, reference_code=row['reference_code'])

# Apply the function to each row and add as a new column
df['input_generator_prompt'] = df.apply(create_input_generator_prompt, axis=1)

df.to_csv('../generated_data/augmented_data_with_prompts.csv', index=False) 
print(f"Saved DataFrame with {len(df)} rows to augmented_data_with_prompts.csv")

Saved DataFrame with 10000 rows to augmented_data_with_prompts.csv


### use deepseek api to generate input generators

In [4]:
key = "sk-19ae0ae5f65940a2869149d2a0fe2c82"

### Old sequential api call 

In [None]:
# import time
# from openai import OpenAI

# def generate_input_generators(df, api_key, max_rows=None, temperature=0.3, batch_size=10, sleep_time=1):
#     """
#     Generate input generators for each row in the DataFrame using the Deepseek API.
    
#     Args:
#         df: DataFrame containing an 'input_generator_prompt' column
#         api_key: Deepseek API key
#         max_rows: Maximum number of rows to process (None for all rows)
#         temperature: Temperature setting for the API (0.0 to 1.0)
#         batch_size: Number of API calls before pausing to avoid rate limits
#         sleep_time: Seconds to sleep between batches
        
#     Returns:
#         DataFrame with a new 'input_generator' column containing the generated code
#     """
#     # Create a copy to avoid modifying the original DataFrame
#     result_df = df.copy()
    
#     # Limit rows if specified
#     if max_rows is not None:
#         result_df = result_df.iloc[:max_rows].copy()
    
#     # Initialize the API client
#     client = OpenAI(api_key=api_key, base_url="https://api.deepseek.com")
    
#     # Initialize the new column
#     result_df['input_generator'] = None
    
#     # Process each row with progress bar
#     for idx, row in tqdm(result_df.iterrows(), total=len(result_df), desc="Generating input generators"):
#         if idx > 0 and idx % batch_size == 0:
#             print(f"Pausing for {sleep_time} seconds to avoid rate limits...")
#             time.sleep(sleep_time)
        
#         prompt = row['input_generator_prompt']
#         if pd.isna(prompt) or not prompt.strip():
#             print(f"Skipping row {idx}: Empty prompt")
#             continue
            
#         try:
#             # Call the Deepseek API
#             response = client.chat.completions.create(
#                 model="deepseek-chat",
#                 messages=[
#                     {"role": "system", "content": "You are an expert Python programmer. Provide only valid, runnable Python code."},
#                     {"role": "user", "content": prompt}
#                 ],
#                 temperature=temperature,
#                 stream=False
#             )
            
#             # Extract the generated code
#             generated_code = response.choices[0].message.content
            
#             # Store the response in the DataFrame
#             result_df.at[idx, 'input_generator'] = generated_code
            
#         except Exception as e:
#             print(f"Error in row {idx}: {str(e)}")
    
#     # Save progress to CSV (in case of interruption)
#     result_df.to_csv('../generated_data/df_with_input_generators.csv', index=False)
#     print(f"Saved DataFrame with {len(result_df)} rows to df_with_input_generators.csv")
    
#     return result_df




In [24]:
import pandas as pd
import time
import json
import os
from openai import OpenAI
from concurrent.futures import ThreadPoolExecutor, as_completed
from tqdm import tqdm

def call_deepseek_api_with_retry(client, prompt, temperature, max_retries=3, backoff_factor=2, timeout=30):
    """
    Call the Deepseek API with retry logic and exponential backoff.
    
    Args:
        client: OpenAI client instance
        prompt: Prompt to send to the API
        temperature: Temperature setting (0.0 to 1.0)
        max_retries: Maximum number of retries on failure
        backoff_factor: Factor to increase wait time between retries
        timeout: Timeout in seconds for the API call
        
    Returns:
        Generated code from Deepseek or error message
    """
    if pd.isna(prompt) or not prompt.strip():
        return None
        
    retries = 0
    while retries <= max_retries:
        try:
            # Add small delay before API call to prevent rate limiting
            time.sleep(0.5)
            
            response = client.chat.completions.create(
                model="deepseek-chat",
                messages=[
                    {"role": "system", "content": "You are an expert Python programmer. Provide only valid, runnable Python code."},
                    {"role": "user", "content": prompt}
                ],
                temperature=temperature,
                stream=False,
                timeout=timeout  # Add timeout to prevent hanging requests
            )
            return response.choices[0].message.content
            
        except Exception as e:
            retries += 1
            if retries > max_retries:
                return f"# ERROR: {str(e)}"
                
            # Calculate wait time with exponential backoff
            wait_time = backoff_factor ** retries
            print(f"API call failed, retrying in {wait_time:.1f}s... ({str(e)})")
            time.sleep(wait_time)
    
    return "# ERROR: Maximum retries exceeded"

def generate_input_generators_parallel(df, api_key, max_rows=None, temperature=0.3, max_workers=5, 
                                      save_interval=25, output_dir='../generated_data'):
    """
    Generate input generators in parallel with improved robustness.
    
    Args:
        df: DataFrame with input_generator_prompt column
        api_key: Deepseek API key
        max_rows: Maximum number of rows to process (None for all)
        temperature: Temperature setting for the API
        max_workers: Maximum number of parallel threads
        save_interval: Save interim results every N completed items
        output_dir: Directory to save output files
        
    Returns:
        DataFrame with generated input_generator column
    """
    # Create output directory if it doesn't exist
    os.makedirs(output_dir, exist_ok=True)
    
    # Prepare the result DataFrame
    result_df = df.copy()
    if max_rows is not None:
        result_df = result_df.iloc[:max_rows].copy()

    # Initialize API client
    client = OpenAI(api_key=api_key, base_url="https://api.deepseek.com")
    prompts = result_df['input_generator_prompt'].tolist()

    # Pre-allocate results list
    results = [None] * len(prompts)
    completed_count = 0
    
    # Start time for logging
    start_time = time.time()
    
    print(f"Starting parallel API calls with {max_workers} workers, temperature={temperature}")
    
    with ThreadPoolExecutor(max_workers=max_workers) as executor:
        # Submit all tasks
        futures = {
            executor.submit(
                call_deepseek_api_with_retry, 
                client, 
                prompt, 
                temperature
            ): idx for idx, prompt in enumerate(prompts)
        }
        
        # Process results as they complete
        for future in tqdm(as_completed(futures), total=len(futures), desc="Parallel API calls"):
            idx = futures[future]
            results[idx] = future.result()
            
            # Increment completed count
            completed_count += 1
            
            # Periodically save progress
            if completed_count % save_interval == 0:
                interim_df = result_df.copy()
                interim_df['input_generator'] = results
                
                # Save to CSV with timestamp
                timestamp = int(time.time())
                interim_file = os.path.join(
                    output_dir, 
                    f"df_with_input_generators_interim_{completed_count}_{timestamp}.csv"
                )
                
                interim_df.to_csv(interim_file, index=False)
                print(f"\nInterim progress saved ({completed_count}/{len(prompts)}) to {interim_file}")
                
                # Also log some stats
                elapsed = time.time() - start_time
                rate = completed_count / elapsed
                remaining = (len(prompts) - completed_count) / rate if rate > 0 else 0
                
                print(f"Elapsed: {elapsed:.1f}s | Rate: {rate:.2f} items/s | Est. remaining: {remaining:.1f}s")
    
    # Store results in DataFrame
    result_df['input_generator'] = results
    
    # Save final results
    final_file = os.path.join(output_dir, "df_with_input_generators.csv")
    result_df.to_csv(final_file, index=False)
    print(f"Saved DataFrame with {len(result_df)} rows to {final_file}")
    
    # Count successes and failures
    success_count = sum(1 for r in results if r is not None and not str(r).startswith("# ERROR:"))
    error_count = sum(1 for r in results if r is not None and str(r).startswith("# ERROR:"))
    empty_count = sum(1 for r in results if r is None)
    
    print(f"Results summary:")
    print(f"  Success: {success_count} ({success_count/len(results)*100:.1f}%)")
    print(f"  Errors: {error_count} ({error_count/len(results)*100:.1f}%)")
    print(f"  Empty: {empty_count} ({empty_count/len(results)*100:.1f}%)")
    
    return result_df

# Example usage:
df_with_generators = generate_input_generators_parallel(
    df=df,
    api_key=key,
    max_rows=10,       # Process 10 rows (adjust as needed)
    temperature=0.3,     # Good balance for code generation
    max_workers=5,       # Adjust based on your system and API limits
    save_interval=50     # Save progress every 50 completions
)


Starting parallel API calls with 5 workers, temperature=0.3


Parallel API calls: 100%|██████████| 10/10 [01:12<00:00,  7.23s/it]

Saved DataFrame with 10 rows to ../generated_data\df_with_input_generators.csv
Results summary:
  Success: 10 (100.0%)
  Errors: 0 (0.0%)
  Empty: 0 (0.0%)





### Remove ```python tag

In [25]:
import pandas as pd
import re

def clean_code_block(code_text):
    """
    Clean a code block by:
    1. Removing ```python from the start
    2. Replacing all \n with spaces
    3. Removing ``` from the end
    """
    if pd.isna(code_text):
        return code_text
        
    # Remove ```python from the start
    if code_text.strip().startswith("```python"):
        code_text = code_text.replace("```python", "", 1)
    
    # Remove ``` from the end
    code_text = re.sub(r"```$", "", code_text.strip())
    
    return code_text.strip()

# Load the DataFrame (if not already loaded)
# df_with_generators = pd.read_csv('../generated_data/df_with_input_generators.csv')

# Apply the cleaning function to the input_generator column
df_with_generators['input_generator'] = df_with_generators['input_generator'].apply(clean_code_block)

# Save the cleaned DataFrame
df_with_generators.to_csv('../generated_data/df_with_clean_input_generators.csv', index=False)

# Display sample of cleaned data
print("Sample of cleaned data:")
print(df_with_generators.head(2))

Sample of cleaned data:
                                             context  \
0  In a 3D space, a drone is navigating towards a...   
1  Given a differential equation and its initial ...   

                                      reference_code  \
0  # import necessary packages\nimport numpy as n...   
1  # import necessary packages\nimport math\nimpo...   

                              input_generator_prompt  \
0  \nYou are an expert programmer tasked with cre...   
1  \nYou are an expert programmer tasked with cre...   

                                     input_generator  
0  import numpy as np\nimport random\n\ndef input...  
1  import random\nimport math\n\ndef input_genera...  


### creating IO pairs

In [26]:
"""
Note 1: Currently a parameter max_attempts is manually set to 15 in the script template.
This may not be a sufficiently large number for some cases, and the script may fail to generate the required number 
of examples.
Note 2: There is no guarantee that there will be exactly num_examples examples generated. For many reasons errors 
can arrise (timeout, etc.). We need a larger subset 
Note 3: current progres bar formatting generates more than 500 output lines so you cannot track progress at a certain point.
-> maybe reduce the number of lines printed for the final run 
"""

import subprocess
import os
import signal
from tqdm import tqdm
import shutil
import sys
import traceback

def generate_io_pairs_robust(df, max_rows=None, num_examples=5, timeout_seconds=60, parallel=False, n_processes=4):
    """
    Generate I/O pairs using the robust approach, ensuring exactly num_examples examples when possible
    
    Args:
        df: DataFrame with context, reference_code, and input_generator columns
        max_rows: Maximum number of rows to process (None for all rows)
        num_examples: Number of I/O examples to generate per problem (default: 5)
        timeout_seconds: Maximum execution time per problem in seconds
        parallel: Whether to use parallel processing
        n_processes: Number of processes to use if parallel=True
    
    Returns:
        DataFrame with added column: io_pairs (always with exactly num_examples pairs when successful)
    """
    # Create a copy to avoid modifying the original DataFrame
    result_df = df.copy()
    
    # Limit rows if specified
    if max_rows is not None:
        result_df = result_df.iloc[:max_rows].copy()
    
    # Initialize new column
    result_df['io_pairs'] = None
    
    # Create temp directory for script execution
    temp_dir = os.path.join(os.getcwd(), "temp_scripts")
    os.makedirs(temp_dir, exist_ok=True)
    
    print(f"Generating exactly {num_examples} I/O pairs for {len(result_df)} rows...")
    
    # Define the template for the script - NOTE: doubled curly braces to escape them
    script_template = """
import json
import math
import random
import re
import collections
import itertools
import functools
import operator
import copy
import bisect
import sys
import nltk
from typing import List, Dict, Set, Tuple, Optional, Union, Any

# For size checking without pympler dependency
def strict_check_size(obj):
    # Check for dict type
    if isinstance(obj, dict):
        if len(obj) >= 20:  # Check dict has fewer than 20 key-value pairs
            return False
        # Recursively check keys and values
        for k, v in obj.items():
            if not strict_check_size(k) or not strict_check_size(v):
                return False

    # Check for list, tuple, or set
    elif isinstance(obj, (list, tuple, set)):
        if len(obj) >= 20:  # Check if the length is less than 20
            return False
        # Recursively check each element
        for item in obj:
            if not strict_check_size(item):
                return False

    # Check for string
    elif isinstance(obj, str):
        if len(obj) >= 100:  # Check if string length is less than 100 characters
            return False

    # If all checks are passed, return True
    return True

# Print debug info
print("DEBUG: Loading reference code and input generator...")

# Reference code (solution implementation)
{reference_code}

# Input generator
{input_generator}

print("DEBUG: Looking for main function...")

# Find the main function
main_fn = None
callable_funcs = []

for name in list(globals().keys()):
    if callable(globals()[name]) and name not in ['strict_check_size', 'input_generator']:
        callable_funcs.append(name)
        if name == 'main_solution':
            main_fn = globals()['main_solution']
            print(f"DEBUG: Found main_solution function")
            break
        elif 'main' in name.lower():
            main_fn = globals()[name]
            print(f"DEBUG: Found function with 'main' in name: {{name}}")
            break

if not main_fn and callable_funcs:
    # Try any other callable function that's not input_generator
    main_fn = globals()[callable_funcs[0]]
    print(f"DEBUG: Using alternative function: {{callable_funcs[0]}}")

if not main_fn:
    print("DEBUG: Cannot find main function. Available functions:")
    for name in callable_funcs:
        print(f"DEBUG: - {{name}}")
    print("DEBUG: ABORTING")
    sys.exit(1)

# Generate I/O pairs
diff_inputs = []
corr_outputs = []

print("DEBUG: Generating I/O pairs...")

# Maximum attempts to get the required number of examples
max_attempts = 15
attempts = 0

while len(diff_inputs) < {num_examples} and attempts < max_attempts:
    attempts += 1
    try:
        # Generate candidate input
        cand_input = input_generator()
        
        # Ensure inputs are unique and not too large
        if cand_input not in diff_inputs and strict_check_size(cand_input):
            try:
                # Call the function with the input
                print(f"DEBUG: Calling main function with input {{attempts}} ({{len(diff_inputs)}}/{{{num_examples}}} examples generated)")
                cand_output = main_fn(**cand_input)
                
                # Check if output is valid and not too large
                if strict_check_size(cand_output) and cand_output is not None:
                    diff_inputs.append(cand_input)
                    corr_outputs.append(cand_output)
                    print(f"DEBUG: Successfully generated example {{len(diff_inputs)}}/{{{num_examples}}}")
            except Exception as e:
                print(f"DEBUG: Error calling main function: {{str(e)}}")
                continue
            
    except Exception as e:
        print(f"DEBUG: Error generating input: {{str(e)}}")
        continue

print(f"DEBUG: Generated {{len(diff_inputs)}} I/O pairs after {{attempts}} attempts")

# Prepare the output - ensure we have exactly num_examples or fewer if we couldn't generate enough
if len(diff_inputs) > 0:
    assert len(diff_inputs) == len(corr_outputs)
    # Limit to exactly num_examples if we have more (shouldn't happen with the loop above, but just in case)
    if len(diff_inputs) > {num_examples}:
        diff_inputs = diff_inputs[:{num_examples}]
        corr_outputs = corr_outputs[:{num_examples}]
    
    iolist = [{{"input": diff_inputs[i], "output": corr_outputs[i]}} for i in range(len(diff_inputs))]
    # Print the result with markers for extraction
    print("[JSON IOS START]" + json.dumps(iolist) + "[JSON IOS END]")
else:
    print("DEBUG: Failed to generate any valid I/O pairs")
    print("[JSON IOS START][]" + "[JSON IOS END]")
"""
    
    def process_row(idx, row):
        """Process a single DataFrame row to generate I/O pairs"""
        if not row.get('reference_code') or not row.get('input_generator'):
            print(f"[Row {idx}] Missing reference_code or input_generator")
            return None
        
        # Create a temporary directory for each process in parallel mode
        if parallel:
            process_temp_dir = os.path.join(temp_dir, f"proc_{os.getpid()}")
            os.makedirs(process_temp_dir, exist_ok=True)
            this_temp_dir = process_temp_dir
        else:
            this_temp_dir = temp_dir
        
        # Create the script content
        try:
            script_content = script_template.format(
                reference_code=row['reference_code'],
                input_generator=row['input_generator'],
                num_examples=num_examples
            )
        except Exception as e:
            print(f"[Row {idx}] Error formatting script: {str(e)}")
            return None
        
        # Write the script to a temporary file
        script_path = os.path.join(this_temp_dir, f"script_{idx}.py")
        with open(script_path, 'w', encoding='utf-8') as f:
            f.write(script_content)
        
        try:
            # Execute the script as a separate process with timeout
            # Start the subprocess in a new session (process group)
            print(f"[Row {idx}] Executing script...")
            if os.name == 'nt':  # Windows
                process = subprocess.Popen(
                    [sys.executable, script_path],
                    stdout=subprocess.PIPE,
                    stderr=subprocess.PIPE,
                    text=True,
                    creationflags=subprocess.CREATE_NEW_PROCESS_GROUP
                )
            else:  # Unix/Linux/Mac
                process = subprocess.Popen(
                    [sys.executable, script_path],
                    stdout=subprocess.PIPE,
                    stderr=subprocess.PIPE,
                    text=True,
                    start_new_session=True
                )
            
            try:
                stdout, stderr = process.communicate(timeout=timeout_seconds)
            except subprocess.TimeoutExpired:
                # Kill the process if it times out
                print(f"[Row {idx}] Process timed out after {timeout_seconds} seconds")
                if os.name == 'nt':  # Windows
                    subprocess.call(['taskkill', '/F', '/T', '/PID', str(process.pid)])
                else:  # Unix/Linux/Mac
                    os.killpg(os.getpgid(process.pid), signal.SIGTERM)
                process.wait()
                return None
            
            # Extract I/O pairs from the output
            start_marker = "[JSON IOS START]"
            end_marker = "[JSON IOS END]"
            
            if start_marker in stdout and end_marker in stdout:
                start_index = stdout.index(start_marker) + len(start_marker)
                end_index = stdout.index(end_marker)
                json_str = stdout[start_index:end_index].strip()
                
                # Parse the JSON data
                try:
                    io_pairs = json.loads(json_str)
                    if io_pairs:
                        print(f"[Row {idx}] Successfully generated {len(io_pairs)} I/O pairs")
                        return io_pairs
                    else:
                        print(f"[Row {idx}] Generated empty IO pairs list")
                        return None
                except json.JSONDecodeError:
                    print(f"[Row {idx}] JSON decode error")
                    print(f"[Row {idx}] JSON string: {json_str[:100]}...")
                    return None
            else:
                print(f"[Row {idx}] No JSON markers found in output")
                print(f"[Row {idx}] STDOUT: {stdout[:500]}...")
                print(f"[Row {idx}] STDERR: {stderr[:500]}...")
                return None
            
        except Exception as e:
            print(f"[Row {idx}] Error: {str(e)}")
            print(f"[Row {idx}] Traceback: {traceback.format_exc()}")
            return None
        finally:
            # Clean up the temporary script file
            if os.path.exists(script_path):
                try:
                    os.remove(script_path)
                except:
                    pass
    
    # Process each row
    if parallel and n_processes > 1:
        # Parallel processing
        from multiprocessing import Pool
        
        with Pool(processes=n_processes) as pool:
            results = list(tqdm(
                pool.starmap(
                    process_row, 
                    [(idx, row) for idx, row in result_df.iterrows()]
                ),
                total=len(result_df)
            ))
        
        # Update the DataFrame with results
        for idx, io_pairs in enumerate(results):
            result_df.iloc[idx, result_df.columns.get_loc('io_pairs')] = io_pairs
    else:
        # Sequential processing
        for idx, row in tqdm(result_df.iterrows(), total=len(result_df)):
            io_pairs = process_row(idx, row)
            result_df.at[idx, 'io_pairs'] = io_pairs
    
    # Clean up the temporary directory
    try:
        shutil.rmtree(temp_dir)
    except:
        pass
    
    return result_df

def add_robust_io_pairs(df, num_examples=5, max_rows=None, timeout_seconds=60, parallel=False, n_processes=4):
    """
    Add robust I/O pairs to a DataFrame, ensuring exactly num_examples examples when possible.
    
    Args:
        df: DataFrame with context, reference_code, and input_generator columns
        num_examples: Number of I/O examples to generate per row (default: 5)
        max_rows: Maximum number of rows to process
        timeout_seconds: Maximum seconds to allow for each row's execution
        parallel: Whether to use parallel processing
        n_processes: Number of processes for parallel execution
        
    Returns:
        DataFrame with added io_pairs column
    """
    return generate_io_pairs_robust(
        df, 
        max_rows=max_rows, 
        num_examples=num_examples, 
        timeout_seconds=timeout_seconds,
        parallel=parallel,
        n_processes=n_processes
    )

In [27]:
# Define your parameters
NUM_EXAMPLES = 5  # Number of I/O pairs to generate per problem
MAX_ROWS = 10  # Start with a small number for testing (None for all rows)
TIMEOUT = 120  # Maximum seconds per problem <- 30 had 2 hours of runtime for 3000 rows
USE_PARALLEL = False  # Set to True for faster processing if you have multiple cores
N_PROCESSES = 4  # Number of parallel processes if USE_PARALLEL=True

# Add robust I/O pairs to the DataFrame
df_with_io = add_robust_io_pairs(
    df_with_generators,
    num_examples=NUM_EXAMPLES,
    max_rows=MAX_ROWS,
    timeout_seconds=TIMEOUT, 
    parallel=USE_PARALLEL,
    n_processes=N_PROCESSES
)

# Save to jsonl file
# output_jsonl = '../generated_data/augmented_data.jsonl'
# df_with_io.to_json(output_jsonl, orient='records', lines=True, force_ascii=False)
# print(f"Processed {len(df_with_io)} rows. Output written to {output_jsonl}")

# Save to csv file
output_csv = '../generated_data/augmented_data_10.csv'
df_with_io.to_csv(output_csv, index=False, encoding='utf-8')
print(f"Processed {len(df_with_io)} rows. Output written to {output_csv}")

# Show a sample of the results
sample_idx = min(5, len(df_with_io)-1)  # Get a valid index
if df_with_io.at[sample_idx, 'io_pairs']:
    print("\nSample I/O pairs for row", sample_idx)
    print(json.dumps(df_with_io.at[sample_idx, 'io_pairs'], indent=2))
    
    print("\nSample I/O prompt:")

Generating exactly 5 I/O pairs for 10 rows...


  0%|          | 0/10 [00:00<?, ?it/s]

[Row 0] Executing script...


 10%|█         | 1/10 [00:03<00:30,  3.38s/it]

[Row 0] Successfully generated 5 I/O pairs
[Row 1] Executing script...


 20%|██        | 2/10 [00:06<00:26,  3.25s/it]

[Row 1] Successfully generated 5 I/O pairs
[Row 2] Executing script...


 30%|███       | 3/10 [00:09<00:22,  3.16s/it]

[Row 2] Successfully generated 5 I/O pairs
[Row 3] Executing script...


 40%|████      | 4/10 [00:12<00:18,  3.07s/it]

[Row 3] Successfully generated 4 I/O pairs
[Row 4] Executing script...


 50%|█████     | 5/10 [00:15<00:15,  3.10s/it]

[Row 4] Successfully generated 5 I/O pairs
[Row 5] Executing script...


 60%|██████    | 6/10 [00:18<00:12,  3.09s/it]

[Row 5] Successfully generated 5 I/O pairs
[Row 6] Executing script...


 70%|███████   | 7/10 [00:21<00:09,  3.08s/it]

[Row 6] Successfully generated 5 I/O pairs
[Row 7] Executing script...


 80%|████████  | 8/10 [00:24<00:06,  3.07s/it]

[Row 7] Successfully generated 5 I/O pairs
[Row 8] Executing script...


 90%|█████████ | 9/10 [00:27<00:03,  3.08s/it]

[Row 8] Successfully generated 5 I/O pairs
[Row 9] Executing script...


100%|██████████| 10/10 [00:31<00:00,  3.12s/it]

[Row 9] Successfully generated 5 I/O pairs
Processed 10 rows. Output written to ../generated_data/augmented_data_10.csv

Sample I/O pairs for row 5
[
  {
    "input": {
      "location": [
        39,
        44
      ],
      "initialFoodLevel": 896,
      "food_locations": [
        [
          6,
          62
        ]
      ]
    },
    "output": {
      "food_level": 896,
      "ant_count": 0,
      "state": "HEALTHY"
    }
  },
  {
    "input": {
      "location": [
        41,
        6
      ],
      "initialFoodLevel": 155,
      "food_locations": [
        [
          57,
          47
        ],
        [
          38,
          92
        ],
        [
          76,
          25
        ],
        [
          64,
          75
        ],
        [
          24,
          16
        ],
        [
          95,
          26
        ],
        [
          85,
          78
        ]
      ]
    },
    "output": {
      "food_level": 155,
      "ant_count": 0,
      "state": "HEALTH




In [None]:


# def limit_io_pairs(df, max_examples=5):
#     """
#     Limit the number of I/O pairs in each row to max_examples.
    
#     Args:
#         df: DataFrame with an 'io_pairs' column
#         max_examples: Maximum number of examples to keep
        
#     Returns:
#         DataFrame with limited io_pairs
#     """
#     # Create a copy to avoid modifying the original DataFrame
#     result_df = df.copy()
    
#     # Function to limit the examples in a single row
#     def limit_examples(io_pairs):
#         if io_pairs is None:
#             return None
        
#         # If it's a list of dictionaries with 'input' and 'output' keys
#         if isinstance(io_pairs, list):
#             return io_pairs[:max_examples]
            
#         return io_pairs
    
#     # Apply the function to each row
#     result_df['io_pairs'] = result_df['io_pairs'].apply(limit_examples)
    
#     return result_df

# # Limit the number of I/O pairs to 5 per row
# df_with_limited_io = limit_io_pairs(df_with_io, max_examples=5)

# # Save to CSV
# output_csv = '../processed_data/df_with_limited_io_pairs.csv'
# df_with_limited_io.to_csv(output_csv, index=False)
# print(f"Processed {len(df_with_limited_io)} rows. Output written to {output_csv}")

# # Print example to verify
# if len(df_with_limited_io) > 0:
#     sample_idx = 0
#     sample_row = df_with_limited_io.iloc[sample_idx]
#     if sample_row['io_pairs'] is not None:
#         print(f"\nSample row {sample_idx} has {len(sample_row['io_pairs'])} I/O pairs:")
#         import json
#         print(json.dumps(sample_row['io_pairs'], indent=2)[:500] + "...")

Processed 2775 rows. Output written to ../processed_data/df_with_limited_io_pairs.csv

Sample row 0 has 5 I/O pairs:
[
  {
    "input": {
      "nums": [
        7,
        -9,
        6
      ]
    },
    "output": [
      [
        6,
        -9,
        7
      ],
      [
        -9,
        6,
        7
      ],
      [
        7,
        6,
        -9
      ],
      [
        6,
        7,
        -9
      ],
      [
        -9,
        7,
        6
      ],
      [
        7,
        -9,
        6
      ]
    ]
  },
  {
    "input": {
      "nums": [
        -3,
        -10
      ]
    },
    "output": [...
