In [67]:
import ast
import sys
from pathlib import Path
import shutil

# ---------------------------------------------------------------------- #
#  Global constants & Configuration
# ---------------------------------------------------------------------- #

def find_project_root():
    """Traverse upwards to find the project root, marked by the .git folder."""
    current_path = Path.cwd()
    while current_path != current_path.parent:
        if (current_path / ".git").is_dir():
            return current_path
        current_path = current_path.parent
    raise FileNotFoundError("Could not find project root. Is this a git repository?")


PROJECT_ROOT = find_project_root()
BASE_INPUT_DIR = PROJECT_ROOT / 'data' / 'code_gen_outputs_cleaned'
BASE_OUTPUT_DIR = PROJECT_ROOT / 'data' / 'code_gen_outputs_formatted'
print(f"Project root found: {PROJECT_ROOT}")
print(f"Base input directory set to: {BASE_INPUT_DIR}")
print(f"Base output directory set to: {BASE_OUTPUT_DIR}")

MODEL_DICT = {
  "anthropic": ["claude-3-5-haiku-20241022"], 
  "openai": ["gpt-4.1-mini"],
  "google": ["gemini-2.0-flash-thinking-exp", 
             "gemini-2.5-flash-lite-preview-06-17",
             "gemini-2.5-flash"]
}

MODELS = [f"{provider}_{model}" for provider, sublist in MODEL_DICT.items() for model in sublist]
print(f"Available models: {MODELS}")

INDICES = list(range(100))

Project root found: /Users/arvindsuresh/Documents/Github/Erdos-DL-June25-Math
Base input directory set to: /Users/arvindsuresh/Documents/Github/Erdos-DL-June25-Math/data/code_gen_outputs_cleaned
Base output directory set to: /Users/arvindsuresh/Documents/Github/Erdos-DL-June25-Math/data/code_gen_outputs_formatted
Available models: ['anthropic_claude-3-5-haiku-20241022', 'openai_gpt-4.1-mini', 'google_gemini-2.0-flash-thinking-exp', 'google_gemini-2.5-flash-lite-preview-06-17', 'google_gemini-2.5-flash']


In [68]:
import re

def get_code_lines_dict(
        problem_index: int, 
        model_name: str, 
        base_input_dir: Path = BASE_INPUT_DIR):
    """
    Returns a dict mapping line numbers (0-based) to the verbatim lines of code
    for the given problem index and model name.
    """
    import sys
    problem_dir = base_input_dir / str(problem_index)
    file_path = problem_dir / f"{model_name}.py"
    if not file_path.exists():
        print(f"[Error] File not found: {file_path}", file=sys.stderr)
        return None
    with open(file_path, "r", encoding="utf-8") as f:
        lines = f.readlines()
    return {i: line.rstrip("\n") for i, line in enumerate(lines)}


def replace_final_answer_comment(code_lines_dict):
    """
    Returns a new dict where any line that begins with 'answer =' (ignoring leading whitespace)
    is preceded by a '#: FA' comment line.
    """
    new_dict = {}
    new_idx = 0
    for idx in range(len(code_lines_dict)):
        line = code_lines_dict[idx]
        if line.lstrip().startswith("answer ="):
            new_dict[new_idx] = "#: FA"
            new_dict[new_idx + 1] = line
            new_idx += 2
        else:
            new_dict[new_idx] = line
            new_idx += 1
    return new_dict

def format_comment(comment_line):
    """
    Accepts a comment string. Returns '\n' if not in the format '#: ... Ln' or '#: FA'.
    If correct, returns '#: Ln' or '#: FA' (ignoring everything between).
    """
    comment_line = comment_line.strip()
    if comment_line.startswith("#:") and comment_line.endswith("FA"):
        return "#: FA"
    m = re.match(r"^#:(.*)L(\d+)\s*$", comment_line)
    if m:
        return f"#: L{m.group(2)}"
    return "\n"

def remove_trailing_comment(code_line):
    """
    Removes trailing comment from a code line.
    If it finds the trailing comment '# FINAL ANSWER', then it returns ': FA' on one line and the code on the next. If it finds any other comment, it just removes the comment and returns the code line.
    """
    if "#" not in code_line:
        return code_line
    code, _ = code_line.split("#", 1)
    return code.rstrip()

def remove_redundant_comments(lines):
    """
    Removes any comment line that is immediately followed by another comment line.
    But keeps '#: FA' comments even if followed by 'answer =' lines.
    """
    result = []
    for i, line in enumerate(lines):
        if line.startswith("#:"):
            # If next line is also a comment, skip this one
            if i + 1 < len(lines) and lines[i + 1].startswith("#:"):
                continue
            # If next line begins with "answer" and this is NOT a FA comment, skip this one
            if (i + 1 < len(lines) and 
                lines[i + 1].lstrip().startswith("answer") and 
                line.strip() != "#: FA"):
                continue
        result.append(line)
    return result

def format_generated_code(code_lines_dict):
    """
    Cleans and formats generated code according to the specified rules.
    """

    # Format the final answer portion
    code_lines_dict = replace_final_answer_comment(code_lines_dict)

    # Split into lines for further processing
    lines = list(code_lines_dict.values())

    # Find the end of the func signature/docstring
    sig_end = 0
    for i, line in enumerate(lines):
        if line.strip().endswith('"""'):
            sig_end = i
            break
    signature = lines[:sig_end + 1]
    body = lines[sig_end + 1:]

    # Format comments and remove trailing comments
    processed = []
    for line in body:
        stripped = line.strip()
        if stripped.startswith("#"):
            processed.append(format_comment(stripped))
        elif stripped == "":
            processed.append("\n")
        else:
            processed.append(remove_trailing_comment(line))

    # Remove all line breaks, keep only comments or code
    processed = [l for l in processed if l.strip() != ""]

    # Third pass: remove redundant comments
    processed = remove_redundant_comments(processed)

    # Assemble final code: signature, then body with correct indentation and blank lines before comments
    final_lines = signature.copy()
    indent = "    "
    for line in processed:
        if line.startswith("#:"):
            final_lines.append("")  # blank line before comment
            final_lines.append(f"{indent}{line}")
        else:
            final_lines.append(f"{indent}{line.strip()}")

    # Remove any extra blank lines at the end
    while final_lines and final_lines[-1].strip() == "":
        final_lines.pop()

    return "\n".join(final_lines)

def batch_format_and_write(
    indices,
    models = MODELS,
    base_input_dir = BASE_INPUT_DIR,
    base_output_dir = BASE_OUTPUT_DIR,
    get_code_lines_dict=get_code_lines_dict,
    format_generated_code=format_generated_code
):
    """
    For each (index, model), generate formatted code and write to {base_output_dir}/{index}/{model}.py
    """
    for idx in indices:
        for model in models:
            code_lines_dict = get_code_lines_dict(idx, model, base_input_dir)
            if code_lines_dict is None:
                print(f"[Warning] No code for index={idx}, model={model}")
                continue
            formatted_code = format_generated_code(code_lines_dict)
            out_dir = base_output_dir / str(idx)
            out_dir.mkdir(parents=True, exist_ok=True)
            out_path = out_dir / f"{model}.py"
            with open(out_path, "w", encoding="utf-8") as f:
                f.write(formatted_code)
            print(f"Wrote formatted code to {out_path}")

# ---------------------------------------------------------------------- #

def format_code_wrapper(index: int):
    for model in MODELS:
        lines_dict = get_code_lines_dict(index, model)
        if lines_dict is not None:
            cleaned_code = format_generated_code(lines_dict)
            print("="*50)
            print(f"Cleaned code for {model}")
            print("-"*50)
            print(cleaned_code)
        else:
            print(f"[Error] No code found for {model} at index {index}", file=sys.stderr)

In [69]:
batch_format_and_write(INDICES)

Wrote formatted code to /Users/arvindsuresh/Documents/Github/Erdos-DL-June25-Math/data/code_gen_outputs_formatted/0/anthropic_claude-3-5-haiku-20241022.py
Wrote formatted code to /Users/arvindsuresh/Documents/Github/Erdos-DL-June25-Math/data/code_gen_outputs_formatted/0/openai_gpt-4.1-mini.py
Wrote formatted code to /Users/arvindsuresh/Documents/Github/Erdos-DL-June25-Math/data/code_gen_outputs_formatted/0/google_gemini-2.0-flash-thinking-exp.py
Wrote formatted code to /Users/arvindsuresh/Documents/Github/Erdos-DL-June25-Math/data/code_gen_outputs_formatted/0/google_gemini-2.5-flash-lite-preview-06-17.py
Wrote formatted code to /Users/arvindsuresh/Documents/Github/Erdos-DL-June25-Math/data/code_gen_outputs_formatted/0/google_gemini-2.5-flash.py
Wrote formatted code to /Users/arvindsuresh/Documents/Github/Erdos-DL-June25-Math/data/code_gen_outputs_formatted/1/anthropic_claude-3-5-haiku-20241022.py
Wrote formatted code to /Users/arvindsuresh/Documents/Github/Erdos-DL-June25-Math/data/cod

[Error] File not found: /Users/arvindsuresh/Documents/Github/Erdos-DL-June25-Math/data/code_gen_outputs_cleaned/52/anthropic_claude-3-5-haiku-20241022.py
[Error] File not found: /Users/arvindsuresh/Documents/Github/Erdos-DL-June25-Math/data/code_gen_outputs_cleaned/59/anthropic_claude-3-5-haiku-20241022.py
[Error] File not found: /Users/arvindsuresh/Documents/Github/Erdos-DL-June25-Math/data/code_gen_outputs_cleaned/71/anthropic_claude-3-5-haiku-20241022.py
[Error] File not found: /Users/arvindsuresh/Documents/Github/Erdos-DL-June25-Math/data/code_gen_outputs_cleaned/76/anthropic_claude-3-5-haiku-20241022.py
[Error] File not found: /Users/arvindsuresh/Documents/Github/Erdos-DL-June25-Math/data/code_gen_outputs_cleaned/82/anthropic_claude-3-5-haiku-20241022.py
[Error] File not found: /Users/arvindsuresh/Documents/Github/Erdos-DL-June25-Math/data/code_gen_outputs_cleaned/85/anthropic_claude-3-5-haiku-20241022.py
[Error] File not found: /Users/arvindsuresh/Documents/Github/Erdos-DL-June25

In [70]:
def execution_trace(func):
    import inspect, ast
    src = inspect.getsource(func)
    tree = ast.parse(src)
    func_def = tree.body[0]
    env = {}
    for arg, default in zip(
        func_def.args.args[::-1], 
        func_def.args.defaults[::-1]
    ):
        env[arg.arg] = eval(compile(ast.Expression(default), '', 'eval'))
    for stmt in func_def.body:
        if isinstance(stmt, ast.Assign):
            code = compile(ast.Module([stmt], []), '', 'exec')
            exec(code, {}, env)
    # Return all variables, including arguments
    return env

def augment_code_with_trace(
    problem_index: int,
    model_name: str,
    input_dir: Path,
    output_dir: Path,
    verbose: bool = True
):
    """
    Augment a single Python file with execution trace comments using simple string replacement.
    """
    # Step 1: Get the individual lines of code
    code_lines_dict = get_code_lines_dict(problem_index, model_name, input_dir)
    if code_lines_dict is None:
        if verbose:
            print(f"Could not get code lines for problem {problem_index}, model {model_name}")
        return False
    
    try:
        # Load the module and get execution trace
        input_file = input_dir / str(problem_index) / f"{model_name}.py"
        spec = importlib.util.spec_from_file_location("module.name", input_file)
        if spec is None or spec.loader is None:
            if verbose:
                print(f"Could not load module spec for {input_file}")
            return False
                
        module = importlib.util.module_from_spec(spec)
        spec.loader.exec_module(module)
        
        # Get execution trace
        trace_dict = execution_trace(module.solve)
        
        # Step 2: Find the end of function signature (including docstring)
        lines = list(code_lines_dict.values())
        sig_end = -1
        for i, line in enumerate(lines):
            if line.strip().endswith('"""'):
                sig_end = i
                break
        
        if sig_end == -1:
            if verbose:
                print(f"Could not find end of docstring for {model_name}")
            return False
        
        # Split into signature and body
        signature_lines = lines[:sig_end + 1]
        body_lines = lines[sig_end + 1:]
        
        # Step 3-5: Process body lines
        processed_body = ""  # Start as empty string instead of list
        for line in body_lines:
            # Step 3: Add comment lines
            if line.strip().startswith('#'):
                processed_body += '\n'  # Add a blank line before comments
                processed_body += line + '\n'  # Append line with newline
                continue
            elif line.strip().startswith('answer ='):
                processed_body += line + '\n' # Do not add eval comment for answer lines
                continue
            
            # Step 4: For code lines, do search and replace
            if line.strip():  # Non-empty line
                eval_line = line
                sorted_trace_dict = sorted(trace_dict.items(), key=lambda x: len(x[0]), reverse=True)
                
                # Replace each variable with its value
                substituted = False
                for variable, value in sorted_trace_dict:
                    if variable in eval_line:
                        eval_line = eval_line.replace(variable, str(value))
                        substituted = True
                
                # Step 5: Append eval comment if any substitution happened
                if substituted:
                    # Get indentation from original line
                    new_line = line + " # eval: " + eval_line.lstrip()
                    processed_body += new_line + '\n'  # Append with newline
                else:
                    # If no substitution, just append the original line
                    processed_body += line + '\n'  # Append with newline
        
        # Step 6: Concatenate and save
        # Convert signature_lines to string and concatenate with processed_body
        signature_code = '\n'.join(signature_lines) + '\n'
        final_code = signature_code + processed_body
        
        # Write to output file
        output_file = output_dir / str(problem_index) / f"{model_name}.py"
        output_file.parent.mkdir(parents=True, exist_ok=True)
        
        with open(output_file, 'w', encoding='utf-8') as f:
            f.write(final_code)
        
        if verbose:
            print(f"Augmented {model_name} for problem {problem_index}")
        return True
        
    except Exception as e:
        if verbose:
            print(f"Error processing problem {problem_index}, model {model_name}: {e}")
        return False

def batch_augment_with_trace(
    indices: list,
    models: list = MODELS,
    input_dir: Path = PROJECT_ROOT / 'data' / 'code_gen_outputs_formatted',
    output_dir: Path = PROJECT_ROOT / 'data' / 'code_gen_outputs_traced',
    verbose: bool = True
):
    """
    Batch process multiple problems and models to augment code with execution traces.
    """
    success_count = 0
    total_count = 0
    
    for idx in indices:
        if verbose:
            print(f"Processing problem {idx}...")
        
        for model in models:
            total_count += 1
            if augment_code_with_trace(idx, model, input_dir, output_dir, verbose):
                success_count += 1
    
    if verbose:
        print(f"Completed: {success_count}/{total_count} files processed successfully")

# ---------------------------------------------------------------------- #

def augment_trace_wrapper(index: int):
    for model in MODELS:
        print(f"Execution trace for {model} with index {index}:")
        problem_dir = BASE_OUTPUT_DIR / f"{index}"
        py_file_path = problem_dir / f"{model}.py"
        try:
            spec = importlib.util.spec_from_file_location("module.name", py_file_path)
            module = importlib.util.module_from_spec(spec)
            spec.loader.exec_module(module)
        except FileNotFoundError:
            print(f"File not found: {py_file_path}, skipping.")
            continue
        except Exception as e:
            print(f"Error processing {py_file_path}: {e}, skipping.")
            continue

        trace = execution_trace(module.solve)
        display(trace)
        print("\n" + "="*50 + "\n")

In [71]:
batch_augment_with_trace(INDICES)

Processing problem 0...
Augmented anthropic_claude-3-5-haiku-20241022 for problem 0
Augmented openai_gpt-4.1-mini for problem 0
Augmented google_gemini-2.0-flash-thinking-exp for problem 0
Augmented google_gemini-2.5-flash-lite-preview-06-17 for problem 0
Augmented google_gemini-2.5-flash for problem 0
Processing problem 1...
Augmented anthropic_claude-3-5-haiku-20241022 for problem 1
Augmented openai_gpt-4.1-mini for problem 1
Augmented google_gemini-2.0-flash-thinking-exp for problem 1
Augmented google_gemini-2.5-flash-lite-preview-06-17 for problem 1
Augmented google_gemini-2.5-flash for problem 1
Processing problem 2...
Augmented anthropic_claude-3-5-haiku-20241022 for problem 2
Augmented openai_gpt-4.1-mini for problem 2
Augmented google_gemini-2.0-flash-thinking-exp for problem 2
Augmented google_gemini-2.5-flash-lite-preview-06-17 for problem 2
Augmented google_gemini-2.5-flash for problem 2
Processing problem 3...
Augmented anthropic_claude-3-5-haiku-20241022 for problem 3
Augm

[Error] File not found: /Users/arvindsuresh/Documents/Github/Erdos-DL-June25-Math/data/code_gen_outputs_formatted/52/anthropic_claude-3-5-haiku-20241022.py
[Error] File not found: /Users/arvindsuresh/Documents/Github/Erdos-DL-June25-Math/data/code_gen_outputs_formatted/59/anthropic_claude-3-5-haiku-20241022.py
[Error] File not found: /Users/arvindsuresh/Documents/Github/Erdos-DL-June25-Math/data/code_gen_outputs_formatted/71/anthropic_claude-3-5-haiku-20241022.py
[Error] File not found: /Users/arvindsuresh/Documents/Github/Erdos-DL-June25-Math/data/code_gen_outputs_formatted/76/anthropic_claude-3-5-haiku-20241022.py
[Error] File not found: /Users/arvindsuresh/Documents/Github/Erdos-DL-June25-Math/data/code_gen_outputs_formatted/82/anthropic_claude-3-5-haiku-20241022.py
[Error] File not found: /Users/arvindsuresh/Documents/Github/Erdos-DL-June25-Math/data/code_gen_outputs_formatted/85/anthropic_claude-3-5-haiku-20241022.py
[Error] File not found: /Users/arvindsuresh/Documents/Github/Erd

In [72]:
import ast
import inspect
import importlib.util
import pandas as pd

def eval_default_expr(expr_node):
    """
    Safely evaluate a default argument AST node to obtain its value.

    Tries to use ast.literal_eval for simple literals (numbers, strings, etc.).
    Falls back to evaluating the expression using eval/compile for cases like arithmetic expressions (e.g., 1/2, 2*3).
    
    Parameters
    ----------
    expr_node : ast.AST
        The AST node representing the default value expression.

    Returns
    -------
    Any
        The evaluated value of the expression.
    """
    try:
        return ast.literal_eval(expr_node)
    except Exception:
        return eval(compile(ast.Expression(expr_node), filename="<ast>", mode="eval"))

def extract_vars_to_df(
        index: int, 
        models: list, 
        output_dir=BASE_OUTPUT_DIR,
        verbose=True,
        save=True):
    """
    Extracts argument and variable information from the 'solve' function in generated Python files for multiple models.

    For each model, loads the corresponding Python file, parses the 'solve' function, and collects:
      - All argument names, types (from type hints), and default values.
      - All intermediate and final variables (as determined by local variables in the function).
      - The role of each variable ('argument', 'intermediate', or 'final').

    Returns a pandas DataFrame with columns: index, model, variable, role, type, value.
    Optionally saves the DataFrame as a CSV file for each problem index.

    Parameters
    ----------
    index : int
        The problem index to process.
    models : list
        List of model names to process for the given index.
    output_dir : Path, optional
        The base directory containing the generated Python files (default: BASE_OUTPUT_DIR).
    verbose : bool, optional
        If True, prints progress and error messages (default: True).
    save : bool, optional
        If True, saves the resulting DataFrame as a CSV file (default: True).

    Returns
    -------
    pd.DataFrame
        DataFrame containing variable information for all models for the given index.
    """
    rows = []
    for model in models:
        problem_dir = output_dir / f"{index}"
        py_file_path = problem_dir / f"{model}.py"
        try:
            spec = importlib.util.spec_from_file_location("module.name", py_file_path)
            module = importlib.util.module_from_spec(spec)
            spec.loader.exec_module(module)
            if verbose:
                print(f"Processing ast for model {model}")
            with open(py_file_path, "r") as f:
                source = f.read()
            tree = ast.parse(source)

            result = {}

            for node in ast.walk(tree):
                if isinstance(node, ast.FunctionDef) and node.name == "solve":
                    # Arguments
                    args = node.args.args
                    defaults = node.args.defaults
                    if len(defaults) != len(args):
                        raise ValueError("All arguments must have default values.")
                    for arg, default in zip(args, defaults):
                        arg_name = arg.arg
                        # Get type from annotation
                        if arg.annotation is not None:
                            if isinstance(arg.annotation, ast.Name):
                                arg_type = arg.annotation.id
                            elif isinstance(arg.annotation, ast.Subscript):
                                arg_type = ast.unparse(arg.annotation)
                            else:
                                arg_type = str(ast.dump(arg.annotation))
                        else:
                            arg_type = "unknown"
                        value = eval_default_expr(default)
                        result[arg_name] = {
                            "role": "argument",
                            "type": arg_type,
                            "value": value
                        }
                    # Intermediate/final variables
                    solve_func = getattr(module, "solve")
                    arg_values = [result[arg.arg]["value"] for arg in args]
                    def capture_locals(*args, **kwargs):
                        frame = {}
                        def tracer(frame_obj, event, arg):
                            if event == "return":
                                frame.update(frame_obj.f_locals)
                            return tracer
                        sys.setprofile(tracer)
                        try:
                            ret = solve_func(*args, **kwargs)
                        finally:
                            sys.setprofile(None)
                        return frame, ret
                    locals_dict, _ = capture_locals(*arg_values)
                    for var, val in locals_dict.items():
                        if var not in result:
                            role = "final" if var == "answer" else "intermediate"
                            result[var] = {
                                "role": role,
                                "type": type(val).__name__,
                                "value": val
                            }
                    break
            # Add to rows
            for var, info in result.items():
                rows.append({
                    "index": index,
                    "model": model,
                    "variable": var,
                    "role": info["role"],
                    "type": info["type"],
                    "value": info["value"]
                })
        except FileNotFoundError:
            if verbose:
                print(f"File not found: {py_file_path}, skipping.")
            continue
        except Exception as e:
            if verbose:
                print(f"Error processing {py_file_path}: {e}, skipping.")
            continue
    df = pd.DataFrame(rows, columns=["index", "model", "variable", "role", "type", "value"])

    # save if required
    if save and not df.empty:
        save_path = problem_dir / f"{index}_variables.csv"
        df.to_csv(save_path, index=False)
    return df

def batch_extract_vars_to_df(
    indices: list,
    models: list = MODELS,
    output_dir=BASE_OUTPUT_DIR,
    verbose=True,
    save=True
):
    """
    Runs extract_vars_to_df for each index in indices.
    Returns a dict mapping index to the resulting DataFrame.
    """
    results = {}
    for idx in indices:
        if verbose:
            print(f"Processing index {idx}...")
        df = extract_vars_to_df(
            index=idx,
            models=models,
            output_dir=output_dir,
            verbose=verbose,
            save=save
        )
        results[idx] = df
    return results

In [73]:
vars_df_dict = batch_extract_vars_to_df(INDICES)

Processing index 0...
Processing ast for model anthropic_claude-3-5-haiku-20241022
Processing ast for model openai_gpt-4.1-mini
Processing ast for model google_gemini-2.0-flash-thinking-exp
Processing ast for model google_gemini-2.5-flash-lite-preview-06-17
Processing ast for model google_gemini-2.5-flash
Processing index 1...
Processing ast for model anthropic_claude-3-5-haiku-20241022
Processing ast for model openai_gpt-4.1-mini
Processing ast for model google_gemini-2.0-flash-thinking-exp
Processing ast for model google_gemini-2.5-flash-lite-preview-06-17
Processing ast for model google_gemini-2.5-flash
Processing index 2...
Processing ast for model anthropic_claude-3-5-haiku-20241022
Processing ast for model openai_gpt-4.1-mini
Processing ast for model google_gemini-2.0-flash-thinking-exp
Processing ast for model google_gemini-2.5-flash-lite-preview-06-17
Processing ast for model google_gemini-2.5-flash
Processing index 3...
Processing ast for model anthropic_claude-3-5-haiku-20241