In [2]:
import sys
from pathlib import Path
import ast
import importlib.util

# ---------------------------------------------------------------------- #
#  Global constants & Configuration
# ---------------------------------------------------------------------- #

def find_project_root():
    """Traverse upwards to find the project root, marked by the .git folder."""
    current_path = Path.cwd()
    while current_path != current_path.parent:
        if (current_path / ".git").is_dir():
            return current_path
        current_path = current_path.parent
    raise FileNotFoundError("Could not find project root. Is this a git repository?")


PROJECT_ROOT = find_project_root()
BASE_OUTPUT_DIR = PROJECT_ROOT / 'data' / 'code_gen_outputs_cleaned'
print(f"Project root found: {PROJECT_ROOT}")
print(f"Base output directory set to: {BASE_OUTPUT_DIR}")

MODEL_DICT = {
  "anthropic": ["claude-3-5-haiku-20241022"], 
  "openai": ["gpt-4.1-mini"],
  "google": ["gemini-2.0-flash-thinking-exp", 
             "gemini-2.5-flash-lite-preview-06-17",
             "gemini-2.5-flash"]
}

MODELS = [f"{provider}_{model}" for provider, sublist in MODEL_DICT.items() for model in sublist]
print(f"Available models: {MODELS}")

Project root found: /Users/arvindsuresh/Documents/Github/Erdos-DL-June25-Math
Base output directory set to: /Users/arvindsuresh/Documents/Github/Erdos-DL-June25-Math/data/code_gen_outputs_cleaned
Available models: ['anthropic_claude-3-5-haiku-20241022', 'openai_gpt-4.1-mini', 'google_gemini-2.0-flash-thinking-exp', 'google_gemini-2.5-flash-lite-preview-06-17', 'google_gemini-2.5-flash']


In [None]:
import ast
import importlib.util
import pandas as pd

def eval_default_expr(expr_node):
    """
    Safely evaluate a default argument AST node to obtain its value.

    Tries to use ast.literal_eval for simple literals (numbers, strings, etc.).
    Falls back to evaluating the expression using eval/compile for cases like arithmetic expressions (e.g., 1/2, 2*3).
    
    Parameters
    ----------
    expr_node : ast.AST
        The AST node representing the default value expression.

    Returns
    -------
    Any
        The evaluated value of the expression.
    """
    try:
        return ast.literal_eval(expr_node)
    except Exception:
        return eval(compile(ast.Expression(expr_node), filename="<ast>", mode="eval"))

def extract_vars_to_df(
        index: int, 
        models: list, 
        output_dir=BASE_OUTPUT_DIR,
        verbose=True,
        save=True):
    """
    Extracts argument and variable information from the 'solve' function in generated Python files for multiple models.

    For each model, loads the corresponding Python file, parses the 'solve' function, and collects:
      - All argument names, types (from type hints), and default values.
      - All intermediate and final variables (as determined by local variables in the function).
      - The role of each variable ('argument', 'intermediate', or 'final').

    Returns a pandas DataFrame with columns: index, model, variable, role, type, value.
    Optionally saves the DataFrame as a CSV file for each problem index.

    Parameters
    ----------
    index : int
        The problem index to process.
    models : list
        List of model names to process for the given index.
    output_dir : Path, optional
        The base directory containing the generated Python files (default: BASE_OUTPUT_DIR).
    verbose : bool, optional
        If True, prints progress and error messages (default: True).
    save : bool, optional
        If True, saves the resulting DataFrame as a CSV file (default: True).

    Returns
    -------
    pd.DataFrame
        DataFrame containing variable information for all models for the given index.
    """
    rows = []
    for model in models:
        problem_dir = output_dir / f"{index}"
        py_file_path = problem_dir / f"{model}.py"
        try:
            spec = importlib.util.spec_from_file_location("module.name", py_file_path)
            module = importlib.util.module_from_spec(spec)
            spec.loader.exec_module(module)
            if verbose:
                print(f"Processing ast for model {model}")
            with open(py_file_path, "r") as f:
                source = f.read()
            tree = ast.parse(source)

            result = {}

            for node in ast.walk(tree):
                if isinstance(node, ast.FunctionDef) and node.name == "solve":
                    # Arguments
                    args = node.args.args
                    defaults = node.args.defaults
                    if len(defaults) != len(args):
                        raise ValueError("All arguments must have default values.")
                    for arg, default in zip(args, defaults):
                        arg_name = arg.arg
                        # Get type from annotation
                        if arg.annotation is not None:
                            if isinstance(arg.annotation, ast.Name):
                                arg_type = arg.annotation.id
                            elif isinstance(arg.annotation, ast.Subscript):
                                arg_type = ast.unparse(arg.annotation)
                            else:
                                arg_type = str(ast.dump(arg.annotation))
                        else:
                            arg_type = "unknown"
                        value = eval_default_expr(default)
                        result[arg_name] = {
                            "role": "argument",
                            "type": arg_type,
                            "value": value
                        }
                    # Intermediate/final variables
                    solve_func = getattr(module, "solve")
                    arg_values = [result[arg.arg]["value"] for arg in args]
                    def capture_locals(*args, **kwargs):
                        frame = {}
                        def tracer(frame_obj, event, arg):
                            if event == "return":
                                frame.update(frame_obj.f_locals)
                            return tracer
                        sys.setprofile(tracer)
                        try:
                            ret = solve_func(*args, **kwargs)
                        finally:
                            sys.setprofile(None)
                        return frame, ret
                    locals_dict, _ = capture_locals(*arg_values)
                    for var, val in locals_dict.items():
                        if var not in result:
                            role = "final" if var == "answer" else "intermediate"
                            result[var] = {
                                "role": role,
                                "type": type(val).__name__,
                                "value": val
                            }
                    break
            # Add to rows
            for var, info in result.items():
                rows.append({
                    "index": index,
                    "model": model,
                    "variable": var,
                    "role": info["role"],
                    "type": info["type"],
                    "value": info["value"]
                })
        except FileNotFoundError:
            if verbose:
                print(f"File not found: {py_file_path}, skipping.")
            continue
        except Exception as e:
            if verbose:
                print(f"Error processing {py_file_path}: {e}, skipping.")
            continue
    df = pd.DataFrame(rows, columns=["index", "model", "variable", "role", "type", "value"])

    # save if required
    if save and not df.empty:
        save_path = problem_dir / f"{index}_variables.csv"
        df.to_csv(save_path, index=False)
    return df

In [29]:
vars_df_dict = {}
for index in range(100):
    print(f"Extracting variables for problem {index}")
    vars_df_dict[index] = extract_vars_to_df(index, MODELS)

Extracting variables for problem 0
Processing ast for model anthropic_claude-3-5-haiku-20241022
Processing ast for model openai_gpt-4.1-mini
Processing ast for model google_gemini-2.0-flash-thinking-exp
Processing ast for model google_gemini-2.5-flash-lite-preview-06-17
Processing ast for model google_gemini-2.5-flash
Extracting variables for problem 1
Processing ast for model anthropic_claude-3-5-haiku-20241022
Processing ast for model openai_gpt-4.1-mini
Processing ast for model google_gemini-2.0-flash-thinking-exp
Processing ast for model google_gemini-2.5-flash-lite-preview-06-17
Processing ast for model google_gemini-2.5-flash
Extracting variables for problem 2
Processing ast for model anthropic_claude-3-5-haiku-20241022
Processing ast for model openai_gpt-4.1-mini
Processing ast for model google_gemini-2.0-flash-thinking-exp
Processing ast for model google_gemini-2.5-flash-lite-preview-06-17
Processing ast for model google_gemini-2.5-flash
Extracting variables for problem 3
Proce

In [31]:
display(vars_df_dict[25])  # Displaying the DataFrame for problem index 26

Unnamed: 0,index,model,variable,role,type,value
0,25,anthropic_claude-3-5-haiku-20241022,total_tennis_balls,argument,int,175.0
1,25,anthropic_claude-3-5-haiku-20241022,first_batch_balls,argument,int,100.0
2,25,anthropic_claude-3-5-haiku-20241022,first_batch_hit_fraction,argument,float,0.4
3,25,anthropic_claude-3-5-haiku-20241022,next_batch_balls,argument,int,75.0
4,25,anthropic_claude-3-5-haiku-20241022,next_batch_hit_fraction,argument,float,0.333333
5,25,anthropic_claude-3-5-haiku-20241022,first_batch_missed_balls,intermediate,float,60.0
6,25,anthropic_claude-3-5-haiku-20241022,next_batch_missed_balls,intermediate,float,50.0
7,25,anthropic_claude-3-5-haiku-20241022,total_missed_balls,intermediate,float,110.0
8,25,anthropic_claude-3-5-haiku-20241022,answer,final,float,110.0
9,25,openai_gpt-4.1-mini,total_balls,argument,int,175.0
