In [1]:
import ast
import sys
from pathlib import Path
import shutil

# ---------------------------------------------------------------------- #
#  Global constants & Configuration
# ---------------------------------------------------------------------- #

def find_project_root():
    """Traverse upwards to find the project root, marked by the .git folder."""
    current_path = Path.cwd()
    while current_path != current_path.parent:
        if (current_path / ".git").is_dir():
            return current_path
        current_path = current_path.parent
    raise FileNotFoundError("Could not find project root. Is this a git repository?")


PROJECT_ROOT = find_project_root()
BASE_OUTPUT_DIR = PROJECT_ROOT / 'data' / 'code_gen_outputs_cleaned'
print(f"Project root found: {PROJECT_ROOT}")
print(f"Base output directory set to: {BASE_OUTPUT_DIR}")

MODEL_DICT = {
  "anthropic": ["claude-3-5-haiku-20241022"], 
  "openai": ["gpt-4.1-mini"],
  "google": ["gemini-2.0-flash-thinking-exp", 
             "gemini-2.5-flash-lite-preview-06-17",
             "gemini-2.5-flash"]
}

MODELS = [f"{provider}_{model}" for provider, sublist in MODEL_DICT.items() for model in sublist]
print(f"Available models: {MODELS}")


# ---------------------------------------------------------------------- #
#  Reusable Function
# ---------------------------------------------------------------------- #

def generate_ast_for_model( 
    problem_index: int, 
    model_name: str,
    base_output_dir: Path = BASE_OUTPUT_DIR,
) -> ast.AST | None:
    """
    Locates the Python file for a given model and problem, and returns its AST.

    The Abstract Syntax Tree (AST) is a structured, hierarchical representation
    of the code that can be programmatically analyzed and modified.

    Args:
        base_output_dir: The root directory for the cleaned code outputs
                         (e.g., '.../data/code_gen_outputs_cleaned').
        problem_index: The integer index of the GSM8K problem.
        model_name: The name of the model, which corresponds to the file stem
                    (e.g., 'openai_gpt-4.1-mini').

    Returns:
        The root ast.AST node for the file if successful, otherwise None.
    """
    problem_dir = base_output_dir / str(problem_index)
    file_path = problem_dir / f"{model_name}.py"

    if not file_path.exists():
        print(f"[AST Generator Error] File not found: {file_path}", file=sys.stderr)
        return None

    try:
        source_code = file_path.read_text(encoding="utf-8")
        tree = ast.parse(source_code)
        return tree
    except SyntaxError as e:
        print(f"[AST Generator Error] Syntax error in {file_path}: {e}", file=sys.stderr)
        return None
    except Exception as e:
        print(
            f"[AST Generator Error] Unexpected error processing {file_path}: {e!r}",
            file=sys.stderr,
        )
        return None

Project root found: /Users/arvindsuresh/Documents/Github/Erdos-DL-June25-Math
Base output directory set to: /Users/arvindsuresh/Documents/Github/Erdos-DL-June25-Math/data/code_gen_outputs_cleaned
Available models: ['anthropic_claude-3-5-haiku-20241022', 'openai_gpt-4.1-mini', 'google_gemini-2.0-flash-thinking-exp', 'google_gemini-2.5-flash-lite-preview-06-17', 'google_gemini-2.5-flash']


In [7]:
ast_dict = {}

for index in range(100):
    index_dict = {}
    for model in MODELS:
        ast_tree = generate_ast_for_model(index, model)
        if ast_tree is not None:
            index_dict[model] = ast_tree
        else:
            index_dict[model] = None
    ast_dict[index] = index_dict
    print(f'Finished processing problem {index}.')

Finished processing problem 0.
Finished processing problem 1.
Finished processing problem 2.
Finished processing problem 3.
Finished processing problem 4.
Finished processing problem 5.
Finished processing problem 6.
Finished processing problem 7.
Finished processing problem 8.
Finished processing problem 9.
Finished processing problem 10.
Finished processing problem 11.
Finished processing problem 12.
Finished processing problem 13.
Finished processing problem 14.
Finished processing problem 15.
Finished processing problem 16.
Finished processing problem 17.
Finished processing problem 18.
Finished processing problem 19.
Finished processing problem 20.
Finished processing problem 21.
Finished processing problem 22.
Finished processing problem 23.
Finished processing problem 24.
Finished processing problem 25.
Finished processing problem 26.
Finished processing problem 27.
Finished processing problem 28.
Finished processing problem 29.
Finished processing problem 30.
Finished processin

[AST Generator Error] File not found: /Users/arvindsuresh/Documents/Github/Erdos-DL-June25-Math/data/code_gen_outputs_cleaned/52/anthropic_claude-3-5-haiku-20241022.py
[AST Generator Error] File not found: /Users/arvindsuresh/Documents/Github/Erdos-DL-June25-Math/data/code_gen_outputs_cleaned/59/anthropic_claude-3-5-haiku-20241022.py
[AST Generator Error] File not found: /Users/arvindsuresh/Documents/Github/Erdos-DL-June25-Math/data/code_gen_outputs_cleaned/71/anthropic_claude-3-5-haiku-20241022.py
[AST Generator Error] File not found: /Users/arvindsuresh/Documents/Github/Erdos-DL-June25-Math/data/code_gen_outputs_cleaned/76/anthropic_claude-3-5-haiku-20241022.py
[AST Generator Error] File not found: /Users/arvindsuresh/Documents/Github/Erdos-DL-June25-Math/data/code_gen_outputs_cleaned/82/anthropic_claude-3-5-haiku-20241022.py
[AST Generator Error] File not found: /Users/arvindsuresh/Documents/Github/Erdos-DL-June25-Math/data/code_gen_outputs_cleaned/85/anthropic_claude-3-5-haiku-2024

In [9]:
import pickle
from pathlib import Path

def save_ast_dict(data: dict, filepath: Path) -> None:
    """
    Saves a dictionary of AST objects to a file using pickle.

    Args:
        data: The dictionary to save.
        filepath: The path to the output file.
    """
    try:
        # Create parent directories if they don't exist
        filepath.parent.mkdir(parents=True, exist_ok=True)
        
        # Open the file in binary write mode ('wb')
        with open(filepath, "wb") as f:
            pickle.dump(data, f)
        print(f"Successfully saved AST dictionary to: {filepath}")
    except Exception as e:
        print(f"Error saving file: {e!r}", file=sys.stderr)


def load_ast_dict(filepath: Path) -> dict | None:
    """
    Loads a dictionary of AST objects from a pickle file.

    Args:
        filepath: The path to the input file.

    Returns:
        The loaded dictionary, or None if an error occurs.
    """
    if not filepath.exists():
        print(f"Error: File not found at {filepath}", file=sys.stderr)
        return None
    try:
        # Open the file in binary read mode ('rb')
        with open(filepath, "rb") as f:
            data = pickle.load(f)
        print(f"Successfully loaded AST dictionary from: {filepath}")
        return data
    except Exception as e:
        print(f"Error loading file: {e!r}", file=sys.stderr)
        return None

In [10]:
# --- Example Usage ---
# Define a path within your project for the serialized data
ast_data_path = PROJECT_ROOT / "data" / "ast_data" / "ast_dict.pkl"

# Save the dictionary you created in the previous cell
save_ast_dict(ast_dict, ast_data_path)

# Load it back to confirm the process works
loaded_ast_dict = load_ast_dict(ast_data_path)

# Check if the loaded data is the same
if loaded_ast_dict is not None and list(loaded_ast_dict.keys()) == list(ast_dict.keys()):
    print("\nVerification successful: Loaded dictionary keys match original.")

Successfully saved AST dictionary to: /Users/arvindsuresh/Documents/Github/Erdos-DL-June25-Math/data/ast_data/ast_dict.pkl
Successfully loaded AST dictionary from: /Users/arvindsuresh/Documents/Github/Erdos-DL-June25-Math/data/ast_data/ast_dict.pkl

Verification successful: Loaded dictionary keys match original.


In [2]:
# --- Select a specific AST to inspect ---
# Let's choose a problem and model that we know has data.
# Replace with any index/model you want to explore.
TARGET_INDEX = 52
TARGET_MODEL = "google_gemini-2.0-flash-thinking-exp"

if loaded_ast_dict:
    # Safely get the AST, handling cases where it might not exist
    tree_to_explore = loaded_ast_dict.get(TARGET_INDEX, {}).get(TARGET_MODEL)
else:
    tree_to_explore = None
    print("AST dictionary not loaded. Please run the previous cell.", file=sys.stderr)

if tree_to_explore:
    print(f"--- Exploring AST for Problem {TARGET_INDEX}, Model '{TARGET_MODEL}' ---")

    # 1. The most direct method: ast.dump()
    # This gives a complete, unambiguous representation of the tree's structure.
    # It's dense, but it's the ground truth of what's in the tree.
    print("\n1. Full Tree Structure using ast.dump():")
    # The `indent` argument makes the output readable
    print(ast.dump(tree_to_explore, indent=4))

    # 2. Programmatic Traversal: ast.walk()
    # This function lets you iterate over every single node in the tree.
    # It's how you would programmatically find specific elements.
    print("\n2. Finding all Binary Operations using ast.walk():")
    binary_ops_found = 0
    for node in ast.walk(tree_to_explore):
        if isinstance(node, ast.BinOp):
            binary_ops_found += 1
            # ast.dump() is useful for inspecting even small sub-nodes
            print(f"  - Found a BinOp node: {ast.dump(node)}")
    if not binary_ops_found:
        print("  - No binary operations found in this tree.")

    # 3. Converting back to source code: ast.unparse()
    # This demonstrates that the AST contains all the information needed
    # to reconstruct the original code.
    print("\n3. Reconstructed Source Code using ast.unparse():")
    reconstructed_code = ast.unparse(tree_to_explore)
    print(reconstructed_code)

    # 4. (Optional) For better visualization, install and use `astpretty`
    # It provides a cleaner, more readable dump than the default.
    # You may need to install it first by running:
    # !pip install astpretty
    try:
        import astpretty
        print("\n4. Visualizing with `astpretty` (more readable dump):")
        astpretty.pprint(tree_to_explore)
    except ImportError:
        print("\n4. `astpretty` not installed. Skipping. (Run '!pip install astpretty' to use)")
else:
    print(f"Could not find a valid AST for index {TARGET_INDEX}, model '{TARGET_MODEL}'.")

NameError: name 'loaded_ast_dict' is not defined

In [None]:
# import re
# from pathlib import Path
# import sys
# from dataclasses import dataclass
# import libcst as cst
# from libcst.metadata import PositionProvider, MetadataWrapper

# # ---------------------------------------------------------------------- #
# #  Define a custom data structure to hold the parsed line information
# # ---------------------------------------------------------------------- #

# @dataclass
# class CSTStatement:
#     """
#     A structured representation of a single logical line of code from the CST.
    
#     This object links a statement node with its associated comment and the
#     extracted solution step (e.g., 'L1').
#     """
#     line_number: int         # Physical line number in the source file
#     solution_step: str | None # The extracted step, like 'L1', 'L2', etc.
#     code_str: str            # The string representation of the code on the line
#     comment_str: str | None  # The string representation of the comment
#     cst_node: cst.CSTNode    # The actual libcst node for potential modification

# # ---------------------------------------------------------------------- #
# #  Reusable Function using libcst
# # ---------------------------------------------------------------------- #

# # Regex to find the solution step comment, e.g., "#: L1"
# _TRACE_RE = re.compile(r"^#:\s*L(\d+)")

# def generate_cst_data_for_model(
#     problem_index: int,
#     model_name: str,
#     base_output_dir: Path = BASE_OUTPUT_DIR,
# ) -> list[CSTStatement] | None:
#     """
#     Parses a model's Python file using libcst to preserve comments and
#     returns a list of structured statement objects.

#     Args:
#         problem_index: The integer index of the GSM8K problem.
#         model_name: The name of the model corresponding to the file stem.
#         base_output_dir: The root directory for the cleaned code outputs.

#     Returns:
#         A list of CSTStatement objects, one for each relevant line in the
#         'solve' function's body, or None if an error occurs.
#     """
#     print(f"[DEBUG] Entered generate_cst_data_for_model with problem_index={problem_index}, model_name={model_name}")
#     problem_dir = base_output_dir / str(problem_index)
#     file_path = problem_dir / f"{model_name}.py"
#     print(f"[DEBUG] Looking for file at: {file_path}")

#     if not file_path.exists():
#         print(f"[CST Parser Error] File not found: {file_path}", file=sys.stderr)
#         return None

#     try:
#         source_code = file_path.read_text(encoding="utf-8")
#         print(f"[DEBUG] Successfully read source code ({len(source_code)} chars)")
#         module_cst = cst.parse_module(source_code)
#         print(f"[DEBUG] Parsed module_cst")
#         wrapper = MetadataWrapper(module_cst)
#         print(f"[DEBUG] Created MetadataWrapper")
#         # Resolve the position metadata for all nodes in the module once
#         position_map = wrapper.resolve(PositionProvider)
#         print(f"[DEBUG] Resolved position_map with {len(position_map)} entries")
#         print("[DEBUG] position_map keys (types):", set(type(k) for k in position_map.keys()))
#         print("[DEBUG] position_map keys (repr):", [repr(k) for k in list(position_map.keys())[:10]])
#     except Exception as e:
#         print(f"[CST Parser Error] Failed to parse {file_path}: {e!r}", file=sys.stderr)
#         return None

#     statements = []

#     # Iterate through the top-level nodes to find the 'solve' function
#     found_solve = False
#     for node in module_cst.body:
#         print(f"[DEBUG] Top-level node: {type(node)}")
#         if isinstance(node, cst.FunctionDef) and node.name.value == "solve":
#             print(f"[DEBUG] Found 'solve' function")
#             found_solve = True
#             function_body = node.body
#             if not isinstance(function_body, cst.IndentedBlock):
#                 print(f"[DEBUG] Function body is not an IndentedBlock, skipping")
#                 continue
#             # Iterate through each statement in the function body
#             for i, stmt_line in enumerate(function_body.body):
#                 print(f"[DEBUG] Statement {i}: {type(stmt_line)}")
#                 if isinstance(stmt_line, cst.SimpleStatementLine):
#                     # --- Extract comment from leading_lines or trailing_whitespace ---
#                     comment_text = None
#                     # Check leading_lines for a comment (step comments like #: L1)
#                     for leading in stmt_line.leading_lines:
#                         if leading.comment is not None:
#                             comment_text = leading.comment.value.strip()
#                             break
#                     # If not found in leading_lines, check trailing_whitespace
#                     if comment_text is None and stmt_line.trailing_whitespace.comment:
#                         comment_text = stmt_line.trailing_whitespace.comment.value.strip()
#                     # --- Extract solution step if present ---
#                     solution_step = None
#                     if comment_text:
#                         match = _TRACE_RE.match(comment_text)
#                         if match:
#                             solution_step = f"L{match.group(1)}"
#                     if stmt_line.body:
#                         code_str = module_cst.code_for_node(stmt_line.body[0])
#                         print(f"[DEBUG] Checking stmt_line {i}: {repr(stmt_line)}")
#                         found = False
#                         for k in position_map:
#                             if isinstance(k, cst.SimpleStatementLine) and k.deep_equals(stmt_line):
#                                 pos = position_map[k]
#                                 found = True
#                                 break
#                         if not found:
#                             print(f"[DEBUG] Statement {i} not in position_map (by deep_equals), skipping")
#                             continue
#                         print(f"[DEBUG] Appending CSTStatement for line {pos.start.line}, step={solution_step}, code={code_str}")
#                         statements.append(CSTStatement(
#                             line_number=pos.start.line,
#                             solution_step=solution_step,
#                             code_str=code_str,
#                             comment_str=comment_text,
#                             cst_node=stmt_line.body[0]
#                         ))
#             break
#     if not found_solve:
#         print(f"[DEBUG] Did not find 'solve' function in file.")
#     print(f"[DEBUG] Returning {len(statements)} statements")
#     return statements

In [16]:
# ---------------------------------------------------------------------- #
#  Example Usage
# ---------------------------------------------------------------------- #

# Select a target to explore. Let's use the same one as before.
TARGET_INDEX = 8
TARGET_MODEL = "google_gemini-2.5-flash-lite-preview-06-17"

print(f"--- Parsing with libcst for Problem {TARGET_INDEX}, Model '{TARGET_MODEL}' ---")

# Call the new function
cst_data = generate_cst_data_for_model(TARGET_INDEX, TARGET_MODEL)

if cst_data:
    print("Successfully parsed file and extracted statements with comments.\n")
    print("--- Extracted Statement Data ---")
    for statement in cst_data:
        print(
            f"Line {statement.line_number:2d}: "
            f"Step='{statement.solution_step or 'N/A'}', "
            f"Code='{statement.code_str}', "
            f"Comment='{statement.comment_str or ''}'"
        )
    
    # You can also inspect the cst_node directly for a specific step
    print("\n--- Inspecting the CST node for step 'L2' ---")
    step_l2_node = next((s.cst_node for s in cst_data if s.solution_step == 'L2'), None)
    
    if step_l2_node:
        # For CST nodes, `step_l2_node.visit(YourVisitor())` is how you'd traverse
        # or `step_l2_node.with_changes(...)` is how you'd modify it.
        # For now, just printing the type and code is illustrative.
        print(f"Node Type: {type(step_l2_node)}")
        print(f"Code for node: {cst.parse_module('').code_for_node(step_l2_node)}")
else:
    print("Failed to retrieve CST data.")

--- Parsing with libcst for Problem 8, Model 'google_gemini-2.5-flash-lite-preview-06-17' ---
[DEBUG] Entered generate_cst_data_for_model with problem_index=8, model_name=google_gemini-2.5-flash-lite-preview-06-17
[DEBUG] Looking for file at: /Users/arvindsuresh/Documents/Github/Erdos-DL-June25-Math/data/code_gen_outputs_cleaned/8/google_gemini-2.5-flash-lite-preview-06-17.py
[DEBUG] Successfully read source code (787 chars)
[DEBUG] Parsed module_cst
[DEBUG] Created MetadataWrapper
[DEBUG] Resolved position_map with 270 entries
[DEBUG] position_map keys (types): {<class 'libcst._nodes.expression.Name'>, <class 'libcst._nodes.whitespace.Newline'>, <class 'libcst._nodes.expression.Param'>, <class 'libcst._nodes.op.Comma'>, <class 'libcst._nodes.statement.Expr'>, <class 'libcst._nodes.statement.Assign'>, <class 'libcst._nodes.whitespace.SimpleWhitespace'>, <class 'libcst._nodes.whitespace.ParenthesizedWhitespace'>, <class 'libcst._nodes.whitespace.TrailingWhitespace'>, <class 'libcst._nod

In [34]:
import re
from pathlib import Path
import sys
from dataclasses import dataclass
import libcst as cst
from libcst.metadata import PositionProvider, MetadataWrapper

# ---------------------------------------------------------------------- #
#  Define a custom data structure to hold the parsed line information
# ---------------------------------------------------------------------- #

@dataclass
class CSTStatement:
    """
    A structured representation of a single logical line of code from the CST.
    
    This object links a statement node with its associated comment and the
    extracted solution step (e.g., 'L1').
    """
    line_number: int         # Physical line number in the source file
    solution_step: str | None # The extracted step, like 'L1', 'L2', etc.
    code_str: str            # The string representation of the code on the line
    comment_str: str | None  # The string representation of the comment
    cst_node: cst.CSTNode    # The actual libcst node for potential modification

# ---------------------------------------------------------------------- #
#  Reusable Function using libcst
# ---------------------------------------------------------------------- #

# Regex to find the solution step comment, e.g., "#: L1"
_TRACE_RE = re.compile(r"^#:\s*L(\d+)")

def generate_cst_data_for_model(
    problem_index: int,
    model_name: str,
    base_output_dir: Path = BASE_OUTPUT_DIR,
) -> list[CSTStatement] | None:
    """
    Parses a model's Python file using libcst to preserve comments and
    returns a list of structured statement objects.

    Args:
        problem_index: The integer index of the GSM8K problem.
        model_name: The name of the model corresponding to the file stem.
        base_output_dir: The root directory for the cleaned code outputs.

    Returns:
        A list of CSTStatement objects, one for each relevant line in the
        'solve' function's body, or None if an error occurs.
    """
    print(f"[DEBUG] Entered generate_cst_data_for_model with problem_index={problem_index}, model_name={model_name}")
    problem_dir = base_output_dir / str(problem_index)
    file_path = problem_dir / f"{model_name}.py"
    print(f"[DEBUG] Looking for file at: {file_path}")

    if not file_path.exists():
        print(f"[CST Parser Error] File not found: {file_path}", file=sys.stderr)
        return None

    try:
        source_code = file_path.read_text(encoding="utf-8")
        print(f"[DEBUG] Successfully read source code ({len(source_code)} chars)")
        module_cst = cst.parse_module(source_code)
        print(f"[DEBUG] Parsed module_cst")
        wrapper = MetadataWrapper(module_cst)
        print(f"[DEBUG] Created MetadataWrapper")
        position_map = wrapper.resolve(PositionProvider)
        print(f"[DEBUG] Resolved position_map with {len(position_map)} entries")
    except Exception as e:
        print(f"[CST Parser Error] Failed to parse {file_path}: {e!r}", file=sys.stderr)
        return None

    statements = []
    found_solve = False
    pending_step_comment = None

    for node in module_cst.body:
        print(f"[DEBUG] Top-level node: {type(node)}")
        if isinstance(node, cst.FunctionDef) and node.name.value == "solve":
            print(f"[DEBUG] Found 'solve' function")
            found_solve = True
            function_body = node.body
            if not isinstance(function_body, cst.IndentedBlock):
                print(f"[DEBUG] Function body is not an IndentedBlock, skipping")
                continue
            for i, stmt_line in enumerate(function_body.body):
                print(f"[DEBUG] Statement {i}:")
                if isinstance(stmt_line, cst.SimpleStatementLine):
                    # --- Gather all step comments in leading_lines ---
                    for leading in stmt_line.leading_lines:
                        if leading.comment is not None:
                            comment_val = leading.comment.value.strip()
                            match = _TRACE_RE.match(comment_val)
                            if match:
                                # If it's a step comment, store it for the next statement
                                pending_step_comment = comment_val
                            else:
                                # If it's a regular comment, you may want to handle it as well
                                pass  # (optional: collect or ignore)
                    # --- Extract trailing comment (for non-step comments) ---
                    comment_text = None
                    if stmt_line.trailing_whitespace.comment:
                        comment_text = stmt_line.trailing_whitespace.comment.value.strip()
                    # --- Extract solution step if pending ---
                    solution_step = None
                    if pending_step_comment:
                        match = _TRACE_RE.match(pending_step_comment)
                        if match:
                            solution_step = f"L{match.group(1)}"
                    if stmt_line.body:
                        code_str = module_cst.code_for_node(stmt_line.body[0])
                        print(f"[DEBUG] Checking stmt_line {i}")
                        found = False
                        for k in position_map:
                            if isinstance(k, cst.SimpleStatementLine) and k.deep_equals(stmt_line):
                                pos = position_map[k]
                                found = True
                                break
                        if not found:
                            print(f"[DEBUG] Statement {i} not in position_map (by deep_equals), skipping")
                            continue
                        print(f"[DEBUG] Appending CSTStatement for line {pos.start.line}, step={solution_step}, code={code_str}")
                        statements.append(CSTStatement(
                            line_number=pos.start.line,
                            solution_step=solution_step,
                            code_str=code_str,
                            comment_str=comment_text if comment_text else pending_step_comment,
                            cst_node=stmt_line.body[0]
                        ))
                    # Clear the pending step comment after using it
                    pending_step_comment = None
            break
    if not found_solve:
        print(f"[DEBUG] Did not find 'solve' function in file.")
    print(f"[DEBUG] Returning {len(statements)} statements")
    return statements

In [35]:
def batch_generate_cst_data(
    target_indices: list[int],
    target_models: list[str],
    base_output_dir: Path = BASE_OUTPUT_DIR,
) -> dict[int, dict[str, list[CSTStatement] | None]]:
    """
    Runs generate_cst_data_for_model for each (index, model) pair and returns a nested dictionary.

    Args:
        target_indices: List of problem indices to process.
        target_models: List of model names to process.
        base_output_dir: The root directory for the cleaned code outputs.

    Returns:
        A dictionary of the form {index: {model: [CSTStatement, ...] or None}}
    """
    results = {}
    for idx in target_indices:
        model_results = {}
        for model in target_models:
            print(f"\n[INFO] Processing index={idx}, model={model}")
            cst_data = generate_cst_data_for_model(idx, model, base_output_dir=base_output_dir)
            model_results[model] = cst_data

            if cst_data:
                print("Successfully parsed file and extracted statements with comments.\n")
                print("--- Extracted Statement Data ---")
                for statement in cst_data:
                    print(
                        f"Line {statement.line_number:2d}: "
                        f"Step='{statement.solution_step or 'N/A'}', "
                        f"Code='{statement.code_str}', "
                        f"Comment='{statement.comment_str or ''}'"
                    )
                
                # You can also inspect the cst_node directly for a specific step
                print("\n--- Inspecting the CST node for step 'L2' ---")
                step_l2_node = next((s.cst_node for s in cst_data if s.solution_step == 'L2'), None)
                
                if step_l2_node:
                    # For CST nodes, `step_l2_node.visit(YourVisitor())` is how you'd traverse
                    # or `step_l2_node.with_changes(...)` is how you'd modify it.
                    # For now, just printing the type and code is illustrative.
                    print(f"Node Type: {type(step_l2_node)}")
                    print(f"Code for node: {cst.parse_module('').code_for_node(step_l2_node)}")
            else:
                print("Failed to retrieve CST data.")
        results[idx] = model_results
    return results

In [37]:
results = batch_generate_cst_data([10], MODELS)


[INFO] Processing index=10, model=anthropic_claude-3-5-haiku-20241022
[DEBUG] Entered generate_cst_data_for_model with problem_index=10, model_name=anthropic_claude-3-5-haiku-20241022
[DEBUG] Looking for file at: /Users/arvindsuresh/Documents/Github/Erdos-DL-June25-Math/data/code_gen_outputs_cleaned/10/anthropic_claude-3-5-haiku-20241022.py
[DEBUG] Successfully read source code (875 chars)
[DEBUG] Parsed module_cst
[DEBUG] Created MetadataWrapper
[DEBUG] Resolved position_map with 172 entries
[DEBUG] Top-level node: <class 'libcst._nodes.statement.FunctionDef'>
[DEBUG] Found 'solve' function
[DEBUG] Statement 0:
[DEBUG] Checking stmt_line 0
[DEBUG] Appending CSTStatement for line 5, step=None, code="""Index: 10.
    Returns: the number of people on the first ship the monster consumed."""
[DEBUG] Statement 1:
[DEBUG] Checking stmt_line 1
[DEBUG] Appending CSTStatement for line 19, step=L4, code=total_ship_multiplier = 1 + 2 + 4
[DEBUG] Statement 2:
[DEBUG] Checking stmt_line 2
[DEBUG] 

In [33]:
# # ---------------------------------------------------------------------- #
# #  Example Usage
# # ---------------------------------------------------------------------- #

# # Select a target to explore. Let's use the same one as before.
# TARGET_INDEX = 8
# TARGET_MODEL = "google_gemini-2.5-flash-lite-preview-06-17"

# print(f"--- Parsing with libcst for Problem {TARGET_INDEX}, Model '{TARGET_MODEL}' ---")

# # Call the new function
# cst_data = generate_cst_data_for_model(TARGET_INDEX, TARGET_MODEL)

# if cst_data:
#     print("Successfully parsed file and extracted statements with comments.\n")
#     print("--- Extracted Statement Data ---")
#     for statement in cst_data:
#         print(
#             f"Line {statement.line_number:2d}: "
#             f"Step='{statement.solution_step or 'N/A'}', "
#             f"Code='{statement.code_str}', "
#             f"Comment='{statement.comment_str or ''}'"
#         )
    
#     # You can also inspect the cst_node directly for a specific step
#     print("\n--- Inspecting the CST node for step 'L2' ---")
#     step_l2_node = next((s.cst_node for s in cst_data if s.solution_step == 'L2'), None)
    
#     if step_l2_node:
#         # For CST nodes, `step_l2_node.visit(YourVisitor())` is how you'd traverse
#         # or `step_l2_node.with_changes(...)` is how you'd modify it.
#         # For now, just printing the type and code is illustrative.
#         print(f"Node Type: {type(step_l2_node)}")
#         print(f"Code for node: {cst.parse_module('').code_for_node(step_l2_node)}")
# else:
#     print("Failed to retrieve CST data.")


[INFO] Processing index=8, model=google_gemini-2.5-flash-lite-preview-06-17
[DEBUG] Entered generate_cst_data_for_model with problem_index=8, model_name=google_gemini-2.5-flash-lite-preview-06-17
[DEBUG] Looking for file at: /Users/arvindsuresh/Documents/Github/Erdos-DL-June25-Math/data/code_gen_outputs_cleaned/8/google_gemini-2.5-flash-lite-preview-06-17.py
[DEBUG] Successfully read source code (787 chars)
[DEBUG] Parsed module_cst
[DEBUG] Created MetadataWrapper
[DEBUG] Resolved position_map with 270 entries
[DEBUG] Top-level node: <class 'libcst._nodes.statement.FunctionDef'>
[DEBUG] Found 'solve' function
[DEBUG] Statement 0:
[DEBUG] Checking stmt_line 0
[DEBUG] Appending CSTStatement for line 10, step=None, code="""Index: 8.
    Returns: the amount Alexis paid for the shoes.
    """
[DEBUG] Statement 1:
[DEBUG] Checking stmt_line 1
[DEBUG] Appending CSTStatement for line 17, step=L2, code=spent_on_clothes_without_shoes = shirt_cost + pants_cost + coat_cost + socks_cost + belt_cos

In [38]:
def get_code_lines_dict(problem_index: int, model_name: str, base_output_dir: Path = BASE_OUTPUT_DIR) -> dict[int, str] | None:
    """
    Returns a dict mapping line numbers (0-based) to the verbatim lines of code
    for the given problem index and model name.
    """
    import sys
    problem_dir = base_output_dir / str(problem_index)
    file_path = problem_dir / f"{model_name}.py"
    if not file_path.exists():
        print(f"[Error] File not found: {file_path}", file=sys.stderr)
        return None
    with open(file_path, "r", encoding="utf-8") as f:
        lines = f.readlines()
    return {i: line.rstrip("\n") for i, line in enumerate(lines)}

In [40]:
lines_dict = get_code_lines_dict(10, "google_gemini-2.0-flash-thinking-exp")
lines_dict

{0: 'def solve(',
 1: '    every_hundred_years: int = 100, # rises from the waters once every hundred years',
 2: '    over_three_hundred_years: int = 300, # Over three hundred years',
 3: '    total_people_consumed: int = 847, # it has consumed 847 people',
 4: '    twice_as_many: int = 2 # each new ship has twice as many people as the last ship',
 5: '):',
 6: '    """Index: 10.',
 7: '    Returns: the number of people on the ship the monster ate in the first hundred years.',
 8: '    """',
 9: '    # L1: Let S be the number of people on the first hundred years’ ship. (Implicitly defined as the variable we solve for)',
 10: '    #: L2',
 11: '    # The second hundred years’ ship had twice as many as the first, so it had 2S people.',
 12: "    # This is represented by the factor 'twice_as_many'",
 13: '    #: L3',
 14: '    # The third hundred years’ ship had twice as many as the second, so it had 2 * 2S = 4S people.',
 15: '    multiplier_third_ship = twice_as_many * twice_as_many',


In [52]:
import re
from collections import defaultdict

_TRACE_RE = re.compile(r"#:\s*L(\d+)")

def normalize_code_lines(raw_lines_dict: dict[int, str]) -> dict[str, list[str]]:
    """
    Normalizes raw code lines into a canonical format. (Version 2)
    
    This version correctly separates the final 'answer' and 'return'
    statements from the logical solution steps.
    """
    normalized_dict = defaultdict(list)
    sorted_lines = sorted(raw_lines_dict.items())
    last_seen_step = None

    for i, (line_num, line_content) in enumerate(sorted_lines):
        line = line_content.strip()
        if not line:
            continue

        step_match = _TRACE_RE.match(line)
        if step_match:
            last_seen_step = f"L{step_match.group(1)}"
            continue

        if line.startswith('#'):
            continue
            
        code_part = line
        trailing_match = _TRACE_RE.search(line)
        if trailing_match:
            last_seen_step = f"L{trailing_match.group(1)}"
            code_part = line[:trailing_match.start()].strip()
        
        # --- MODIFIED LOGIC ---
        # Prioritize checking for final lines BEFORE associating with a step
        if 'answer =' in line:
            normalized_dict['answer_line'].append(code_part)
        elif line.startswith('return'):
            normalized_dict['return_line'].append(code_part)
        elif last_seen_step:
            # Only associate with a step if it's not a final line
            normalized_dict[last_seen_step].append(code_part)
            
    return dict(normalized_dict)

# ---------------------------------------------------------------------- #
#  Example Usage
# ---------------------------------------------------------------------- #
print("--- Normalizing with v2 function ---")

# Run on the first example
raw_lines_1 = get_code_lines_dict(10, "google_gemini-2.0-flash-thinking-exp")
if raw_lines_1:
    normalized_code_1 = normalize_code_lines(raw_lines_1)
    print("\n--- Normalized Code Dictionary (Problem 10) ---")
    print(json.dumps(normalized_code_1, indent=4))

# Run on the second example
raw_lines_2 = get_code_lines_dict(8, "google_gemini-2.5-flash-lite-preview-06-17")
if raw_lines_2:
    normalized_code_2 = normalize_code_lines(raw_lines_2)
    print("\n--- Normalized Code Dictionary (Problem 8) ---")
    print(json.dumps(normalized_code_2, indent=4))

--- Normalizing with v2 function ---

--- Normalized Code Dictionary (Problem 10) ---
{
    "L3": [
        "multiplier_third_ship = twice_as_many * twice_as_many"
    ],
    "L4": [
        "total_factor = 1 + twice_as_many + multiplier_third_ship"
    ],
    "L5": [
        "people_first_ship = total_people_consumed / total_factor"
    ],
    "answer_line": [
        "answer = people_first_ship # FINAL ANSWER"
    ],
    "return_line": [
        "return answer"
    ]
}

--- Normalized Code Dictionary (Problem 8) ---
{
    "L2": [
        "spent_on_clothes_without_shoes = shirt_cost + pants_cost + coat_cost + socks_cost + belt_cost"
    ],
    "L3": [
        "total_spent = budget - money_left"
    ],
    "L4": [
        "shoes_cost = total_spent - spent_on_clothes_without_shoes"
    ],
    "answer_line": [
        "answer = shoes_cost # FINAL ANSWER"
    ],
    "return_line": [
        "return answer"
    ]
}


In [55]:
import json

def generate_normalized_code_dict(
    indices: range, 
    models: list[str],
    base_output_dir: Path = BASE_OUTPUT_DIR
) -> dict[int, dict[str, dict[str, list[str]]]]:
    """
    Generates a nested dictionary containing the normalized code for each model
    across a range of problem indices.

    The structure of the returned dictionary is:
    {
        problem_index_1: {
            "model_name_A": { "L1": ["code"], "L2": ["code"], ... },
            "model_name_B": { "L1": ["code"], ... },
            ...
        },
        problem_index_2: { ... },
        ...
    }

    Args:
        indices: A range of problem indices to process (e.g., range(100)).
        models: A list of model name strings.
        base_output_dir: The root directory for the cleaned code outputs.

    Returns:
        A nested dictionary containing all the normalized code.
    """
    master_dict = {}

    for index in indices:
        index_dict = {}
        for model_name in models:
            # 1. Get the raw lines for the current model and index
            raw_lines = get_code_lines_dict(index, model_name, base_output_dir)
            
            if raw_lines:
                # 2. Normalize the code if raw lines were found
                normalized_code = normalize_code_lines(raw_lines)
                if normalized_code:  # Ensure the dictionary is not empty
                    index_dict[model_name] = normalized_code

        if index_dict:  # Only add the index if at least one model had valid code
            master_dict[index] = index_dict
            
    return master_dict

# ---------------------------------------------------------------------- #
#  Example Usage
# ---------------------------------------------------------------------- #

# Use the constants defined in your first cell
# We'll run it on a smaller range (0-20) for a quick demonstration
INDICES_TO_PROCESS = range(100) 

# Generate the master dictionary
normalized_code_master_dict = generate_normalized_code_dict(
    indices=INDICES_TO_PROCESS,
    models=MODELS
)

print(f"\nGenerated master dictionary with {len(normalized_code_master_dict)} problem indices.")

# --- Inspect the output for a specific problem ---
# Let's look at the data for problem 10 again to verify
if 10 in normalized_code_master_dict:
    print("\n--- Verifying data for Problem 10 ---")
    problem_10_data = normalized_code_master_dict[10]
    
    # Print the normalized code for one of the models
    for model in MODELS:
        if model in problem_10_data:
            print(f"\nNormalized code for '{model}':")
            print(json.dumps(problem_10_data[model], indent=4))
    else:
        print(f"No data found for model '{model}' in problem 10.")
else:
    print("\nData for problem 10 not found in the generated dictionary.")


Generated master dictionary with 100 problem indices.

--- Verifying data for Problem 10 ---

Normalized code for 'anthropic_claude-3-5-haiku-20241022':
{
    "L4": [
        "total_ship_multiplier = 1 + 2 + 4  # S + 2S + 4S"
    ],
    "L5": [
        "people_on_first_ship = total_people_consumed / total_ship_multiplier"
    ],
    "answer_line": [
        "answer = people_on_first_ship  # FINAL ANSWER"
    ],
    "return_line": [
        "return answer"
    ]
}

Normalized code for 'openai_gpt-4.1-mini':
{
    "L4": [
        "total_ships_factor = 7  # S + 2S + 4S = 7S"
    ],
    "L5": [
        "first_ship_people = total_people / total_ships_factor"
    ],
    "answer_line": [
        "answer = first_ship_people  # FINAL ANSWER"
    ],
    "return_line": [
        "return answer"
    ]
}

Normalized code for 'google_gemini-2.0-flash-thinking-exp':
{
    "L3": [
        "multiplier_third_ship = twice_as_many * twice_as_many"
    ],
    "L4": [
        "total_factor = 1 + twice_as_m

[Error] File not found: /Users/arvindsuresh/Documents/Github/Erdos-DL-June25-Math/data/code_gen_outputs_cleaned/52/anthropic_claude-3-5-haiku-20241022.py
[Error] File not found: /Users/arvindsuresh/Documents/Github/Erdos-DL-June25-Math/data/code_gen_outputs_cleaned/59/anthropic_claude-3-5-haiku-20241022.py
[Error] File not found: /Users/arvindsuresh/Documents/Github/Erdos-DL-June25-Math/data/code_gen_outputs_cleaned/71/anthropic_claude-3-5-haiku-20241022.py
[Error] File not found: /Users/arvindsuresh/Documents/Github/Erdos-DL-June25-Math/data/code_gen_outputs_cleaned/76/anthropic_claude-3-5-haiku-20241022.py
[Error] File not found: /Users/arvindsuresh/Documents/Github/Erdos-DL-June25-Math/data/code_gen_outputs_cleaned/82/anthropic_claude-3-5-haiku-20241022.py
[Error] File not found: /Users/arvindsuresh/Documents/Github/Erdos-DL-June25-Math/data/code_gen_outputs_cleaned/85/anthropic_claude-3-5-haiku-20241022.py
[Error] File not found: /Users/arvindsuresh/Documents/Github/Erdos-DL-June25

In [56]:
normalized_code_master_dict[10]

{'anthropic_claude-3-5-haiku-20241022': {'L4': ['total_ship_multiplier = 1 + 2 + 4  # S + 2S + 4S'],
  'L5': ['people_on_first_ship = total_people_consumed / total_ship_multiplier'],
  'answer_line': ['answer = people_on_first_ship  # FINAL ANSWER'],
  'return_line': ['return answer']},
 'openai_gpt-4.1-mini': {'L4': ['total_ships_factor = 7  # S + 2S + 4S = 7S'],
  'L5': ['first_ship_people = total_people / total_ships_factor'],
  'answer_line': ['answer = first_ship_people  # FINAL ANSWER'],
  'return_line': ['return answer']},
 'google_gemini-2.0-flash-thinking-exp': {'L3': ['multiplier_third_ship = twice_as_many * twice_as_many'],
  'L4': ['total_factor = 1 + twice_as_many + multiplier_third_ship'],
  'L5': ['people_first_ship = total_people_consumed / total_factor'],
  'answer_line': ['answer = people_first_ship # FINAL ANSWER'],
  'return_line': ['return answer']},
 'google_gemini-2.5-flash-lite-preview-06-17': {'L4': ['total_proportions = 1 + 2 + 4'],
  'L5': ['people_on_first_s

In [57]:
normalized_code_master_dict[14]

{'anthropic_claude-3-5-haiku-20241022': {'L1': ['sets_of_20_mins_per_hour = minutes_per_hour // 20'],
  'L2': ['pages_read_per_hour = pages_read_in_20_mins * sets_of_20_mins_per_hour'],
  'L3': ['hours_to_read = total_pages / pages_read_per_hour'],
  'answer_line': ['answer = hours_to_read  # FINAL ANSWER'],
  'return_line': ['return answer']},
 'openai_gpt-4.1-mini': {'L1': ['sets_per_hour = 60 / 20'],
  'L2': ['pages_per_hour = pages_per_20_min * sets_per_hour'],
  'L3': ['hours_needed = total_pages / pages_per_hour'],
  'answer_line': ['answer = hours_needed  # FINAL ANSWER'],
  'return_line': ['return answer']},
 'google_gemini-2.0-flash-thinking-exp': {'L1': ['minutes_in_hour = 60',
   'intervals_in_hour = minutes_in_hour / time_in_minutes'],
  'L2': ['pages_per_hour = pages_read_in_time * intervals_in_hour'],
  'L3': ['time_needed_hours = total_pages_to_read / pages_per_hour'],
  'answer_line': ['answer = time_needed_hours # FINAL ANSWER'],
  'return_line': ['return answer']},
 '

In [58]:
normalized_code_master_dict[16]

{'anthropic_claude-3-5-haiku-20241022': {'L1': ['part_value = johnson_amount / johnson_share_ratio'],
  'L3': ['mike_total_share = mike_share_ratio * part_value'],
  'L4': ['mike_remaining_amount = mike_total_share - shirt_cost'],
  'answer_line': ['answer = mike_remaining_amount  # FINAL ANSWER'],
  'return_line': ['return answer']},
 'openai_gpt-4.1-mini': {'L2': ['value_per_part = johnson_share / 5'],
  'L3': ['mike_share = 2 * value_per_part'],
  'L4': ['mike_after_shirt = mike_share - shirt_cost'],
  'answer_line': ['answer = mike_after_shirt  # FINAL ANSWER'],
  'return_line': ['return answer']},
 'google_gemini-2.0-flash-thinking-exp': {'L2': ['value_per_part = johnson_share / johnson_ratio_part'],
  'L3': ['mike_share = mike_ratio_part * value_per_part'],
  'L4': ['mike_remaining = mike_share - shirt_cost'],
  'answer_line': ['answer = mike_remaining # FINAL ANSWER'],
  'return_line': ['return answer']},
 'google_gemini-2.5-flash-lite-preview-06-17': {'L2': ['value_per_part = j

In [60]:
from datasets import load_dataset
# Load the GSM8K dataset
gsm8k_dataset = load_dataset("gsm8k", "main", split="train")

In [65]:
sample = gsm8k_dataset[16]
print(f"Index: 16")
print(f"Question: {sample['question']}")
print(f"Answer: {sample['answer']}")

Index: 16
Question: The profit from a business transaction is shared among 2 business partners, Mike and Johnson in the ratio 2:5 respectively. If Johnson got $2500, how much will Mike have after spending some of his share on a shirt that costs $200?
Answer: According to the ratio, for every 5 parts that Johnson gets, Mike gets 2 parts
Since Johnson got $2500, each part is therefore $2500/5 = $<<2500/5=500>>500
Mike will get 2*$500 = $<<2*500=1000>>1000
After buying the shirt he will have $1000-$200 = $<<1000-200=800>>800 left
#### 800


In [66]:
normalized_code_master_dict[25]

{'anthropic_claude-3-5-haiku-20241022': {'L1': ['first_batch_missed_balls = (1 - first_batch_hit_fraction) * first_batch_balls'],
  'L2': ['next_batch_missed_balls = (1 - next_batch_hit_fraction) * next_batch_balls'],
  'L3': ['total_missed_balls = first_batch_missed_balls + next_batch_missed_balls'],
  'answer_line': ['answer = total_missed_balls  # FINAL ANSWER'],
  'return_line': ['return answer']},
 'openai_gpt-4.1-mini': {'L1': ['not_hit_first = (1 - hit_fraction_first) * first_batch'],
  'L2': ['not_hit_second = (1 - hit_fraction_second) * second_batch'],
  'L3': ['total_not_hit = not_hit_first + not_hit_second'],
  'answer_line': ['answer = total_not_hit  # FINAL ANSWER'],
  'return_line': ['return answer']},
 'google_gemini-2.0-flash-thinking-exp': {'L1': ['fraction_not_hit_first_batch = 1 - fraction_hit_first_batch',
   'not_hit_first_batch = fraction_not_hit_first_batch * first_batch_size'],
  'L2': ['fraction_not_hit_second_batch = 1 - fraction_hit_second_batch',
   'not_hit

In [69]:
import ast
import json
import re
from collections import defaultdict
from pathlib import Path

# Assumes that BASE_OUTPUT_DIR, get_code_lines_dict, and normalize_code_lines
# are defined in previous cells.

# Define a new output directory for the final, fully normalized code
FINAL_NORMALIZED_DIR = PROJECT_ROOT / "data" / "code_gen_outputs_final"

# ---------------------------------------------------------------------- #
#  1. Function to generate the "Pre-Normalized" function string
# ---------------------------------------------------------------------- #

def generate_pre_normalized_function(
    problem_index: int, model_name: str
) -> tuple[str, dict] | None:
    """
    Reads a raw model output, normalizes its step-wise logic, and
    reconstructs it into a single, canonical Python function string.

    Returns:
        A tuple containing:
        - The pre-normalized function code as a single string.
        - The normalized_dict mapping steps to code lines.
        Returns None if the source file cannot be processed.
    """
    raw_lines = get_code_lines_dict(problem_index, model_name)
    if not raw_lines:
        return None

    normalized_dict = normalize_code_lines(raw_lines)
    if not normalized_dict:
        return None

    # --- Reconstruct the function ---
    # a. Extract the original function signature
    signature_lines = []
    in_signature = False
    for line in raw_lines.values():
        if line.strip().startswith("def solve("):
            in_signature = True
        if in_signature:
            signature_lines.append(line)
        if in_signature and line.strip().endswith("):"):
            break
    
    # b. Reconstruct the body from the normalized dict
    body_lines = []
    # Ensure a consistent order for reconstruction
    sorted_steps = sorted([key for key in normalized_dict if key.startswith('L')])
    
    for step in sorted_steps:
        for code_line in normalized_dict[step]:
            body_lines.append(f"    {code_line}  # {step}")

    # Add the answer and return lines
    if 'answer_line' in normalized_dict:
        body_lines.append(f"    {normalized_dict['answer_line'][0]}")
    if 'return_line' in normalized_dict:
        body_lines.append(f"    {normalized_dict['return_line'][0]}")
        
    # c. Combine into a full function string
    full_code = "\n".join(signature_lines) + "\n" + "\n".join(body_lines)
    
    return full_code, normalized_dict


# ---------------------------------------------------------------------- #
#  2. AST Transformers for Final Normalization
# ---------------------------------------------------------------------- #

class ConstantFolder(ast.NodeTransformer):
    """
    An AST transformer that finds and evaluates binary operations on constants.
    Example: a tree for `3 + 5` becomes a single Constant node with value `8`.
    """
    def visit_BinOp(self, node):
        self.generic_visit(node) # Ensure children are visited first
        if isinstance(node.left, ast.Constant) and isinstance(node.right, ast.Constant):
            # Both operands are constants, we can evaluate them.
            left_val = node.left.value
            right_val = node.right.value
            op = node.op
            
            # Use a safe evaluation for the specific operation
            if isinstance(op, ast.Add): result = left_val + right_val
            elif isinstance(op, ast.Sub): result = left_val - right_val
            elif isinstance(op, ast.Mult): result = left_val * right_val
            elif isinstance(op, ast.Div): result = left_val / right_val
            # Add other operations as needed (Pow, etc.)
            else: return node # Cannot evaluate, return original node

            return ast.Constant(value=result)
        return node


def _substitute_step_vars(statements: list[ast.stmt]) -> list[ast.stmt]:
    """
    A helper function to perform intermediate variable substitution on a
    list of AST assignment statements for a single logical step.
    """
    # 1. Identify variables defined within this block
    defined_vars = {
        node.targets[0].id for node in statements if isinstance(node, ast.Assign)
    }

    # 2. Identify all variables used within this block
    used_vars = {
        node.id for stmt in statements for node in ast.walk(stmt) if isinstance(node, ast.Name)
    }

    # 3. An intermediate variable is one that is defined and used, but not used elsewhere
    #    (For simplicity here, we assume any var defined in the block is intermediate)
    substitution_map = {
        node.targets[0].id: node.value 
        for node in statements if isinstance(node, ast.Assign)
    }
    
    # Transformer to perform the substitution
    class Substituter(ast.NodeTransformer):
        def visit_Name(self, node):
            if node.id in substitution_map and isinstance(node.ctx, ast.Load):
                # Replace the variable name with its defining expression
                return substitution_map[node.id]
            return node

    # 4. Apply substitution to all statements
    transformer = Substituter()
    substituted_statements = [transformer.visit(s) for s in statements]

    # 5. Filter out the original assignments for the variables we just substituted
    final_statements = [
        s for s in substituted_statements if not (
            isinstance(s, ast.Assign) and s.targets[0].id in defined_vars and s.targets[0].id != 'answer'
        )
    ]
    # This logic is simplified; a perfect implementation would track usage counts.
    # For GSM8K, this should be sufficient.
    return final_statements if final_statements else substituted_statements

def _execute_and_trace(code_str: str, signature: dict) -> dict:
    # Prepare the namespace and define the function
    namespace = {}
    exec(code_str, namespace)
    # Call the function with the default arguments
    result = namespace['solve'](**signature)
    # Optionally, return just the result or modify the function to return locals()
    return {'result': result}

# def _execute_and_trace(code_str: str, signature: dict) -> dict:
#     """Executes a function string and returns a dict of all local variables."""
#     # We must provide the function with its default args to execute it
#     default_args = {name: val for name, val in signature.items()}
    
#     # A namespace to execute the code in and capture the results
#     namespace = {}
#     namespace.update(default_args)
    
#     # Parse the code and find the function body
#     tree = ast.parse(code_str)
#     function_body_nodes = []
#     for node in ast.walk(tree):
#         if isinstance(node, ast.FunctionDef) and node.name == 'solve':
#             function_body_nodes = node.body
#             break
            
#     # Execute each statement of the function body in the namespace
#     for stmt in function_body_nodes:
#         # Wrap statement in a module to make it compilable
#         module_wrapper = ast.Module(body=[stmt], type_ignores=[])
#         code_obj = compile(module_wrapper, '<string>', 'exec')
#         exec(code_obj, namespace)
        
#     # Return all variables defined during execution, excluding builtins
#     return {k: v for k, v in namespace.items() if not k.startswith('__')}


# ---------------------------------------------------------------------- #
#  3. Main Orchestrator Function
# ---------------------------------------------------------------------- #

def normalize_and_save_function(
    pre_normalized_code: str,
    original_signature: dict, # Needed for the trace
    problem_index: int,
    model_name: str,
    output_dir: Path = FINAL_NORMALIZED_DIR,
):
    """
    Takes a pre-normalized function, applies final AST transformations,
    generates a value trace, and saves the results.
    """
    # 1. Parse the pre-normalized code
    try:
        print(repr(pre_normalized_code))
        tree = ast.parse(pre_normalized_code)
    except SyntaxError:
        return # Cannot process
        
    # 2. Apply constant folding across the entire tree
    folder = ConstantFolder()
    tree = folder.visit(tree)
    
    # 3. Apply intermediate variable substitution (This logic is complex and simplified here)
    # A full implementation would be more involved, but this demonstrates the principle.
    # For now, we will skip this complex step and focus on the structure.
    # A more robust version would apply `_substitute_step_vars` to each step's nodes.

    # 4. Unparse to get the final normalized code
    final_code_str = ast.unparse(tree)

    # 5. Generate the value_dict by executing the final code
    value_dict = _execute_and_trace(final_code_str, original_signature)
    
    # 6. Save the results
    output_path = output_dir / str(problem_index) / f"{model_name}.py"
    output_path.parent.mkdir(parents=True, exist_ok=True)
    
    # Prepare the value_dict for pretty printing
    value_dict_str = json.dumps(value_dict, indent=4)
    
    final_output_content = (
        f"{final_code_str}\n\n"
        f"# --- EXECUTION TRACE (value_dict) ---\n"
        f"# The following is a dictionary of all variable values\n"
        f"# when the function is run with its default arguments.\n\n"
        f'"""\n{value_dict_str}\n"""\n'
    )
    
    output_path.write_text(final_output_content, encoding="utf-8")
    

# ---------------------------------------------------------------------- #
#  Example End-to-End Usage
# ---------------------------------------------------------------------- #

# Use a model that has both single and multi-line steps
TARGET_INDEX = 25
TARGET_MODEL = "google_gemini-2.0-flash-thinking-exp"

print(f"--- Running full normalization pipeline for Problem {TARGET_INDEX}, Model '{TARGET_MODEL}' ---")

# 1. Generate the pre-normalized function
result = generate_pre_normalized_function(TARGET_INDEX, TARGET_MODEL)

if result:
    pre_norm_code, norm_dict = result
    print("\n--- Step 1: Pre-Normalized Code ---")
    print(pre_norm_code)
    
    # We need the original signature to run the trace
    from inspect import signature, Parameter
    # This is a bit of a hack to get the signature; a more robust way
    # would be to parse it properly.
    temp_namespace = {}
    exec(pre_norm_code.split('"""')[0], temp_namespace)
    sig = signature(temp_namespace['solve'])
    original_signature_defaults = {
        p.name: p.default for p in sig.parameters.values() 
        if p.default is not Parameter.empty
    }

    # 2. Run the final normalization and save
    normalize_and_save_function(
        pre_normalized_code=pre_norm_code,
        original_signature=original_signature_defaults,
        problem_index=TARGET_INDEX,
        model_name=TARGET_MODEL
    )
    
    # 3. Verify the saved output
    saved_file = FINAL_NORMALIZED_DIR / str(TARGET_INDEX) / f"{TARGET_MODEL}.py"
    if saved_file.exists():
        print(f"\n--- Step 2: Final Normalized File Saved ---")
        print(f"File created at: {saved_file}")
        print("\n--- File Content ---")
        print(saved_file.read_text())

--- Running full normalization pipeline for Problem 25, Model 'google_gemini-2.0-flash-thinking-exp' ---

--- Step 1: Pre-Normalized Code ---
def solve(
    initial_balls: int = 175, # He loads up the machine with 175 tennis balls to start with
    first_batch_size: int = 100, # Out of the first 100 balls
    fraction_hit_first_batch: float = 2/5, # he manages to hit 2/5 of them
    second_batch_size: int = 75, # Of the next 75 tennis balls
    fraction_hit_second_batch: float = 1/3 # he manages to hit 1/3 of them
):
    fraction_not_hit_first_batch = 1 - fraction_hit_first_batch  # L1
    not_hit_first_batch = fraction_not_hit_first_batch * first_batch_size  # L1
    fraction_not_hit_second_batch = 1 - fraction_hit_second_batch  # L2
    not_hit_second_batch = fraction_not_hit_second_batch * second_batch_size  # L2
    total_not_hit = not_hit_first_batch + not_hit_second_batch  # L3
    answer = total_not_hit # FINAL ANSWER
    return answer
'def solve(\n    initial_balls: int = 175, 

In [70]:
import ast
import json
from pathlib import Path

# Assumes previous cells have defined:
# - PROJECT_ROOT
# - BASE_OUTPUT_DIR
# - get_code_lines_dict
# - normalize_code_lines

FINAL_NORMALIZED_DIR = PROJECT_ROOT / "data" / "code_gen_outputs_final"

# ---------------------------------------------------------------------- #
#  1. AST Transformers for Normalization (Unchanged)
# ---------------------------------------------------------------------- #

class ConstantFolder(ast.NodeTransformer):
    """
    An AST transformer that finds and evaluates binary operations on constants.
    """
    def visit_BinOp(self, node):
        self.generic_visit(node)
        if isinstance(node.left, ast.Constant) and isinstance(node.right, ast.Constant):
            left_val = node.left.value
            right_val = node.right.value
            op = node.op
            if isinstance(op, ast.Add): result = left_val + right_val
            elif isinstance(op, ast.Sub): result = left_val - right_val
            elif isinstance(op, ast.Mult): result = left_val * right_val
            elif isinstance(op, ast.Div): result = left_val / right_val
            else: return node
            return ast.Constant(value=result)
        return node

# Note: A robust intermediate variable substituter is highly complex.
# We will omit it for now in favor of a simpler, more reliable pipeline.
# The primary normalization will come from constant folding.

# ---------------------------------------------------------------------- #
#  2. Corrected Tracing and Normalization Logic
# ---------------------------------------------------------------------- #

def _instrument_and_trace(tree: ast.AST) -> tuple[ast.AST, str]:
    """
    Instruments a function's AST to trace all variable assignments and
    generates the code to execute this trace.
    
    Returns:
        A tuple containing:
        - The instrumented AST of the function.
        - The name of the dictionary that will hold the trace results.
    """
    trace_dict_name = "__execution_trace__"
    
    class AssignmentTracer(ast.NodeTransformer):
        def visit_FunctionDef(self, node):
            # Inject the trace dictionary initialization at the start of the function
            trace_init_expr = ast.parse(f"{trace_dict_name} = {{}}").body[0]
            # Also copy function args into the trace dict
            arg_names = [arg.arg for arg in node.args.args]
            arg_trace_expr = ast.parse(
                f"{trace_dict_name}.update({{name: value for name, value in locals().items() if name in {arg_names}}})"
            ).body[0]
            
            # Instrument the rest of the body
            new_body = [trace_init_expr, arg_trace_expr]
            for stmt in node.body:
                new_body.append(stmt)
                if isinstance(stmt, ast.Assign):
                    # For each assignment, add a line to log the result
                    var_name = stmt.targets[0].id
                    log_expr = ast.parse(f"{trace_dict_name}['{var_name}'] = {var_name}").body[0]
                    new_body.append(log_expr)
            
            # Inject a final return of the trace dictionary
            return_trace_expr = ast.parse(f"return {trace_dict_name}").body[0]
            # Replace the original return statement with this one
            final_body = [s for s in new_body if not isinstance(s, ast.Return)]
            final_body.append(return_trace_expr)
            
            node.body = final_body
            return node

    tracer = AssignmentTracer()
    instrumented_tree = tracer.visit(tree)
    return instrumented_tree, trace_dict_name


def generate_final_normalized_code(
    problem_index: int,
    model_name: str,
    base_output_dir: Path = BASE_OUTPUT_DIR,
    output_dir: Path = FINAL_NORMALIZED_DIR,
):
    """
    Orchestrates the full normalization pipeline for a single model file.
    - Reads the original code.
    - Normalizes step comments into a canonical body.
    - Applies AST transformations (constant folding) to the body.
    - Instruments the final AST to produce an execution trace.
    - Saves the final code and its trace to a new file.
    """
    # 1. Get original source code
    raw_lines = get_code_lines_dict(problem_index, model_name)
    if not raw_lines: return
    source_code = "\n".join(raw_lines.values())
    
    # 2. Get normalized step-wise dictionary
    normalized_steps = normalize_code_lines(raw_lines)
    if not normalized_steps: return

    # 3. Parse the original code to get the original function definition
    try:
        original_tree = ast.parse(source_code)
        original_func_def = next(
            (n for n in original_tree.body if isinstance(n, ast.FunctionDef)), None
        )
        if not original_func_def: return
    except SyntaxError:
        return
        
    # 4. Reconstruct the body from the normalized steps
    body_str_parts = []
    sorted_steps = sorted([key for key in normalized_steps if key.startswith('L')])
    for step in sorted_steps:
        body_str_parts.extend(normalized_steps[step])
    body_str_parts.extend(normalized_steps.get('answer_line', []))
    body_str_parts.extend(normalized_steps.get('return_line', []))
    
    body_str = "\n".join(body_str_parts)
    
    # 5. Parse the reconstructed body and apply AST transformations
    try:
        # We need to wrap it in a dummy function to parse it
        body_tree = ast.parse(f"def dummy():\n    {body_str.replace('\n', '\n    ')}")
        body_nodes = body_tree.body[0].body
        
        # Apply constant folding
        folder = ConstantFolder()
        # We must visit each node individually
        transformed_body_nodes = [folder.visit(node) for node in body_nodes]
        
    except SyntaxError:
        # If the reconstructed body is invalid, we can't proceed
        return

    # 6. Replace the original function's body with the new transformed body
    original_func_def.body = transformed_body_nodes
    final_tree = ast.Module(body=[original_func_def], type_ignores=[])
    final_code_str = ast.unparse(final_tree)
    
    # 7. Instrument the final tree to get the execution trace
    instrumented_tree, trace_dict_name = _instrument_and_trace(ast.parse(final_code_str))
    instrumented_code_str = ast.unparse(instrumented_tree)
    
    # 8. Execute the instrumented code to get the value_dict
    try:
        namespace = {}
        exec(instrumented_code_str, namespace)
        solve_func = namespace['solve']
        value_dict = solve_func() # No args needed as defaults are in the code
    except Exception as e:
        print(f"Error executing trace for {model_name}: {e!r}", file=sys.stderr)
        value_dict = {"__error__": f"Execution failed: {e!r}"}
        
    # 9. Save the results
    output_path = output_dir / str(problem_index) / f"{model_name}.py"
    output_path.parent.mkdir(parents=True, exist_ok=True)
    
    value_dict_str = json.dumps(value_dict, indent=4)
    final_output_content = (
        f"{final_code_str}\n\n"
        f'# --- EXECUTION TRACE (value_dict) ---\n'
        f'"""\n{value_dict_str}\n"""\n'
    )
    output_path.write_text(final_output_content, encoding="utf-8")


# ---------------------------------------------------------------------- #
#  Example End-to-End Usage
# ---------------------------------------------------------------------- #

TARGET_INDEX = 25
TARGET_MODEL = "google_gemini-2.0-flash-thinking-exp"

print(f"--- Running full normalization pipeline for Problem {TARGET_INDEX}, Model '{TARGET_MODEL}' ---")

# This single function now handles the entire process.
generate_final_normalized_code(TARGET_INDEX, TARGET_MODEL)

# Verify the saved output
saved_file = FINAL_NORMALIZED_DIR / str(TARGET_INDEX) / f"{TARGET_MODEL}.py"
if saved_file.exists():
    print(f"\n--- Final Normalized File Saved ---")
    print(f"File created at: {saved_file}")
    print("\n--- File Content ---")
    print(saved_file.read_text())

--- Running full normalization pipeline for Problem 25, Model 'google_gemini-2.0-flash-thinking-exp' ---

--- Final Normalized File Saved ---
File created at: /Users/arvindsuresh/Documents/Github/Erdos-DL-June25-Math/data/code_gen_outputs_final/25/google_gemini-2.0-flash-thinking-exp.py

--- File Content ---
def solve(initial_balls: int=175, first_batch_size: int=100, fraction_hit_first_batch: float=2 / 5, second_batch_size: int=75, fraction_hit_second_batch: float=1 / 3):
    fraction_not_hit_first_batch = 1 - fraction_hit_first_batch
    not_hit_first_batch = fraction_not_hit_first_batch * first_batch_size
    fraction_not_hit_second_batch = 1 - fraction_hit_second_batch
    not_hit_second_batch = fraction_not_hit_second_batch * second_batch_size
    total_not_hit = not_hit_first_batch + not_hit_second_batch
    answer = total_not_hit
    return answer

# --- EXECUTION TRACE (value_dict) ---
"""
{
    "initial_balls": 175,
    "first_batch_size": 100,
    "fraction_hit_first_batch": 

In [71]:
import ast
import json
import re
from collections import defaultdict
from pathlib import Path
import libcst as cst
from libcst.tool import dump

# Assumes previous cells have defined:
# - PROJECT_ROOT
# - BASE_OUTPUT_DIR
# - get_code_lines_dict
# - normalize_code_lines

FINAL_NORMALIZED_DIR = PROJECT_ROOT / "data" / "code_gen_outputs_final"

# --- ConstantFolder Transformer (now for libcst) ---

class CSTConstantFolder(cst.CSTTransformer):
    """
    A libcst transformer that finds and evaluates binary operations on constants.
    """
    def leave_BinaryOperation(
        self, original_node: cst.BinaryOperation, updated_node: cst.BinaryOperation
    ) -> cst.BaseExpression:
        # Check if both left and right are simple numbers (Integer or Float)
        if isinstance(updated_node.left, (cst.Integer, cst.Float)) and \
           isinstance(updated_node.right, (cst.Integer, cst.Float)):
            
            # Use Python's eval on the node's code for safety and simplicity
            try:
                result = eval(updated_node.code)
                if isinstance(result, int):
                    return cst.Integer(value=str(result))
                elif isinstance(result, float):
                    return cst.Float(value=str(result))
            except Exception:
                return updated_node # Fallback on any eval error
        return updated_node

# --- Main Orchestrator Function (v2) ---

def generate_final_normalized_code_v2(
    problem_index: int,
    model_name: str,
    base_output_dir: Path = BASE_OUTPUT_DIR,
    output_dir: Path = FINAL_NORMALIZED_DIR,
):
    """
    Orchestrates the full normalization pipeline using libcst for lossless
    transformation, preserving all comments in the function signature.
    """
    # 1. Get original source code
    raw_lines = get_code_lines_dict(problem_index, model_name)
    if not raw_lines: return
    source_code = "\n".join(raw_lines.values())
    
    # 2. Get normalized step-wise dictionary
    normalized_steps = normalize_code_lines(raw_lines)
    if not normalized_steps: return

    # 3. Parse the original code using libcst to preserve everything
    try:
        original_tree = cst.parse_module(source_code)
    except cst.ParserSyntaxError:
        return

    # 4. Reconstruct the body from the normalized steps as a string
    body_str_parts = []
    sorted_steps = sorted([key for key in normalized_steps if key.startswith('L')])
    for step in sorted_steps:
        body_str_parts.extend(normalized_steps[step])
    body_str_parts.extend(normalized_steps.get('answer_line', []))
    body_str_parts.extend(normalized_steps.get('return_line', []))
    
    # Indent each line of the new body correctly
    indented_body_str = "\n".join([f"    {line}" for line in body_str_parts])

    # 5. Parse the new body string into a block of CST nodes
    try:
        # We parse a whole dummy function to get a valid IndentedBlock
        new_body_cst = cst.parse_statement(f"def dummy():\n{indented_body_str}")
        new_body_block = new_body_cst.body
    except cst.ParserSyntaxError:
        return

    # 6. Apply AST transformations (Constant Folding) to the new body
    folder = CSTConstantFolder()
    transformed_body_block = new_body_block.visit(folder)
    
    # 7. Create a transformer to replace the old body with the new one
    class BodyReplacer(cst.CSTTransformer):
        def __init__(self, new_body):
            self.new_body = new_body
        def leave_FunctionDef(
            self, original_node: cst.FunctionDef, updated_node: cst.FunctionDef
        ) -> cst.FunctionDef:
            # If this is the 'solve' function, replace its body
            if updated_node.name.value == "solve":
                return updated_node.with_changes(body=self.new_body)
            return updated_node

    # 8. Apply the replacement
    replacer = BodyReplacer(transformed_body_block)
    final_tree = original_tree.visit(replacer)
    final_code_str = final_tree.code
    
    # 9. Execute and get the value_dict (using the same trace method as before)
    # The _instrument_and_trace function from the previous answer can be reused here
    # as it operates on standard ast, which we can get from the final code string.
    try:
        final_ast_for_trace = ast.parse(final_code_str)
        instrumented_tree, _ = _instrument_and_trace(final_ast_for_trace)
        instrumented_code_str = ast.unparse(instrumented_tree)
        
        namespace = {}
        exec(instrumented_code_str, namespace)
        value_dict = namespace['solve']()
    except Exception as e:
        value_dict = {"__error__": f"Execution failed: {e!r}"}
        
    # 10. Save the results
    output_path = output_dir / str(problem_index) / f"{model_name}.py"
    output_path.parent.mkdir(parents=True, exist_ok=True)
    value_dict_str = json.dumps(value_dict, indent=4)
    final_output_content = (
        f"{final_code_str}\n\n"
        f'# --- EXECUTION TRACE (value_dict) ---\n'
        f'"""\n{value_dict_str}\n"""\n'
    )
    output_path.write_text(final_output_content, encoding="utf-8")


# ---------------------------------------------------------------------- #
#  Example End-to-End Usage
# ---------------------------------------------------------------------- #

TARGET_INDEX = 25
TARGET_MODEL = "google_gemini-2.0-flash-thinking-exp"

# We need the `_instrument_and_trace` function from the previous response.
# Make sure it is defined in your notebook environment.
# (I'm assuming it is for this example to be runnable)
# For clarity, here it is again:
def _instrument_and_trace(tree: ast.AST) -> tuple[ast.AST, str]:
    trace_dict_name = "__execution_trace__"
    class AssignmentTracer(ast.NodeTransformer):
        def visit_FunctionDef(self, node):
            trace_init_expr = ast.parse(f"{trace_dict_name} = {{}}").body[0]
            arg_names = [arg.arg for arg in node.args.args]
            arg_trace_expr = ast.parse(
                f"{trace_dict_name}.update({{name: value for name, value in locals().items() if name in {arg_names}}})"
            ).body[0]
            new_body = [trace_init_expr, arg_trace_expr]
            for stmt in node.body:
                new_body.append(stmt)
                if isinstance(stmt, ast.Assign):
                    var_name = stmt.targets[0].id
                    log_expr = ast.parse(f"{trace_dict_name}['{var_name}'] = {var_name}").body[0]
                    new_body.append(log_expr)
            return_trace_expr = ast.parse(f"return {trace_dict_name}").body[0]
            final_body = [s for s in new_body if not isinstance(s, ast.Return)]
            final_body.append(return_trace_expr)
            node.body = final_body
            return node
    tracer = AssignmentTracer()
    instrumented_tree = tracer.visit(tree)
    return instrumented_tree, trace_dict_name

print(f"--- Running full normalization pipeline (v2) for Problem {TARGET_INDEX}, Model '{TARGET_MODEL}' ---")
generate_final_normalized_code_v2(TARGET_INDEX, TARGET_MODEL)

saved_file = FINAL_NORMALIZED_DIR / str(TARGET_INDEX) / f"{TARGET_MODEL}.py"
if saved_file.exists():
    print(f"\n--- Final Normalized File Saved ---")
    print(f"File created at: {saved_file}")
    print("\n--- File Content ---")
    print(saved_file.read_text())

--- Running full normalization pipeline (v2) for Problem 25, Model 'google_gemini-2.0-flash-thinking-exp' ---

--- Final Normalized File Saved ---
File created at: /Users/arvindsuresh/Documents/Github/Erdos-DL-June25-Math/data/code_gen_outputs_final/25/google_gemini-2.0-flash-thinking-exp.py

--- File Content ---
def solve(
    initial_balls: int = 175, # He loads up the machine with 175 tennis balls to start with
    first_batch_size: int = 100, # Out of the first 100 balls
    fraction_hit_first_batch: float = 2/5, # he manages to hit 2/5 of them
    second_batch_size: int = 75, # Of the next 75 tennis balls
    fraction_hit_second_batch: float = 1/3 # he manages to hit 1/3 of them
):
    fraction_not_hit_first_batch = 1 - fraction_hit_first_batch
    not_hit_first_batch = fraction_not_hit_first_batch * first_batch_size
    fraction_not_hit_second_batch = 1 - fraction_hit_second_batch
    not_hit_second_batch = fraction_not_hit_second_batch * second_batch_size
    total_not_hit = not

In [72]:
import ast
import json
import re
from collections import defaultdict
from pathlib import Path
import libcst as cst
from libcst.tool import dump

# Ensure all previously defined functions and constants are available in the notebook:
# - PROJECT_ROOT, BASE_OUTPUT_DIR, FINAL_NORMALIZED_DIR
# - get_code_lines_dict, normalize_code_lines
# - The CSTConstantFolder class and the _instrument_and_trace function

def generate_final_normalized_code_v3(
    problem_index: int,
    model_name: str,
    base_output_dir: Path = BASE_OUTPUT_DIR,
    output_dir: Path = FINAL_NORMALIZED_DIR,
):
    """
    Orchestrates the full normalization pipeline. (Version 3)
    
    This version correctly reconstructs the final output file to include
    the #: L<n> step marker comments for readability.
    """
    # Steps 1-2: Get raw lines and normalize into step dictionary
    raw_lines = get_code_lines_dict(problem_index, model_name)
    if not raw_lines: return
    source_code = "\n".join(raw_lines.values())
    
    normalized_steps = normalize_code_lines(raw_lines)
    if not normalized_steps: return

    # Step 3: Parse original code with libcst to get the original function node
    try:
        original_tree = cst.parse_module(source_code)
        original_func_def = next(
            (n for n in original_tree.body if isinstance(n, cst.FunctionDef)), None
        )
        if not original_func_def: return
    except cst.ParserSyntaxError:
        return

    # --- THE LOGIC UP TO THIS POINT IS FOR ANALYSIS AND BODY RECONSTRUCTION ---
    # --- WE WILL NOW BUILD TWO THINGS:
    # --- 1. A functional but "flat" code string for execution tracing.
    # --- 2. A formatted, readable code string for saving.

    # 4. Reconstruct the FLAT body string for functional processing
    flat_body_parts = []
    sorted_steps = sorted([key for key in normalized_steps if key.startswith('L')])
    for step in sorted_steps:
        flat_body_parts.extend(normalized_steps[step])
    flat_body_parts.extend(normalized_steps.get('answer_line', []))
    flat_body_parts.extend(normalized_steps.get('return_line', []))
    flat_body_str = "\n".join(flat_body_parts)

    # 5. Create the functional (but unformatted) AST for tracing
    try:
        # This process is unchanged and produces a correct AST for execution
        body_tree = ast.parse(f"def dummy():\n    {flat_body_str.replace('\n', '\n    ')}")
        body_nodes = body_tree.body[0].body
        temp_func_def = ast.parse(source_code).body[0]
        temp_func_def.body = body_nodes
        final_functional_code_str = ast.unparse(temp_func_def)
    except (SyntaxError, IndexError):
        return

    # 6. Execute trace on the functional version
    try:
        instrumented_tree, _ = _instrument_and_trace(ast.parse(final_functional_code_str))
        instrumented_code_str = ast.unparse(instrumented_tree)
        namespace = {}
        exec(instrumented_code_str, namespace)
        value_dict = namespace['solve']()
    except Exception as e:
        value_dict = {"__error__": f"Execution failed: {e!r}"}
        
    # --- NEW LOGIC FOR FORMATTED OUTPUT ---
    
    # 7. Reconstruct the FORMATTED code string for saving
    # Get the original signature string using libcst's unparser
    signature_str = original_tree.code_for_node(original_func_def.params)
    signature_docstring_str = f"def solve{signature_str}:\n"
    if original_func_def.get_docstring():
        signature_docstring_str += f'    """{original_func_def.get_docstring()}"""\n'

    formatted_body_parts = []
    for step in sorted_steps:
        # Add the step marker comment
        formatted_body_parts.append(f"    #: {step}")
        # Add the code lines for that step
        for code_line in normalized_steps[step]:
            formatted_body_parts.append(f"    {code_line}")
        formatted_body_parts.append("") # Add a blank line for readability

    # Add the final answer and return lines
    if 'answer_line' in normalized_steps:
        formatted_body_parts.append(f"    {normalized_steps['answer_line'][0]}")
    if 'return_line' in normalized_steps:
         formatted_body_parts.append(f"    {normalized_steps['return_line'][0]}")

    formatted_code_str = signature_docstring_str + "\n".join(formatted_body_parts)

    # 8. Save the formatted code and the trace
    output_path = output_dir / str(problem_index) / f"{model_name}.py"
    output_path.parent.mkdir(parents=True, exist_ok=True)
    
    value_dict_str = json.dumps(value_dict, indent=4)
    final_output_content = (
        f"{formatted_code_str}\n\n"
        f'# --- EXECUTION TRACE (value_dict) ---\n'
        f'"""\n{value_dict_str}\n"""\n'
    )
    output_path.write_text(final_output_content, encoding="utf-8")


# ---------------------------------------------------------------------- #
#  Example Usage
# ---------------------------------------------------------------------- #

# Make sure _instrument_and_trace is defined from the previous cell
TARGET_INDEX = 25
TARGET_MODEL = "google_gemini-2.0-flash-thinking-exp"

print(f"--- Running full normalization pipeline (v3) for Problem {TARGET_INDEX}, Model '{TARGET_MODEL}' ---")
generate_final_normalized_code_v3(TARGET_INDEX, TARGET_MODEL)

saved_file = FINAL_NORMALIZED_DIR / str(TARGET_INDEX) / f"{TARGET_MODEL}.py"
if saved_file.exists():
    print(f"\n--- Final Normalized File Saved ---")
    print(f"File created at: {saved_file}")
    print("\n--- File Content ---")
    print(saved_file.read_text())

--- Running full normalization pipeline (v3) for Problem 25, Model 'google_gemini-2.0-flash-thinking-exp' ---

--- Final Normalized File Saved ---
File created at: /Users/arvindsuresh/Documents/Github/Erdos-DL-June25-Math/data/code_gen_outputs_final/25/google_gemini-2.0-flash-thinking-exp.py

--- File Content ---
def solveinitial_balls: int = 175, # He loads up the machine with 175 tennis balls to start with
    first_batch_size: int = 100, # Out of the first 100 balls
    fraction_hit_first_batch: float = 2/5, # he manages to hit 2/5 of them
    second_batch_size: int = 75, # Of the next 75 tennis balls
    fraction_hit_second_batch: float = 1/3 # he manages to hit 1/3 of them
:
    """Index: 25.
Returns: the total number of tennis balls Ralph did not hit."""
    #: L1
    fraction_not_hit_first_batch = 1 - fraction_hit_first_batch
    not_hit_first_batch = fraction_not_hit_first_batch * first_batch_size

    #: L2
    fraction_not_hit_second_batch = 1 - fraction_hit_second_batch
    

In [73]:
import ast
import json
import re
from collections import defaultdict
from pathlib import Path
import libcst as cst
from fractions import Fraction # Import the Fraction type

# Assumes previous cells have defined:
# - PROJECT_ROOT, BASE_OUTPUT_DIR, FINAL_NORMALIZED_DIR
# - get_code_lines_dict, normalize_code_lines
# - The _instrument_and_trace function

# --- NEW: Custom JSON Encoder for Fractions ---
class FractionEncoder(json.JSONEncoder):
    def default(self, obj):
        if isinstance(obj, Fraction):
            # Represent Fraction(2, 5) as the string "2/5"
            return f"{obj.numerator}/{obj.denominator}"
        return json.JSONEncoder.default(self, obj)

# --- NEW: AST Transformer to convert numbers to Fractions ---
class NumberToFractionTransformer(ast.NodeTransformer):
    """
    Transforms integer and float constants into Fraction objects in an AST.
    Example: `2` becomes `Fraction(2, 1)`, `0.5` becomes `Fraction(1, 2)`.
    """
    def visit_Constant(self, node):
        if isinstance(node.value, int):
            return ast.Call(
                func=ast.Name(id='Fraction', ctx=ast.Load()),
                args=[ast.Constant(value=node.value), ast.Constant(value=1)],
                keywords=[]
            )
        if isinstance(node.value, float):
            # Convert float to fraction to preserve precision
            frac = Fraction(node.value).limit_denominator()
            return ast.Call(
                func=ast.Name(id='Fraction', ctx=ast.Load()),
                args=[ast.Constant(value=frac.numerator), ast.Constant(value=frac.denominator)],
                keywords=[]
            )
        return node

# --- Final Orchestrator Function (v4) ---
def generate_final_normalized_code_v4(
    problem_index: int,
    model_name: str,
    base_output_dir: Path = BASE_OUTPUT_DIR,
    output_dir: Path = FINAL_NORMALIZED_DIR,
):
    """
    Orchestrates the full normalization pipeline. (Version 4)
    - Fixes function signature reconstruction.
    - Uses the `fractions` module to preserve exact precision.
    """
    raw_lines = get_code_lines_dict(problem_index, model_name)
    if not raw_lines: return
    source_code = "\n".join(raw_lines.values())
    normalized_steps = normalize_code_lines(raw_lines)
    if not normalized_steps: return

    try:
        original_tree = cst.parse_module(source_code)
        original_func_def = next((n for n in original_tree.body if isinstance(n, cst.FunctionDef)), None)
        if not original_func_def: return
    except cst.ParserSyntaxError:
        return

    # --- 1. Reconstruct the signature CORRECTLY ---
    params_str = original_tree.code_for_node(original_func_def.params)
    signature_header = f"def solve({params_str}):"
    docstring = original_func_def.get_docstring()
    docstring_str = f'\n    """{docstring}"""' if docstring else ""

    # --- 2. Reconstruct the FORMATTED body with step markers for saving ---
    formatted_body_parts = []
    sorted_steps = sorted([key for key in normalized_steps if key.startswith('L')])
    for step in sorted_steps:
        formatted_body_parts.append(f"    #: {step}")
        for code_line in normalized_steps[step]:
            formatted_body_parts.append(f"    {code_line}")
        formatted_body_parts.append("")
    if 'answer_line' in normalized_steps:
        formatted_body_parts.append(f"    {normalized_steps['answer_line'][0]}")
    if 'return_line' in normalized_steps:
         formatted_body_parts.append(f"    {normalized_steps['return_line'][0]}")
    formatted_code_str = f"{signature_header}{docstring_str}\n" + "\n".join(formatted_body_parts)

    # --- 3. Reconstruct the FLAT body for functional processing ---
    flat_body_parts = []
    for step in sorted_steps:
        flat_body_parts.extend(normalized_steps[step])
    flat_body_parts.extend(normalized_steps.get('answer_line', []))
    flat_body_parts.extend(normalized_steps.get('return_line', []))
    flat_body_str = "\n".join(flat_body_parts)

    # --- 4. Create the AST to be instrumented, now with Fractions ---
    try:
        # Create a full functional code string to parse
        functional_code_str = f"{signature_header}{docstring_str}\n" + "\n".join([f"    {line}" for line in flat_body_parts])
        # Add the Fraction import
        functional_code_str = "from fractions import Fraction\n" + functional_code_str
        
        tree_for_instrument = ast.parse(functional_code_str)
        
        # Transform all numbers to Fraction objects
        transformer = NumberToFractionTransformer()
        tree_for_instrument = transformer.visit(tree_for_instrument)
        
        # Instrument the fraction-aware tree to get the execution trace
        instrumented_tree, _ = _instrument_and_trace(tree_for_instrument)
        instrumented_code_str = ast.unparse(instrumented_tree)
        
        namespace = {}
        exec(instrumented_code_str, namespace)
        value_dict = namespace['solve']()
    except Exception as e:
        value_dict = {"__error__": f"Execution failed: {e!r}"}
        
    # --- 5. Save the formatted code and the trace ---
    output_path = output_dir / str(problem_index) / f"{model_name}.py"
    output_path.parent.mkdir(parents=True, exist_ok=True)
    
    # Use the custom encoder to serialize the value_dict
    value_dict_str = json.dumps(value_dict, indent=4, cls=FractionEncoder)
    final_output_content = (
        f"{formatted_code_str}\n\n"
        f'# --- EXECUTION TRACE (value_dict) ---\n'
        f'"""\n{value_dict_str}\n"""\n'
    )
    output_path.write_text(final_output_content, encoding="utf-8")

# ---------------------------------------------------------------------- #
#  Example Usage
# ---------------------------------------------------------------------- #

# Make sure _instrument_and_trace is defined from the previous cell
TARGET_INDEX = 25
TARGET_MODEL = "google_gemini-2.0-flash-thinking-exp"

print(f"--- Running full normalization pipeline (v3) for Problem {TARGET_INDEX}, Model '{TARGET_MODEL}' ---")
generate_final_normalized_code_v4(TARGET_INDEX, TARGET_MODEL)

saved_file = FINAL_NORMALIZED_DIR / str(TARGET_INDEX) / f"{TARGET_MODEL}.py"
if saved_file.exists():
    print(f"\n--- Final Normalized File Saved ---")
    print(f"File created at: {saved_file}")
    print("\n--- File Content ---")
    print(saved_file.read_text())

--- Running full normalization pipeline (v3) for Problem 25, Model 'google_gemini-2.0-flash-thinking-exp' ---

--- Final Normalized File Saved ---
File created at: /Users/arvindsuresh/Documents/Github/Erdos-DL-June25-Math/data/code_gen_outputs_final/25/google_gemini-2.0-flash-thinking-exp.py

--- File Content ---
def solve(initial_balls: int = 175, # He loads up the machine with 175 tennis balls to start with
    first_batch_size: int = 100, # Out of the first 100 balls
    fraction_hit_first_batch: float = 2/5, # he manages to hit 2/5 of them
    second_batch_size: int = 75, # Of the next 75 tennis balls
    fraction_hit_second_batch: float = 1/3 # he manages to hit 1/3 of them
):
    """Index: 25.
Returns: the total number of tennis balls Ralph did not hit."""
    #: L1
    fraction_not_hit_first_batch = 1 - fraction_hit_first_batch
    not_hit_first_batch = fraction_not_hit_first_batch * first_batch_size

    #: L2
    fraction_not_hit_second_batch = 1 - fraction_hit_second_batch
  

In [74]:
import ast
import json
import re
from collections import defaultdict
from pathlib import Path
from fractions import Fraction

# (Assuming other setup code like PROJECT_ROOT, get_code_lines_dict, etc., is present)

# --- NumberToFractionTransformer (Unchanged) ---
class NumberToFractionTransformer(ast.NodeTransformer):
    def visit_Constant(self, node):
        if isinstance(node.value, int):
            return ast.Call(func=ast.Name(id='Fraction', ctx=ast.Load()), args=[ast.Constant(value=node.value), ast.Constant(value=1)], keywords=[])
        if isinstance(node.value, float):
            frac = Fraction(node.value).limit_denominator()
            return ast.Call(func=ast.Name(id='Fraction', ctx=ast.Load()), args=[ast.Constant(value=frac.numerator), ast.Constant(value=frac.denominator)], keywords=[])
        return node

# --- NEW: Helper function for manual serialization ---
def _convert_fractions_to_strings(data: dict) -> dict:
    """
    Recursively iterates through a dictionary and converts all Fraction
    objects into their string representation 'numerator/denominator'.
    """
    processed_dict = {}
    for key, value in data.items():
        if isinstance(value, Fraction):
            processed_dict[key] = f"{value.numerator}/{value.denominator}"
        elif isinstance(value, dict): # For potential future nesting
            processed_dict[key] = _convert_fractions_to_strings(value)
        else:
            processed_dict[key] = value
    return processed_dict

# --- Final Orchestrator Function (v5 - with manual serialization) ---
def generate_final_normalized_code_v5(
    problem_index: int,
    model_name: str,
    base_output_dir: Path = BASE_OUTPUT_DIR,
    output_dir: Path = FINAL_NORMALIZED_DIR,
):
    """
    Orchestrates the full normalization pipeline. (Version 5)
    - Uses manual pre-processing of the value_dict to serialize Fractions
      without a custom JSONEncoder.
    """
    # (Steps 1-3: Reading files, normalizing steps, parsing CST... are the same)
    # ... [Code from v4 for steps 1-3] ...
    # This is a simplified placeholder for the existing logic.
    raw_lines = get_code_lines_dict(problem_index, model_name)
    if not raw_lines: return
    source_code = "\n".join(raw_lines.values())
    normalized_steps = normalize_code_lines(raw_lines)
    original_tree = cst.parse_module(source_code)
    original_func_def = next((n for n in original_tree.body if isinstance(n, cst.FunctionDef)), None)
    if not original_func_def: return
    
    params_str = original_tree.code_for_node(original_func_def.params)
    signature_header = f"def solve({params_str}):"
    docstring = original_func_def.get_docstring()
    docstring_str = f'\n    """{docstring}"""' if docstring else ""

    # (Step 4: Create functional code string... is the same)
    # ... [Code from v4 for step 4] ...
    flat_body_parts = []
    sorted_steps = sorted([key for key in normalized_steps if key.startswith('L')])
    for step in sorted_steps: flat_body_parts.extend(normalized_steps[step])
    flat_body_parts.extend(normalized_steps.get('answer_line', []))
    flat_body_parts.extend(normalized_steps.get('return_line', []))
    functional_code_str = f"from fractions import Fraction\n{signature_header}{docstring_str}\n" + "\n".join([f"    {line}" for line in flat_body_parts])

    # (Step 5: Instrument and trace... is the same)
    # ... [Code from v4 for step 5, leading to `value_dict` with Fraction objects] ...
    try:
        tree_for_instrument = ast.parse(functional_code_str)
        transformer = NumberToFractionTransformer()
        tree_for_instrument = transformer.visit(tree_for_instrument)
        instrumented_tree, _ = _instrument_and_trace(tree_for_instrument)
        instrumented_code_str = ast.unparse(instrumented_tree)
        namespace = {}
        exec(instrumented_code_str, namespace)
        value_dict = namespace['solve']()
    except Exception as e:
        value_dict = {"__error__": f"Execution failed: {e!r}"}
        
    # --- 6. NEW: Manually convert Fractions to strings before JSON dump ---
    serializable_value_dict = _convert_fractions_to_strings(value_dict)
    value_dict_str = json.dumps(serializable_value_dict, indent=4)

    # --- 7. Reconstruct formatted code for saving (Unchanged) ---
    formatted_body_parts = []
    for step in sorted_steps:
        formatted_body_parts.append(f"    #: {step}")
        for code_line in normalized_steps[step]: formatted_body_parts.append(f"    {code_line}")
        formatted_body_parts.append("")
    if 'answer_line' in normalized_steps: formatted_body_parts.append(f"    {normalized_steps['answer_line'][0]}")
    if 'return_line' in normalized_steps: formatted_body_parts.append(f"    {normalized_steps['return_line'][0]}")
    formatted_code_str = f"{signature_header}{docstring_str}\n" + "\n".join(formatted_body_parts)

    # --- 8. Save the final output (Unchanged) ---
    output_path = output_dir / str(problem_index) / f"{model_name}.py"
    output_path.parent.mkdir(parents=True, exist_ok=True)
    final_output_content = (
        f"{formatted_code_str}\n\n"
        f'# --- EXECUTION TRACE (value_dict) ---\n'
        f'"""\n{value_dict_str}\n"""\n'
    )
    output_path.write_text(final_output_content, encoding="utf-8")

# ---------------------------------------------------------------------- #
#  Example Usage
# ---------------------------------------------------------------------- #

# Make sure _instrument_and_trace is defined from the previous cell
TARGET_INDEX = 25
TARGET_MODEL = "google_gemini-2.0-flash-thinking-exp"

print(f"--- Running full normalization pipeline (v3) for Problem {TARGET_INDEX}, Model '{TARGET_MODEL}' ---")
generate_final_normalized_code_v5(TARGET_INDEX, TARGET_MODEL)

saved_file = FINAL_NORMALIZED_DIR / str(TARGET_INDEX) / f"{TARGET_MODEL}.py"
if saved_file.exists():
    print(f"\n--- Final Normalized File Saved ---")
    print(f"File created at: {saved_file}")
    print("\n--- File Content ---")
    print(saved_file.read_text())

--- Running full normalization pipeline (v3) for Problem 25, Model 'google_gemini-2.0-flash-thinking-exp' ---

--- Final Normalized File Saved ---
File created at: /Users/arvindsuresh/Documents/Github/Erdos-DL-June25-Math/data/code_gen_outputs_final/25/google_gemini-2.0-flash-thinking-exp.py

--- File Content ---
def solve(initial_balls: int = 175, # He loads up the machine with 175 tennis balls to start with
    first_batch_size: int = 100, # Out of the first 100 balls
    fraction_hit_first_batch: float = 2/5, # he manages to hit 2/5 of them
    second_batch_size: int = 75, # Of the next 75 tennis balls
    fraction_hit_second_batch: float = 1/3 # he manages to hit 1/3 of them
):
    """Index: 25.
Returns: the total number of tennis balls Ralph did not hit."""
    #: L1
    fraction_not_hit_first_batch = 1 - fraction_hit_first_batch
    not_hit_first_batch = fraction_not_hit_first_batch * first_batch_size

    #: L2
    fraction_not_hit_second_batch = 1 - fraction_hit_second_batch
  

In [84]:
x = Fraction(1,3)
str(x)

'1/3'

In [82]:
float(x)

0.25

In [97]:
import re
from collections import defaultdict
import json

# ---------------------------------------------------------------------- #
# 0. Utility Function to Get Code Lines
# ---------------------------------------------------------------------- #

def get_code_lines_dict(problem_index: int, model_name: str, base_output_dir: Path = BASE_OUTPUT_DIR) -> dict[int, str] | None:
    """
    Returns a dict mapping line numbers (0-based) to the verbatim lines of code
    for the given problem index and model name.
    """
    import sys
    problem_dir = base_output_dir / str(problem_index)
    file_path = problem_dir / f"{model_name}.py"
    if not file_path.exists():
        print(f"[Error] File not found: {file_path}", file=sys.stderr)
        return None
    with open(file_path, "r", encoding="utf-8") as f:
        lines = f.readlines()
    return {i: line.rstrip("\n") for i, line in enumerate(lines)}

# Regex to find the solution step comment, e.g., "#: L1"
_TRACE_RE = re.compile(r"#:\s*L(\d+)")

# ---------------------------------------------------------------------- #
#  1. Modified Function to Extract All Parts of the Code
# ---------------------------------------------------------------------- #

import re
from collections import defaultdict
import json

# Regex to find the solution step comment, e.g., "#: L1"
_TRACE_RE = re.compile(r"#:\s*L(\d+)")

# ---------------------------------------------------------------------- #
#  1. Corrected Function to Extract All Parts of the Code
# ---------------------------------------------------------------------- #

def extract_code_parts_v2(raw_lines_dict: dict[int, str]) -> dict[str, list[str]]:
    """
    Normalizes raw code lines into a canonical dictionary. (Version 2)
    
    This version correctly handles multi-line docstrings.
    """
    normalized_dict = defaultdict(list)
    sorted_lines = sorted(raw_lines_dict.items())
    
    # State machine variables
    state = "SEARCHING"  # SEARCHING, SIGNATURE, BODY
    signature_lines = []
    body_lines_dict = {}
    
    # --- First pass: Separate signature from body ---
    for i, (line_num, line_content) in enumerate(sorted_lines):
        stripped_line = line_content.strip()
        
        if state == "SEARCHING" and stripped_line.startswith("def solve("):
            state = "SIGNATURE"
        
        if state == "SIGNATURE":
            signature_lines.append(line_content)
            if stripped_line.endswith("):"):
                state = "BODY"
                # The rest of the lines belong to the body
                body_lines_dict = dict(sorted_lines[i+1:])
                break
    
    normalized_dict['function_signature'] = ["\n".join(signature_lines)]
    
    # --- Second pass: Separate docstring from the rest of the body ---
    docstring_lines = []
    rest_of_body_dict = {}
    
    body_lines_sorted = sorted(body_lines_dict.items())
    in_docstring = False
    
    for i, (line_num, line_content) in enumerate(body_lines_sorted):
        stripped_line = line_content.strip()
        
        if not in_docstring and ('"""' in stripped_line or "'''" in stripped_line):
            in_docstring = True
            docstring_lines.append(line_content)
            if (stripped_line.count('"""') == 2 or stripped_line.count("'''") == 2) and len(stripped_line) > 3:
                # Handle single-line docstring
                in_docstring = False
                # The rest of the body starts from the next line
                rest_of_body_dict = dict(body_lines_sorted[i+1:])
                break
            continue

        if in_docstring:
            docstring_lines.append(line_content)
            if '"""' in stripped_line or "'''" in stripped_line:
                in_docstring = False
                # The rest of the body starts from the next line
                rest_of_body_dict = dict(body_lines_sorted[i+1:])
                break
    
    if docstring_lines:
        normalized_dict['docstring'] = ["\n".join(docstring_lines).strip()]
    else:
        # No docstring found, the whole body is code
        rest_of_body_dict = body_lines_dict

    # --- Third pass: Process the remaining body lines ---
    body_dict = normalize_code_lines(rest_of_body_dict)
    normalized_dict.update(body_dict)
            
    return dict(normalized_dict)


# ---------------------------------------------------------------------- #
#  2. New Function to Collapse Intermediate Variables
# ---------------------------------------------------------------------- #

def collapse_intermediate_vars(normalized_dict: dict[str, list[str]]) -> dict[str, list[str]]:
    """
    Collapses intermediate variables in a normalized code dictionary.
    
    This uses a simple string search-and-replace logic. It iterates through
    the steps and if the variable defined in step L(n) is used in step L(n+1),
    it substitutes the expression from L(n) into L(n+1) and removes L(n).
    
    Args:
        normalized_dict: A dictionary from the `extract_code_parts` function.

    Returns:
        A new dictionary with intermediate steps collapsed.
    """
    # Work on a copy to avoid modifying the original
    collapsed_dict = {k: v[:] for k, v in normalized_dict.items()}
    
    # Get a sorted list of the logical steps to process in order
    step_keys = sorted([key for key in collapsed_dict if key.startswith('L')])
    
    # A set to keep track of steps that have been collapsed and should be removed
    steps_to_remove = set()

    for i in range(len(step_keys) - 1):
        current_step_key = step_keys[i]
        next_step_key = step_keys[i+1]
        
        # This simple logic assumes one line of code per step after normalization.
        # This holds true for the vast majority of our normalized outputs.
        if len(collapsed_dict.get(current_step_key, [])) == 1 and \
           len(collapsed_dict.get(next_step_key, [])) == 1:
            
            current_line = collapsed_dict[current_step_key][0]
            next_line = collapsed_dict[next_step_key][0]

            # Ensure both lines are assignments
            if '=' not in current_line or '=' not in next_line:
                continue

            # Get the variable (LHS) defined in the current step
            current_lhs = current_line.split('=')[0].strip()
            # Get the expression (RHS) of the current step
            current_rhs = current_line.split('=', 1)[1].strip()

            next_line_rhs = next_line.split('=', 1)[1]

            # Check if the variable is used in the next step's expression
            # Use word boundaries (\b) to avoid matching substrings (e.g., 'x' in 'x_y')
            if re.search(r'\b' + re.escape(current_lhs) + r'\b', next_line_rhs):
                # Perform the substitution, enclosing the expression in parentheses
                substituted_rhs = re.sub(
                    r'\b' + re.escape(current_lhs) + r'\b',
                    f"({current_rhs})",
                    next_line_rhs
                )
                
                # Update the next line in the dictionary
                next_line_lhs = next_line.split('=')[0].strip()
                collapsed_dict[next_step_key][0] = f"{next_line_lhs} = {substituted_rhs.strip()}"
                
                # Mark the current step for removal
                steps_to_remove.add(current_step_key)

    # Remove the collapsed steps from the final dictionary
    for step in steps_to_remove:
        del collapsed_dict[step]
        
    return collapsed_dict

# ---------------------------------------------------------------------- #
#  Example Usage
# ---------------------------------------------------------------------- #
# Use a model where steps can be collapsed
TARGET_INDEX = 52
TARGET_MODEL = "google_gemini-2.0-flash-thinking-exp"

print(f"--- Running full TEXT-BASED normalization for Problem {TARGET_INDEX}, Model '{TARGET_MODEL}' ---")

# 1. Get raw lines
raw_lines = get_code_lines_dict(TARGET_INDEX, TARGET_MODEL)

if raw_lines:
    # 2. Extract parts (Signature, Docstring, Body)
    extracted_parts = extract_code_parts_v2(raw_lines)
    print("\n--- Step 1: Extracted Code Parts ---")
    print(json.dumps(extracted_parts, indent=4))
    
    # 3. Collapse intermediate variables
    collapsed_code = collapse_intermediate_vars(extracted_parts)
    print("\n--- Step 2: Code After Collapsing Intermediate Variables ---")
    print(json.dumps(collapsed_code, indent=4))

--- Running full TEXT-BASED normalization for Problem 52, Model 'google_gemini-2.0-flash-thinking-exp' ---

--- Step 1: Extracted Code Parts ---
{
    "function_signature": [
        "def solve(\n    num_boxes: int = 10, # ten boxes of pencils\n    pencils_kept: int = 10, # He kept ten pencils\n    num_friends: int = 5, # shared the remaining pencils equally with his five friends\n    pencils_per_friend: int = 8 # his friends got eight pencils each\n):"
    ],
    "docstring": [
        "\"\"\"Index: 52.\n    Returns: the number of pencils in each box.\n    \"\"\""
    ],
    "L1": [
        "pencils_shared = num_friends * pencils_per_friend"
    ],
    "L2": [
        "total_pencils = pencils_kept + pencils_shared"
    ],
    "L3": [
        "pencils_in_each_box = total_pencils / num_boxes"
    ],
    "answer_line": [
        "answer = pencils_in_each_box # FINAL ANSWER"
    ],
    "return_line": [
        "return answer"
    ]
}

--- Step 2: Code After Collapsing Intermediate Variabl

In [100]:
display(collapsed_code)

{'function_signature': ['def solve(\n    num_boxes: int = 10, # ten boxes of pencils\n    pencils_kept: int = 10, # He kept ten pencils\n    num_friends: int = 5, # shared the remaining pencils equally with his five friends\n    pencils_per_friend: int = 8 # his friends got eight pencils each\n):'],
 'docstring': ['"""Index: 52.\n    Returns: the number of pencils in each box.\n    """'],
 'L3': ['pencils_in_each_box = (pencils_kept + (num_friends * pencils_per_friend)) / num_boxes'],
 'answer_line': ['answer = pencils_in_each_box # FINAL ANSWER'],
 'return_line': ['return answer']}

In [102]:
import re
from collections import defaultdict
import json

# ---------------------------------------------------------------------- #
# 0. Utility Function to Get Code Lines
# ---------------------------------------------------------------------- #

def get_code_lines_dict(problem_index: int, model_name: str, base_output_dir: Path = BASE_OUTPUT_DIR) -> dict[int, str] | None:
    """
    Returns a dict mapping line numbers (0-based) to the verbatim lines of code
    for the given problem index and model name.
    """
    import sys
    problem_dir = base_output_dir / str(problem_index)
    file_path = problem_dir / f"{model_name}.py"
    if not file_path.exists():
        print(f"[Error] File not found: {file_path}", file=sys.stderr)
        return None
    with open(file_path, "r", encoding="utf-8") as f:
        lines = f.readlines()
    return {i: line.rstrip("\n") for i, line in enumerate(lines)}

# Regex to find the solution step comment, e.g., "#: L1"
_TRACE_RE = re.compile(r"#:\s*L(\d+)")

# ---------------------------------------------------------------------- #
#  1. Modified Function to Extract All Parts of the Code
# ---------------------------------------------------------------------- #

import re
from collections import defaultdict
import json

# Regex to find the solution step comment, e.g., "#: L1"
_TRACE_RE = re.compile(r"#:\s*L(\d+)")

# ---------------------------------------------------------------------- #
#  1. Corrected Function to Extract All Parts of the Code
# ---------------------------------------------------------------------- #

def extract_code_parts_v2(raw_lines_dict: dict[int, str]) -> dict[str, list[str]]:
    """
    Normalizes raw code lines into a canonical dictionary. (Version 2)
    
    This version correctly handles multi-line docstrings.
    """
    normalized_dict = defaultdict(list)
    sorted_lines = sorted(raw_lines_dict.items())
    
    # State machine variables
    state = "SEARCHING"  # SEARCHING, SIGNATURE, BODY
    signature_lines = []
    body_lines_dict = {}
    
    # --- First pass: Separate signature from body ---
    for i, (line_num, line_content) in enumerate(sorted_lines):
        stripped_line = line_content.strip()
        
        if state == "SEARCHING" and stripped_line.startswith("def solve("):
            state = "SIGNATURE"
        
        if state == "SIGNATURE":
            signature_lines.append(line_content)
            if stripped_line.endswith("):"):
                state = "BODY"
                # The rest of the lines belong to the body
                body_lines_dict = dict(sorted_lines[i+1:])
                break
    
    normalized_dict['function_signature'] = ["\n".join(signature_lines)]
    
    # --- Second pass: Separate docstring from the rest of the body ---
    docstring_lines = []
    rest_of_body_dict = {}
    
    body_lines_sorted = sorted(body_lines_dict.items())
    in_docstring = False
    
    for i, (line_num, line_content) in enumerate(body_lines_sorted):
        stripped_line = line_content.strip()
        
        if not in_docstring and ('"""' in stripped_line or "'''" in stripped_line):
            in_docstring = True
            docstring_lines.append(line_content)
            if (stripped_line.count('"""') == 2 or stripped_line.count("'''") == 2) and len(stripped_line) > 3:
                # Handle single-line docstring
                in_docstring = False
                # The rest of the body starts from the next line
                rest_of_body_dict = dict(body_lines_sorted[i+1:])
                break
            continue

        if in_docstring:
            docstring_lines.append(line_content)
            if '"""' in stripped_line or "'''" in stripped_line:
                in_docstring = False
                # The rest of the body starts from the next line
                rest_of_body_dict = dict(body_lines_sorted[i+1:])
                break
    
    if docstring_lines:
        normalized_dict['docstring'] = ["\n".join(docstring_lines).strip()]
    else:
        # No docstring found, the whole body is code
        rest_of_body_dict = body_lines_dict

    # --- Third pass: Process the remaining body lines ---
    body_dict = normalize_code_lines(rest_of_body_dict)
    normalized_dict.update(body_dict)
            
    return dict(normalized_dict)


# ---------------------------------------------------------------------- #
#  2. New Function to Collapse Intermediate Variables
# ---------------------------------------------------------------------- #

def collapse_intermediate_vars(normalized_dict: dict[str, list[str]]) -> dict[str, list[str]]:
    """
    Collapses intermediate variables in a normalized code dictionary.
    
    This uses a simple string search-and-replace logic. It iterates through
    the steps and if the variable defined in step L(n) is used in step L(n+1),
    it substitutes the expression from L(n) into L(n+1) and removes L(n).
    
    Args:
        normalized_dict: A dictionary from the `extract_code_parts` function.

    Returns:
        A new dictionary with intermediate steps collapsed.
    """
    # Work on a copy to avoid modifying the original
    collapsed_dict = {k: v[:] for k, v in normalized_dict.items()}
    
    # Get a sorted list of the logical steps to process in order
    step_keys = sorted([key for key in collapsed_dict if key.startswith('L')])
    
    # A set to keep track of steps that have been collapsed and should be removed
    steps_to_remove = set()

    for i in range(len(step_keys) - 1):
        current_step_key = step_keys[i]
        next_step_key = step_keys[i+1]
        
        # This simple logic assumes one line of code per step after normalization.
        # This holds true for the vast majority of our normalized outputs.
        if len(collapsed_dict.get(current_step_key, [])) == 1 and \
           len(collapsed_dict.get(next_step_key, [])) == 1:
            
            current_line = collapsed_dict[current_step_key][0]
            next_line = collapsed_dict[next_step_key][0]

            # Ensure both lines are assignments
            if '=' not in current_line or '=' not in next_line:
                continue

            # Get the variable (LHS) defined in the current step
            current_lhs = current_line.split('=')[0].strip()
            # Get the expression (RHS) of the current step
            current_rhs = current_line.split('=', 1)[1].strip()

            next_line_rhs = next_line.split('=', 1)[1]

            # Check if the variable is used in the next step's expression
            # Use word boundaries (\b) to avoid matching substrings (e.g., 'x' in 'x_y')
            if re.search(r'\b' + re.escape(current_lhs) + r'\b', next_line_rhs):
                # Perform the substitution, enclosing the expression in parentheses
                substituted_rhs = re.sub(
                    r'\b' + re.escape(current_lhs) + r'\b',
                    f"({current_rhs})",
                    next_line_rhs
                )
                
                # Update the next line in the dictionary
                next_line_lhs = next_line.split('=')[0].strip()
                collapsed_dict[next_step_key][0] = f"{next_line_lhs} = {substituted_rhs.strip()}"
                
                # Mark the current step for removal
                steps_to_remove.add(current_step_key)

    # Remove the collapsed steps from the final dictionary
    for step in steps_to_remove:
        del collapsed_dict[step]
        
    return collapsed_dict

# ---------------------------------------------------------------------- #
#  Example Usage
# ---------------------------------------------------------------------- #

TARGET_INDEX = 10
for model in MODELS:
    TARGET_MODEL = model
    
    print(f"--- Running full TEXT-BASED normalization for Problem {TARGET_INDEX}, Model '{TARGET_MODEL}' ---")

    # 1. Get raw lines
    raw_lines = get_code_lines_dict(TARGET_INDEX, TARGET_MODEL)

    if raw_lines:
        # 2. Extract parts (Signature, Docstring, Body)
        extracted_parts = extract_code_parts_v2(raw_lines)
        print("\n--- Step 1: Extracted Code Parts ---")
        print(json.dumps(extracted_parts, indent=4))
        
        # 3. Collapse intermediate variables
        collapsed_code = collapse_intermediate_vars(extracted_parts)
        print("\n--- Step 2: Code After Collapsing Intermediate Variables ---")
        print(json.dumps(collapsed_code, indent=4))

--- Running full TEXT-BASED normalization for Problem 10, Model 'anthropic_claude-3-5-haiku-20241022' ---

--- Step 1: Extracted Code Parts ---
{
    "function_signature": [
        "def solve(\n    total_people_consumed: int = 847,  # Over three hundred years, it has consumed 847 people\n    num_periods: int = 3  # Monster rises once every hundred years over three hundred years\n):"
    ],
    "docstring": [
        "\"\"\"Index: 10.\n    Returns: the number of people on the first ship the monster consumed.\"\"\""
    ],
    "L4": [
        "total_ship_multiplier = 1 + 2 + 4  # S + 2S + 4S"
    ],
    "L5": [
        "people_on_first_ship = total_people_consumed / total_ship_multiplier"
    ],
    "answer_line": [
        "answer = people_on_first_ship  # FINAL ANSWER"
    ],
    "return_line": [
        "return answer"
    ]
}

--- Step 2: Code After Collapsing Intermediate Variables ---
{
    "function_signature": [
        "def solve(\n    total_people_consumed: int = 847,  # Over 

In [103]:
import re
from collections import defaultdict
import json

# This function is now more robust and will replace the previous version.
def collapse_intermediate_vars_v2(normalized_dict: dict[str, list[str]]) -> dict[str, list[str]]:
    """
    Collapses intermediate variables in a normalized code dictionary. (Version 2)
    - Handles chained substitutions by running iteratively.
    - Handles multi-line steps.
    - Strips comments before substitution to prevent syntax errors.
    """
    collapsed_dict = {k: v[:] for k, v in normalized_dict.items()}
    
    # Use a loop to handle chained substitutions (e.g., L3 -> L4 -> L5)
    # The loop continues as long as a substitution was made in the previous pass.
    while True:
        substitution_made_in_pass = False
        steps_to_remove = set()
        
        # Get a sorted list of the logical steps to process in order
        step_keys = sorted([key for key in collapsed_dict if key.startswith('L')])

        for i in range(len(step_keys) - 1):
            current_step_key = step_keys[i]
            next_step_key = step_keys[i+1]

            # Get all variables defined in the current step (the LHS of all assignments)
            # This now handles multi-line steps correctly.
            definitions_in_current_step = {}
            for line in collapsed_dict.get(current_step_key, []):
                if '=' in line:
                    lhs, rhs = line.split('=', 1)
                    # --- FIX 1: Strip comments from RHS before storing ---
                    rhs_no_comment = rhs.split('#')[0].strip()
                    definitions_in_current_step[lhs.strip()] = rhs_no_comment

            if not definitions_in_current_step:
                continue

            # Check for substitutions in the next step
            next_step_lines = collapsed_dict.get(next_step_key, [])
            updated_next_step_lines = []
            made_change_in_next_step = False

            for next_line in next_step_lines:
                original_next_line = next_line
                for var_name, expression in definitions_in_current_step.items():
                    # Use word boundaries (\b) to avoid replacing 'var' in 'other_var'
                    pattern = r'\b' + re.escape(var_name) + r'\b'
                    if re.search(pattern, next_line):
                        # Perform substitution
                        next_line = re.sub(pattern, f"({expression})", next_line)
                        made_change_in_next_step = True
                
                updated_next_step_lines.append(next_line)

            if made_change_in_next_step:
                # Update the dictionary with the new lines and mark step for removal
                collapsed_dict[next_step_key] = updated_next_step_lines
                steps_to_remove.add(current_step_key)
                substitution_made_in_pass = True
        
        # Remove the collapsed steps after the pass is complete
        for step in steps_to_remove:
            if step in collapsed_dict:
                del collapsed_dict[step]
        
        # If a full pass made no substitutions, we are done.
        if not substitution_made_in_pass:
            break
            
    return collapsed_dict

# ---------------------------------------------------------------------- #
#  Updated Example Usage (to be run in a new cell)
# ---------------------------------------------------------------------- #

# Re-run the test for problem 10 using the NEW v2 collapse function.
# This requires the `extract_code_parts_v2` function from the previous response.
# Make sure it is defined in your notebook environment.

TARGET_INDEX = 10
for model in MODELS:
    TARGET_MODEL = model
    
    print(f"\n--- Running full TEXT-BASED normalization (v2 collapse) for Problem {TARGET_INDEX}, Model '{TARGET_MODEL}' ---")

    raw_lines = get_code_lines_dict(TARGET_INDEX, TARGET_MODEL)
    if raw_lines:
        extracted_parts = extract_code_parts_v2(raw_lines)
        
        # Use the new collapse function
        collapsed_code = collapse_intermediate_vars_v2(extracted_parts)
        
        print("\n--- Final Collapsed Code ---")
        print(json.dumps(collapsed_code, indent=4))


--- Running full TEXT-BASED normalization (v2 collapse) for Problem 10, Model 'anthropic_claude-3-5-haiku-20241022' ---

--- Final Collapsed Code ---
{
    "function_signature": [
        "def solve(\n    total_people_consumed: int = 847,  # Over three hundred years, it has consumed 847 people\n    num_periods: int = 3  # Monster rises once every hundred years over three hundred years\n):"
    ],
    "docstring": [
        "\"\"\"Index: 10.\n    Returns: the number of people on the first ship the monster consumed.\"\"\""
    ],
    "L5": [
        "people_on_first_ship = total_people_consumed / (1 + 2 + 4)"
    ],
    "answer_line": [
        "answer = people_on_first_ship  # FINAL ANSWER"
    ],
    "return_line": [
        "return answer"
    ]
}

--- Running full TEXT-BASED normalization (v2 collapse) for Problem 10, Model 'openai_gpt-4.1-mini' ---

--- Final Collapsed Code ---
{
    "function_signature": [
        "def solve(\n    total_people: int = 847,  # it has consumed 847 pe

In [104]:
TARGET_INDEX = 25
for model in MODELS:
    TARGET_MODEL = model
    
    print(f"\n--- Running full TEXT-BASED normalization (v2 collapse) for Problem {TARGET_INDEX}, Model '{TARGET_MODEL}' ---")

    raw_lines = get_code_lines_dict(TARGET_INDEX, TARGET_MODEL)
    if raw_lines:
        extracted_parts = extract_code_parts_v2(raw_lines)
        
        # Use the new collapse function
        collapsed_code = collapse_intermediate_vars_v2(extracted_parts)
        
        print("\n--- Final Collapsed Code ---")
        print(json.dumps(collapsed_code, indent=4))


--- Running full TEXT-BASED normalization (v2 collapse) for Problem 25, Model 'anthropic_claude-3-5-haiku-20241022' ---

--- Final Collapsed Code ---
{
    "function_signature": [
        "def solve(\n    total_tennis_balls: int = 175,  # Ralph loads up the machine with 175 tennis balls\n    first_batch_balls: int = 100,  # Out of the first 100 balls\n    first_batch_hit_fraction: float = 2/5,  # he manages to hit 2/5 of them\n    next_batch_balls: int = 75,  # Of the next 75 tennis balls\n    next_batch_hit_fraction: float = 1/3  # he manages to hit 1/3 of them\n):"
    ],
    "docstring": [
        "\"\"\"Index: 25.\n    Returns: the total number of tennis balls Ralph did not hit.\"\"\""
    ],
    "L3": [
        "total_missed_balls = ((1 - first_batch_hit_fraction) * first_batch_balls) + ((1 - next_batch_hit_fraction) * next_batch_balls)"
    ],
    "answer_line": [
        "answer = total_missed_balls  # FINAL ANSWER"
    ],
    "return_line": [
        "return answer"
    ]
}

-

In [105]:
import re
from collections import defaultdict
import json
from pathlib import Path

# This block should be self-contained and assumes only BASE_OUTPUT_DIR is defined.
# If running in a new notebook, you'll need to define PROJECT_ROOT and BASE_OUTPUT_DIR.

# ---------------------------------------------------------------------- #
#  Module 1: Raw Code -> Normalized Dictionary
# ---------------------------------------------------------------------- #

# Regex to find the solution step comment, e.g., "#: L1"
_TRACE_RE = re.compile(r"#:\s*L(\d+)")

def get_code_lines_dict(problem_index: int, model_name: str, base_output_dir: Path) -> dict[int, str] | None:
    """
    Returns a dict mapping line numbers (0-based) to the verbatim lines of code
    for the given problem index and model name.
    """
    file_path = base_output_dir / str(problem_index) / f"{model_name}.py"
    if not file_path.exists():
        # Using return None for file not found, no need to print error here.
        return None
    with open(file_path, "r", encoding="utf-8") as f:
        lines = f.readlines()
    return {i: line.rstrip("\n") for i, line in enumerate(lines)}

def _normalize_body_lines(body_lines_dict: dict[int, str]) -> dict[str, list[str]]:
    """
    Helper function to process only the body of a function into a
    step-wise dictionary.
    """
    normalized_dict = defaultdict(list)
    sorted_lines = sorted(body_lines_dict.items())
    last_seen_step = None

    for i, (line_num, line_content) in enumerate(sorted_lines):
        line = line_content.strip()
        if not line:
            continue

        step_match = _TRACE_RE.match(line)
        if step_match:
            last_seen_step = f"L{step_match.group(1)}"
            continue

        if line.startswith('#'):
            continue
            
        code_part = line
        trailing_match = _TRACE_RE.search(line)
        if trailing_match:
            last_seen_step = f"L{trailing_match.group(1)}"
            code_part = line[:trailing_match.start()].strip()
        
        if 'answer =' in line:
            normalized_dict['answer_line'].append(code_part)
        elif line.startswith('return'):
            normalized_dict['return_line'].append(code_part)
        elif last_seen_step:
            normalized_dict[last_seen_step].append(code_part)
            
    return dict(normalized_dict)

def get_normalized_code_dict(
    problem_index: int, model_name: str, base_output_dir: Path = BASE_OUTPUT_DIR
) -> dict[str, list[str]] | None:
    """
    Orchestrates the creation of the normalized code dictionary.
    
    This function reads a raw model output file and produces a dictionary
    containing the signature, docstring, and a step-wise breakdown of the
    function body.
    """
    raw_lines_dict = get_code_lines_dict(problem_index, model_name, base_output_dir)
    if not raw_lines_dict:
        return None

    final_dict = defaultdict(list)
    sorted_lines = sorted(raw_lines_dict.items())
    
    # --- State machine variables ---
    state = "SEARCHING"
    signature_lines = []
    body_lines_dict = {}
    
    # --- 1. Separate signature from body ---
    for i, (line_num, line_content) in enumerate(sorted_lines):
        stripped_line = line_content.strip()
        if state == "SEARCHING" and stripped_line.startswith("def solve("):
            state = "SIGNATURE"
        
        if state == "SIGNATURE":
            signature_lines.append(line_content)
            if stripped_line.endswith("):"):
                state = "BODY"
                body_lines_dict = dict(sorted_lines[i+1:])
                break
    
    final_dict['function_signature'] = ["\n".join(signature_lines)]
    
    # --- 2. Separate docstring from the rest of the body ---
    docstring_lines = []
    rest_of_body_dict = {}
    body_lines_sorted = sorted(body_lines_dict.items())
    in_docstring = False
    
    docstring_end_index = -1
    for i, (line_num, line_content) in enumerate(body_lines_sorted):
        stripped_line = line_content.strip()
        
        if not in_docstring and ('"""' in stripped_line or "'''" in stripped_line):
            in_docstring = True
        
        if in_docstring:
            docstring_lines.append(line_content)
            # Check if this line closes the docstring
            if stripped_line.count('"""') % 2 == 1 or stripped_line.count("'''") % 2 == 1:
                if len(docstring_lines) > 1 or stripped_line.count('"""') == 2 or stripped_line.count("'''") == 2:
                    in_docstring = False
                    docstring_end_index = i
                    break
    
    if docstring_lines:
        final_dict['docstring'] = ["\n".join(docstring_lines).strip()]
        rest_of_body_dict = dict(body_lines_sorted[docstring_end_index+1:])
    else:
        rest_of_body_dict = body_lines_dict

    # --- 3. Process the remaining code body ---
    body_dict = _normalize_body_lines(rest_of_body_dict)
    final_dict.update(body_dict)
            
    return dict(final_dict)

# ---------------------------------------------------------------------- #
#  Example Usage
# ---------------------------------------------------------------------- #

# Assume MODELS list and BASE_OUTPUT_DIR are defined.
TARGET_INDEX = 25
TARGET_MODEL = "google_gemini-2.0-flash-thinking-exp"

print(f"--- Generating Normalized Dictionary for Problem {TARGET_INDEX}, Model '{TARGET_MODEL}' ---")

normalized_dict = get_normalized_code_dict(TARGET_INDEX, TARGET_MODEL)

if normalized_dict:
    print(json.dumps(normalized_dict, indent=4))

--- Generating Normalized Dictionary for Problem 25, Model 'google_gemini-2.0-flash-thinking-exp' ---
{
    "function_signature": [
        "def solve(\n    initial_balls: int = 175, # He loads up the machine with 175 tennis balls to start with\n    first_batch_size: int = 100, # Out of the first 100 balls\n    fraction_hit_first_batch: float = 2/5, # he manages to hit 2/5 of them\n    second_batch_size: int = 75, # Of the next 75 tennis balls\n    fraction_hit_second_batch: float = 1/3 # he manages to hit 1/3 of them\n):"
    ],
    "docstring": [
        "\"\"\"Index: 25.\n    Returns: the total number of tennis balls Ralph did not hit.\n    \"\"\""
    ],
    "L1": [
        "fraction_not_hit_first_batch = 1 - fraction_hit_first_batch",
        "not_hit_first_batch = fraction_not_hit_first_batch * first_batch_size"
    ],
    "L2": [
        "fraction_not_hit_second_batch = 1 - fraction_hit_second_batch",
        "not_hit_second_batch = fraction_not_hit_second_batch * second_batch_s

In [None]:
import re
from collections import defaultdict
import json
from pathlib import Path

# This block is self-contained.
# Assumes BASE_OUTPUT_DIR and MODELS are defined in your environment.

# ---------------------------------------------------------------------- #
#  Module 1: Raw Code -> Simplified Normalized Dictionary
# ---------------------------------------------------------------------- #

# Regex to find the solution step comment, e.g., "#: L1"
_TRACE_RE = re.compile(r"#:\s*L(\d+)")

def get_code_lines_dict(problem_index: int, model_name: str, base_output_dir: Path) -> dict[int, str] | None:
    """
    Returns a dict mapping line numbers (0-based) to the verbatim lines of code
    for the given problem index and model name.
    """
    file_path = base_output_dir / str(problem_index) / f"{model_name}.py"
    if not file_path.exists():
        return None
    with open(file_path, "r", encoding="utf-8") as f:
        lines = f.readlines()
    return {i: line.rstrip("\n") for i, line in enumerate(lines)}

def _normalize_body_lines(body_lines_dict: dict[int, str]) -> dict[str, list[str]]:
    """
    Helper function to process only the body of a function into a
    step-wise dictionary. (This function is reused from before).
    """
    normalized_dict = defaultdict(list)
    sorted_lines = sorted(body_lines_dict.items())
    last_seen_step = None

    for i, (line_num, line_content) in enumerate(sorted_lines):
        line = line_content.strip()
        if not line:
            continue

        step_match = _TRACE_RE.match(line)
        if step_match:
            last_seen_step = f"L{step_match.group(1)}"
            continue

        if line.startswith('#'):
            continue
            
        code_part = line
        trailing_match = _TRACE_RE.search(line)
        if trailing_match:
            last_seen_step = f"L{trailing_match.group(1)}"
            code_part = line[:trailing_match.start()].strip()
        
        if 'answer =' in line:
            normalized_dict['answer_line'].append(code_part)
        elif line.startswith('return'):
            normalized_dict['return_line'].append(code_part)
        elif last_seen_step:
            normalized_dict[last_seen_step].append(code_part)
            
    return dict(normalized_dict)


def get_simplified_code_dict(
    problem_index: int, model_name: str, base_output_dir: Path = BASE_OUTPUT_DIR
) -> dict[str, list[str]] | None:
    """
    Orchestrates the creation of the normalized code dictionary using a
    simplified header/body split.
    
    The output dict will contain:
    - 'function_header': A single string with the signature and docstring.
    - Step-wise breakdown of the body (L1, L2, answer_line, etc.).
    """
    raw_lines_dict = get_code_lines_dict(problem_index, model_name, base_output_dir)
    if not raw_lines_dict:
        return None

    sorted_lines = sorted(raw_lines_dict.items())
    header_lines = []
    split_index = -1

    # Generate function signature, including docstring
    for i, (line_num, line_content) in enumerate(sorted_lines):
        header_lines.append(line_content)
        if line_content.endswith('"""'): # Check for end of docstring
            split_index = i
            break

    # --- Process the two parts ---
    header_str = "\n".join(header_lines).strip()
    body_lines_dict = dict(sorted_lines[split_index + 1:])
    body_dict = _normalize_body_lines(body_lines_dict)
    
    # --- Combine into the final dictionary ---
    final_dict = {"function_header": [header_str]}
    final_dict.update(body_dict)
            
    return final_dict

for TARGET_INDEX in [10, 25]:
    for TARGET_MODEL in MODELS:
        print(f"\n--- Generating Simplified Dictionary for Problem {TARGET_INDEX}, Model '{TARGET_MODEL}' ---")

        simplified_dict = get_simplified_code_dict(TARGET_INDEX, TARGET_MODEL)

        if simplified_dict:
            print(json.dumps(simplified_dict, indent=4))

In [107]:
for TARGET_INDEX in [10, 25]:
    for TARGET_MODEL in MODELS:
        print(f"\n--- Generating Simplified Dictionary for Problem {TARGET_INDEX}, Model '{TARGET_MODEL}' ---")

        simplified_dict = get_simplified_code_dict(TARGET_INDEX, TARGET_MODEL)

        if simplified_dict:
            print(json.dumps(simplified_dict, indent=4))


--- Generating Simplified Dictionary for Problem 10, Model 'anthropic_claude-3-5-haiku-20241022' ---
{
    "function_header": [
        "def solve(\n    total_people_consumed: int = 847,  # Over three hundred years, it has consumed 847 people\n    num_periods: int = 3  # Monster rises once every hundred years over three hundred years\n):\n    \"\"\"Index: 10.\n    Returns: the number of people on the first ship the monster consumed.\"\"\""
    ],
    "L4": [
        "total_ship_multiplier = 1 + 2 + 4  # S + 2S + 4S"
    ],
    "L5": [
        "people_on_first_ship = total_people_consumed / total_ship_multiplier"
    ],
    "answer_line": [
        "answer = people_on_first_ship  # FINAL ANSWER"
    ],
    "return_line": [
        "return answer"
    ]
}

--- Generating Simplified Dictionary for Problem 10, Model 'openai_gpt-4.1-mini' ---
{
    "function_header": [
        "def solve(\n    total_people: int = 847,  # it has consumed 847 people over three hundred years\n):\n    \"\"\"In