# Conceptual Error Candidate Generation

This notebook contains the end-to-end pipeline for generating conceptual error "candidates" from validated `Formalization Templates`. The process involves programmatically applying logical mutations to the Abstract Syntax Tree (AST) of the correct `function_code`, executing the mutated code to generate a flawed numerical trace, and then packaging all relevant information for a human-in-the-loop validation and annotation stage.

#### **Cell 1: Imports and Setup**
*   **Functionality:** Standard project setup.
*   **Logic:**
    *   Imports necessary libraries (`json`, `ast`, `pathlib`, etc.).
    *   Defines a robust `find_project_root` function to make path definitions portable.
    *   Sets up key `Path` objects for input (`PROCESSED_TEMPLATE_DIR`) and output (`CONCEPTUAL_CANDIDATES_DIR`).
    *   Ensures the output directory exists.

In [143]:
# --- Imports and Path Definitions ---

import json
import re
import ast
import inspect
import importlib.util
from pathlib import Path
from types import ModuleType
from typing import Callable, Any, Dict, List
from fractions import Fraction as BuiltinFraction
import copy

import pandas as pd
from tqdm.notebook import tqdm
from datasets import load_dataset, Dataset

def find_project_root(marker: str = ".git") -> Path:
    """Traverse upwards to find the project root, marked by the git repository."""
    current_path = Path.cwd().resolve()
    while current_path != current_path.parent:
        if (current_path / marker).exists():
            return current_path
        current_path = current_path.parent
    raise FileNotFoundError(f"Could not find project root. Marker '{marker}' not found.")

PROJECT_ROOT = find_project_root()
DATA_DIR = PROJECT_ROOT / 'data'

# --- Directory Paths ---
PROCESSED_TEMPLATE_DIR = DATA_DIR / "template-generated-processed"
CONCEPTUAL_CANDIDATES_DIR = DATA_DIR / "conceptual-error-candidates"

# --- Models ---
MODELS = ['openai_gpt-4.1', 'google_gemini-2.5-flash']

print(f"Project root: {PROJECT_ROOT}")
print(f"Input (Processed Templates): {PROCESSED_TEMPLATE_DIR}")
print(f"Output (Conceptual Candidates): {CONCEPTUAL_CANDIDATES_DIR}")

# --- Ensure Directories Exist ---
PROCESSED_TEMPLATE_DIR.mkdir(parents=True, exist_ok=True)
CONCEPTUAL_CANDIDATES_DIR.mkdir(parents=True, exist_ok=True)

Project root: /Users/arvindsuresh/Documents/Github/Erdos-DL-June25-Math
Input (Processed Templates): /Users/arvindsuresh/Documents/Github/Erdos-DL-June25-Math/data/template-generated-processed
Output (Conceptual Candidates): /Users/arvindsuresh/Documents/Github/Erdos-DL-June25-Math/data/conceptual-error-candidates


#### **Cell 2: Load Dataset and Tiers**
*   **Functionality:** Loads the ground-truth dataset and categorizes problems into tiers.
*   **Logic:**
    *   Loads the `gsm8k` training set.
    *   Defines several helper functions (`has_computational_division`, `has_float`, `is_symbolic`) that use regular expressions to classify the solution text of each problem.
    *   `mutually_disjoint_tiers` uses these helpers to partition all problem indices into distinct tiers (e.g., `tier1` for simple integer arithmetic, `tier5` for symbolic algebra). This organizes the data generation process.

In [144]:
# --- Load GSM8K Dataset ---
GSM8K_TRAIN: Dataset = load_dataset("gsm8k", "main")["train"] #type: ignore

# --- Tier Definition Functions ---
def has_computational_division(solution_text: str) -> bool:
    pattern = re.compile(r'/\s*\d')
    return bool(pattern.search(solution_text))

def has_float(solution_text: str) -> bool:
    pattern = re.compile(r'(?<!\d)\.\d+|\d+\.\d+')
    return bool(pattern.search(solution_text))

def is_symbolic(solution_text: str) -> bool:
    pattern = re.compile(r'^Let [a-zA-Z] ', re.MULTILINE)
    return bool(pattern.search(solution_text))

def mutually_disjoint_tiers(dataset: Dataset) -> dict[str, list[int]]:
    tiers = {}
    symbolic_set = set(idx for idx, sample in enumerate(dataset) if is_symbolic(sample.get("answer", "")))
    non_symbolic_indices = [idx for idx in range(len(dataset)) if idx not in symbolic_set]
    tiers["tier1"] = sorted([idx for idx in non_symbolic_indices if not has_float(dataset[idx].get("answer", "")) and not has_computational_division(dataset[idx].get("answer", ""))])
    tiers["tier2"] = sorted([idx for idx in non_symbolic_indices if has_float(dataset[idx].get("answer", "")) and not has_computational_division(dataset[idx].get("answer", ""))])
    tiers["tier3"] = sorted([idx for idx in non_symbolic_indices if not has_float(dataset[idx].get("answer", "")) and has_computational_division(dataset[idx].get("answer", ""))])
    tiers["tier4"] = sorted([idx for idx in non_symbolic_indices if has_float(dataset[idx].get("answer", "")) and has_computational_division(dataset[idx].get("answer", ""))])
    tiers["tier5"] = sorted(list(symbolic_set))
    return tiers

TIER_LISTS = mutually_disjoint_tiers(GSM8K_TRAIN)
print("Tier definitions loaded.")

Tier definitions loaded.


#### **Cell 3: Template Loading Utilities**
*   **Functionality:** Provides helper functions to load the validated `Formalization Template` components.
*   **Logic:**
    *   `load_function_module`: Takes a tier, index, and model name. It constructs the path to the `.py` file and uses `importlib.util` to dynamically load the file as a Python module, making its `solve` function accessible.
    *   `load_logical_steps`: Takes the same inputs, constructs the path to the `.json` file, and returns the parsed list of dictionaries representing the solution's logical steps.

In [145]:
def load_function_module(tier: str, index: int, model_name: str) -> ModuleType | None:
    """Dynamically loads the '{model_name}.py' file for a given template."""
    py_file_path = PROCESSED_TEMPLATE_DIR / tier / str(index) / f"{model_name}.py"
    if not py_file_path.exists():
        return None
    module_name = f"templates.t{tier}.i{index}.m_{model_name.replace('.', '_')}.solve"
    spec = importlib.util.spec_from_file_location(module_name, py_file_path)
    if spec and spec.loader:
        module = importlib.util.module_from_spec(spec)
        spec.loader.exec_module(module)
        return module
    return None

def load_logical_steps(tier: str, index: int, model_name: str) -> list[dict] | None:
    """Loads the '{model_name}.json' file for a given template."""
    json_file_path = PROCESSED_TEMPLATE_DIR / tier / str(index) / f"{model_name}.json"
    try:
        return json.loads(json_file_path.read_text(encoding='utf-8'))
    except (FileNotFoundError, json.JSONDecodeError):
        return None

print("Template loading functions defined.")

Template loading functions defined.


#### **Cell 3.5: Solution mapping construction**

In [146]:
# In a new cell (e.g., Cell 3.5)

def sanitize_text(text: str) -> str:
    """
    Replaces a comprehensive set of problematic Unicode characters with their
    ASCII equivalents to prevent model generation and string parsing errors.
    """
    replacements = {
        "\u2212": "-",  # Minus Sign
        "\u00d7": "*",  # Multiplication Sign
        "\u00f7": "/",  # Division Sign
        "\u22c5": "*",  # Dot Operator
        "\u201c": '"',  # Left Double Quotation Mark
        "\u201d": '"',  # Right Double Quotation Mark
        "\u2018": "'",  # Left Single Quotation Mark
        "\u2019": "'",  # Right Single Quotation Mark
        "\u2014": "-",  # Em Dash
        "\u2013": "-",  # En Dash
        "\u2026": "...", # Horizontal Ellipsis
        "\u00a0": " ",  # Non-breaking Space
    }
    for uni, ascii_char in replacements.items():
        text = text.replace(uni, ascii_char)
    return text

def build_solution_mapping(index: int) -> dict:
    """
    Extracts the original natural language solution, sanitizes it, and
    structures it into a line-numbered dictionary including the 'FA' line.
    """
    try:
        solution_text = GSM8K_TRAIN[index]["answer"]
        sanitized_text = sanitize_text(solution_text)
        lines = [ln.strip() for ln in sanitized_text.splitlines() if ln.strip()]

        solution_mapping = {}
        if lines and re.match(r"^####\s*[\d\.,]+$", lines[-1]):
            solution_mapping["FA"] = lines.pop(-1).strip()
        
        for i, line in enumerate(lines, 1):
            solution_mapping[f"L{i}"] = line
            
        return solution_mapping
    except IndexError:
        return {}

print("Sanitization and Solution Mapping utilities defined.")

Sanitization and Solution Mapping utilities defined.


#### **Cell 4: Core AST and Trace Utilities**
*   **Functionality:** Contains the engine for executing and mutating the abstract syntax tree (AST) of the `function_code`.
*   **Logic:**
    *   `_execute_ast_statements`: A low-level executor that takes a list of AST statements. It iterates through them, compiling and executing each `ast.Assign` node to build a dictionary representing the final variable state (a trace).
    *   `execution_trace`: A wrapper around `_execute_ast_statements`. It takes a Python function, parses its source code into an AST, extracts the function body's statements, and passes them to the executor to get a full, correct execution trace.
    *   `get_scope_at_node`: A critical utility that executes a function's code only *up to* a specified AST node. This is used to determine the set of valid, in-scope variables available at a specific line of code, which is essential for discovering "Wrong Reference" errors.
    *   `is_plausible_trace`: A validation function that compares a `flawed_trace` to the `correct_trace`. It enforces rules to filter out nonsensical errors, such as a variable's sign changing incorrectly or a non-zero value becoming zero. It contains a `FIXME` noting that it could be more robust in handling non-numeric types.
    *   `apply_mutation`: The core mutation engine. It uses an `ast.NodeTransformer` to traverse a copy of the function's AST. When it finds the `target_variable` specified in `mutation_details`, it applies the corresponding logical change (e.g., swapping an operator, replacing a variable `ast.Name` node). It then unparses the mutated AST back into a string and executes it to produce the `flawed_trace`.

In [147]:
class NonSimplifyingFraction(BuiltinFraction):
    """A subclass of fractions.Fraction that preserves original representation for printing."""
    def __new__(cls, numerator=0, denominator=None):
        self = super().__new__(cls, numerator, denominator)
        if denominator is not None:
            self._original_denominator = denominator
            self._original_numerator = numerator
        else:
            self._original_numerator, self._original_denominator = self.as_integer_ratio()
        return self

    def __str__(self):
        return f"{self._original_numerator}/{self._original_denominator}"

def _execute_ast_statements(stmts: List[ast.stmt]) -> Dict[str, Any] | None:
    """Executes a list of AST assignment statements and returns the final local environment."""
    try:
        global_namespace = {'Fraction': NonSimplifyingFraction}
        local_env = {}
        for stmt in stmts:
            if isinstance(stmt, ast.Assign):
                module_node = ast.Module([stmt], type_ignores=[])
                code_obj = compile(module_node, '<string>', 'exec')
                exec(code_obj, global_namespace, local_env)
        return local_env
    except Exception:
        return None

def execution_trace(func: Callable[[], Any]) -> dict[str, Any] | None:
    """Executes a function's source code line by line to build a full variable trace."""
    try:
        src = inspect.getsource(func)
        tree = ast.parse(src)
        func_def = tree.body[0]
        if not isinstance(func_def, ast.FunctionDef): return None
        return _execute_ast_statements(func_def.body)
    except (FileNotFoundError, TypeError, IndexError):
        return None

def get_scope_at_node(func: Callable, target_assign_node: ast.Assign) -> Dict[str, Any] | None:
    """Executes a function's AST up to a specific node to get the current variable scope."""
    try:
        src = inspect.getsource(func)
        tree = ast.parse(src)
        func_def = tree.body[0]
        if not isinstance(func_def, ast.FunctionDef): return None

        statements_before = []
        for stmt in func_def.body:
            if isinstance(stmt, ast.Assign) and stmt.lineno == target_assign_node.lineno:
                return _execute_ast_statements(statements_before)
            statements_before.append(stmt)
        return None
    except (FileNotFoundError, TypeError, IndexError):
        return None

def is_plausible_trace(correct_trace: dict, flawed_trace: dict) -> bool:
    """
    Validates a flawed trace against a correct one for type and value integrity.

    A trace is considered implausible if any numeric variable:
    1. Becomes non-numeric.
    2. Was an integer (or integer-like float) and becomes a non-integer float.
    3. Flips its sign (e.g., positive to negative), unless it was zero.
    4. Was non-zero and becomes zero.
    """
    for key, correct_val in correct_trace.items():
        # This function only validates the stability of numeric types.
        if not isinstance(correct_val, (int, float, BuiltinFraction)):
            continue

        flawed_val = flawed_trace.get(key)

        # 1. Type Consistency Check: If correct is numeric, flawed must be too.
        # This is the primary fix for the AttributeError.
        if not isinstance(flawed_val, (int, float, BuiltinFraction)):
            return False  # Invalid: type changed from numeric to non-numeric.

        # 2. Integer-like Preservation Check
        is_correct_int_like = isinstance(correct_val, int) or \
                              (isinstance(correct_val, float) and correct_val.is_integer())
        is_flawed_int_like = isinstance(flawed_val, int) or \
                             (isinstance(flawed_val, float) and flawed_val.is_integer())
        if is_correct_int_like and not is_flawed_int_like:
            return False  # Invalid: integer-like value became a true float.

        # 3. Sign Preservation Check
        correct_sign = (correct_val > 0) - (correct_val < 0)
        flawed_sign = (flawed_val > 0) - (flawed_val < 0)
        if correct_val != 0 and correct_sign != flawed_sign:
            return False  # Invalid: sign flipped for a non-zero value.

        # 4. Zero-Value Preservation Check
        if correct_val != 0 and flawed_val == 0:
            return False  # Invalid: non-zero value became zero.

    return True # All numeric variables passed the plausibility checks.

def apply_mutation(
    func: Callable, mutation_details: Dict[str, Any]
) -> tuple[str | None, Dict[str, Any] | None]:
    """
    Applies a single logical mutation to a function's AST and returns the mutated code and trace.

    This function serves as the central engine for creating flawed versions of the 
    `function_code`. It parses the source, uses a NodeTransformer to apply a specific
    mutation based on the `mutation_details`, and then executes the modified AST
    to generate a `flawed_trace`.

    Args:
        func: The original, correct solve() function from the template.
        mutation_details: A dictionary describing the mutation to apply, including
                          its 'type', 'target_variable', and other relevant info.

    Returns:
        A tuple containing:
        - The mutated function code as a string, or None if mutation failed.
        - The flawed execution trace (a dict of variable states), or None if execution failed.
    """
    try:
        src = inspect.getsource(func)
        tree = ast.parse(src)

        class ASTMutator(ast.NodeTransformer):
            def __init__(self, details):
                self.details = details
                self.mutation_applied = False

            def visit_Assign(self, node):
                # Ensure the node is a simple assignment and matches the target variable
                if not (isinstance(node.targets[0], ast.Name) and 
                        node.targets[0].id == self.details['target_variable']):
                    return node

                mutation_type = self.details['type']
                
                # --- Case 1: Incomplete Calculation ---
                # Handles removing an operand from a multi-operand sum or product.
                if mutation_type == 'incomplete_calculation' and isinstance(node.value, ast.BinOp):
                    unrolled = _unroll_binop(node.value)
                    if unrolled:
                        operands, op_type = unrolled
                        operand_to_remove = self.details['operand_to_remove']
                        
                        new_operands = [op for op in operands if not (isinstance(op, ast.Name) and op.id == operand_to_remove)]
                        
                        if len(new_operands) >= 2:
                            node.value = _build_binop_tree_from_list(new_operands, op_type)
                            self.mutation_applied = True
                
                # --- Case 2: Operator Swap ---
                # Handles changing the operator in a binary operation (e.g., + to -).
                elif mutation_type == 'operator_swap' and isinstance(node.value, ast.BinOp):
                    node.value.op = self.details['new_op']
                    self.mutation_applied = True
                
                # --- Case 3: Variable Name Replacement ---
                # Handles multiple error types that all involve swapping one variable for another.
                # This includes incorrect operands, input misrepresentation, and incorrect final answer selection.
                elif mutation_type in ['incorrect_operand', 'input_misrepresentation', 'incorrect_final_answer_selection']:
                    class OperandSwapper(ast.NodeTransformer):
                        def __init__(self, to_replace, new_name):
                            self.to_replace = to_replace
                            self.new_name = new_name
                        def visit_Name(self, name_node):
                            if name_node.id == self.to_replace:
                                return ast.Name(id=self.new_name, ctx=name_node.ctx)
                            return name_node
                    
                    swapper = OperandSwapper(self.details['operand_to_replace'], self.details['replacement_variable'])
                    node.value = swapper.visit(node.value)
                    self.mutation_applied = True
                
                return node

        mutator = ASTMutator(mutation_details)
        mutated_tree = mutator.visit(copy.deepcopy(tree))
        
        # If the mutation was not successfully applied, abort.
        if not mutator.mutation_applied:
            return None, None

        ast.fix_missing_locations(mutated_tree)
        mutated_code_str = ast.unparse(mutated_tree)

        # Execute the mutated code to get the flawed trace.
        func_def = mutated_tree.body[0]
        flawed_trace = _execute_ast_statements(func_def.body) if isinstance(func_def, ast.FunctionDef) else None
        
        return mutated_code_str, flawed_trace
    
    except Exception:
        # Catch any error during parsing, mutation, or execution and fail gracefully.
        return None, None

#### **Cell 5: Mutation Discovery**
*   **Functionality:** Traverses a function's AST to find all valid opportunities for conceptual mutations. This is the "discovery phase."
*   **Logic:**
    *   `discover_mutation_opportunities`: The main orchestrator. It parses the `function_code`, builds a dependency map, and iterates through each `ast.Assign` node. For each node, it calls specialized discovery helpers.
    *   `_discover_operator_swaps`: Identifies `ast.BinOp` nodes and generates potential mutations by swapping the operator (e.g., `+` to `-`, `*` to `/`).
    *   `_discover_plausible_wrong_references`: For a given assignment, it identifies all operand variables. It then uses `get_scope_at_node` to find all other numerically-compatible variables in scope and generates mutations to substitute them. It correctly classifies these as either `input_misrepresentation` or `incorrect_operand`.
    *   `_discover_omitted_final_steps`: Identifies the final `answer = ...` assignment. It proposes mutations that replace the correct final variable with other intermediate variables calculated earlier in the function, simulating a "stopped one step short" error. The `FIXME` comment correctly identifies a necessary edge-case check.

In [148]:
def get_variable_dependencies(node: ast.Assign) -> List[str]:
    """Extracts the names of variables used as operands in an assignment's value."""
    return sorted([n.id for n in ast.walk(node.value) if isinstance(n, ast.Name) and n.id != 'int'])

def _unroll_binop(node: ast.BinOp) -> tuple[list[ast.AST], type[ast.operator]] | None:
    """
    Recursively unrolls a nested ast.BinOp structure into a flat list of operands.
    Returns None if operators are inconsistent.
    """
    op_type = type(node.op)
    
    def recurse(n):
        if isinstance(n, ast.BinOp) and type(n.op) == op_type:
            yield from recurse(n.left)
            yield from recurse(n.right)
        else:
            yield n
            
    operands = list(recurse(node))
    return operands, op_type

def _build_binop_tree_from_list(operands: list[ast.AST], op_constructor: type[ast.operator]) -> ast.BinOp:
    """Reconstructs a left-associative ast.BinOp tree from a list of operand nodes."""
    if len(operands) < 2:
        raise ValueError("Cannot build a BinOp tree with fewer than 2 operands.")
    
    tree = ast.BinOp(left=operands[0], op=op_constructor(), right=operands[1])
    for i in range(2, len(operands)):
        tree = ast.BinOp(left=tree, op=op_constructor(), right=operands[i])

    return tree

def _discover_incomplete_calculations(node: ast.Assign) -> List[Dict[str, Any]]:
    """
    Discovers opportunities for an Incomplete Calculation error.
    
    This error models a student forgetting to include a term in a multi-operand
    calculation (e.g., A + B instead of A + B + C). It targets nested
    binary operations with a consistent operator (+ or *).
    """
    if not (isinstance(node.value, ast.BinOp) and hasattr(node.targets[0], 'id')):
        return []
        
    target_variable = node.targets[0].id
    
    # Unroll the AST to see if it's a multi-operand calculation.
    unrolled = _unroll_binop(node.value)
    if not unrolled:
        return []
    
    operands, op_type = unrolled
    
    # We need at least 3 operands to create an incomplete calculation.
    if len(operands) < 3:
        return []

    # Only apply to associative operations where omission is plausible.
    if op_type not in [ast.Add, ast.Mult]:
        return []

    mutations = []
    operand_names = [op.id for op in operands if isinstance(op, ast.Name)]

    # For each operand, generate a mutation where it is the one omitted.
    for i, operand_node in enumerate(operands):
        if not isinstance(operand_node, ast.Name):
            continue # Skip constants for now
            
        mutations.append({
            "type": "incomplete_calculation",
            "target_variable": target_variable,
            "operand_to_remove": operand_node.id,
            "original_operands": tuple(sorted(operand_names))
        })
        
    return mutations

def _discover_operator_swaps(node: ast.Assign) -> List[Dict[str, Any]]:
    """Finds all valid OperatorSwap mutations, capturing the operand names."""
    if not (isinstance(node.value, ast.BinOp) and hasattr(node.targets[0], 'id') and
            isinstance(node.value.left, ast.Name) and isinstance(node.value.right, ast.Name)):
        return []

    target_variable, original_op_node = node.targets[0].id, node.value.op
    swap_map = {
        ast.Add: [ast.Sub, ast.Mult], ast.Sub: [ast.Add, ast.Div],
        ast.Mult: [ast.Div, ast.Add], ast.Div: [ast.Mult, ast.Sub]
    }
    original_op_type = type(original_op_node)
    mutations = []
    if original_op_type in swap_map:
        for new_op_constructor in swap_map[original_op_type]:
            mutations.append({"type": "operator_swap", "target_variable": target_variable, "original_op_type": original_op_type,
                              "new_op": new_op_constructor(), "left_operand": node.value.left.id, "right_operand": node.value.right.id})
    return mutations

# In Cell 5
def _discover_plausible_wrong_references(
    func: Callable, node: ast.Assign, all_input_vars: set, debug: bool = False
) -> List[Dict[str, Any]]:
    """
    Discovers and classifies plausible operand substitution mutations.
    This function is now explicitly prohibited from targeting the 'answer' variable.
    """
    mutations = []
    if not hasattr(node.targets[0], 'id'):
        return []
        
    target_variable = node.targets[0].id

    # --- FIX: Prevent this function from targeting the 'answer' variable. ---
    # Mutations on 'answer' are handled exclusively by the final answer selection discoverer.
    if target_variable == 'answer':
        return []
    # --- END FIX ---
    
    scope = get_scope_at_node(func, node)
    if scope is None:
        return []

    operands = [n.id for n in ast.walk(node.value) if isinstance(n, ast.Name) and n.id != 'int']
    
    for operand_to_replace in set(operands):
        if operand_to_replace not in scope:
            continue
        
        for replacement_candidate in scope:
            if replacement_candidate in (operand_to_replace, 'Fraction'):
                continue
            
            # Ensure we are only swapping numeric types.
            if isinstance(scope.get(operand_to_replace), (int, float, BuiltinFraction)) and \
               isinstance(scope.get(replacement_candidate), (int, float, BuiltinFraction)):
                
                # Classify the error based on whether the original operand was a direct input.
                mutation_type = "input_misrepresentation" if operand_to_replace in all_input_vars else "incorrect_operand"
                
                mutations.append({
                    "type": mutation_type, 
                    "target_variable": target_variable,
                    "operand_to_replace": operand_to_replace, 
                    "replacement_variable": replacement_candidate
                })
                
    # Deduplicate the list of found mutations.
    return [dict(t) for t in {tuple(d.items()) for d in mutations}]

def _discover_incorrect_final_answer_selections(
    func: Callable, 
    node: ast.Assign,
    all_output_vars: set, # <-- New argument
    debug: bool = False
) -> List[Dict[str, Any]]:
    """
    Discovers opportunities to select an incorrect final answer.
    This is now constrained to only select from previously computed output variables.
    """
    if not hasattr(node.targets[0], 'id') or node.targets[0].id != 'answer':
        return []

    # Get the correct final variable being assigned to 'answer'.
    correct_operands = [n.id for n in ast.walk(node.value) if isinstance(n, ast.Name)]
    if not correct_operands: return []
    correct_final_var = correct_operands[0]

    mutations = []
    # --- FIX: Only iterate through plausible INTERMEDIATE variables. ---
    for var_name in all_output_vars:
        # A plausible incorrect answer is a previously calculated output that is not the correct final answer.
        if var_name != correct_final_var:
            mutations.append({
                "type": "incorrect_final_answer_selection",
                "target_variable": "answer",
                "operand_to_replace": correct_final_var,
                "replacement_variable": var_name
            })
            
    return mutations

def discover_mutation_opportunities(
    func: Callable, logical_steps: list, correct_trace: dict, debug_mode: bool = False
) -> list:
    """Analyzes a function's AST to find all possible conceptual error mutations."""
    all_mutations = []
    try:
        src = inspect.getsource(func)
        tree = ast.parse(src)
        func_def = tree.body[0]
        if not isinstance(func_def, ast.FunctionDef): return []
    except (FileNotFoundError, TypeError, IndexError): return []

    assign_nodes = [node for node in func_def.body if isinstance(node, ast.Assign)]
    all_input_vars = set(v for s in logical_steps for k in ('question_inputs', 'WK_inputs') for v in s.get(k, []))
    
    # --- FIX: Collect all output variable names first. ---
    all_output_vars = set(s['output_variable'] for s in logical_steps if 'output_variable' in s)

    for node in assign_nodes:
        all_mutations.extend(_discover_operator_swaps(node))
        all_mutations.extend(_discover_plausible_wrong_references(func, node, all_input_vars, debug=debug_mode))
        all_mutations.extend(_discover_incomplete_calculations(node))
        # --- FIX: Pass the set of output variables to the discovery function. ---
        all_mutations.extend(_discover_incorrect_final_answer_selections(func, node, all_output_vars, debug=debug_mode))
        
    return [dict(t) for t in {tuple(d.items()) for d in all_mutations}]

print("Mutation Discovery functions defined.")

Mutation Discovery functions defined.


#### **Cell 6: Candidate Generation and Reconstruction**
*   **Functionality:** Orchestrates the full candidate generation for a single template and handles the reconstruction of natural language text.
*   **Logic:**
    *   `generate_candidates_for_template`: The primary function. It calls `discover_mutation_opportunities`, then iterates through each potential mutation. For each one, it calls `apply_mutation`, validates the result with `is_plausible_trace`, checks that the final answer actually changed, and packages all relevant data (`traces`, `codes`, `explanations`) into a candidate dictionary. The `FIX` comment on line number assignment is a correct and important detail.
    *   `_cap_candidates`: Applies heuristic rules to limit the number of generated candidates per template, ensuring variety and preventing an unmanageable number of very similar errors.
    *   `reconstruct_solution_from_trace`: Reconstructs the natural language solution by populating the `solution_line_template`s with values from a given trace. It uses `_get_mutated_template` to attempt to alter the template text to match the mutation.
    *   `_generate_explanation` / `_get_mutated_template`: Helper functions for generating human-readable text for the final candidate package.

In [149]:
class FormattingTrace(dict):
    """A dictionary subclass that formats numbers for clean reconstruction."""
    def __getitem__(self, key):
        val = super().__getitem__(key)
        if isinstance(val, float) and val.is_integer():
            return str(int(val))
        if isinstance(val, float):
            return str(round(val, 2))
        return str(val)

def _get_mutated_template(original_template: str, mutation: dict) -> str:
    """Modifies a template string to reflect a given mutation."""
    mutation_type = mutation['type']

    # --- FIX: Add logic for incomplete_calculation ---
    if mutation_type == 'incomplete_calculation':
        placeholder_to_remove = f"{{{mutation['operand_to_remove']}}}"
        
        # Regex to find the placeholder followed by a multiplication/addition operator.
        # This handles symbols like 'x', '*', '×', or '+' surrounded by optional whitespace.
        # It correctly removes an operand from the middle or beginning of an expression.
        pattern_after = re.compile(re.escape(placeholder_to_remove) + r'\s*[x*×+]\s*')
        mutated_template, num_subs = pattern_after.subn('', original_template, count=1)
        
        if num_subs > 0:
            return mutated_template

        # If no substitution was made, the operand was likely at the end.
        # Try finding an operator before the placeholder.
        pattern_before = re.compile(r'\s*[x*×+]\s*' + re.escape(placeholder_to_remove))
        mutated_template, num_subs = pattern_before.subn('', original_template, count=1)
        
        return mutated_template
    # --- END FIX ---

    elif mutation_type == 'operator_swap':
        # ... (existing logic for operator_swap is correct)
        op_map = {ast.Add: '+', ast.Sub: '-', ast.Mult: '*', ast.Div: '/'}
        original_op_str, new_op_str = op_map.get(mutation['original_op_type']), op_map.get(type(mutation['new_op']))
        left_ph, right_ph = f"{{{mutation['left_operand']}}}", f"{{{mutation['right_operand']}}}"
        if not (original_op_str and new_op_str): return original_template

        patterns = [
            re.compile(f"({re.escape(left_ph)})(.*?)({re.escape(original_op_str)})(.*?)({re.escape(right_ph)})"),
            re.compile(f"({re.escape(right_ph)})(.*?)({re.escape(original_op_str)})(.*?)({re.escape(left_ph)})")
        ]
        
        mutated_template = original_template
        for pattern in patterns:
            mutated_template = pattern.sub(f"\\1\\2{new_op_str}\\4\\5", mutated_template)
        return mutated_template

    elif mutation_type in ['incorrect_operand', 'input_misrepresentation', 'incorrect_final_answer_selection']:
        # This logic is also correct and handles multiple types.
        to_replace_ph, new_name_ph = f"{{{mutation['operand_to_replace']}}}", f"{{{mutation['replacement_variable']}}}"
        return original_template.replace(to_replace_ph, new_name_ph)
        
    return original_template

def _reconstruct_incomplete_calc_template(original_template: str, mutation: dict) -> str:
    """
    Reconstructs the solution line template for an 'incomplete_calculation' error.
    This version correctly modifies both the outer text and the inner calculator annotation.
    """
    placeholder_to_remove = f"{{{mutation['operand_to_remove']}}}"
    
    # Define regex patterns to remove the placeholder and its associated operator.
    pattern_after = re.compile(re.escape(placeholder_to_remove) + r'\s*[x*×+]\s*')
    pattern_before = re.compile(r'\s*[x*×+]\s*' + re.escape(placeholder_to_remove))
    
    mutated_template = original_template
    
    # --- FIX: Modify both the calculator annotation and the surrounding text. ---
    annotation_match = re.search(r'<<(.*?)>>', mutated_template)
    if annotation_match:
        full_annotation_str = annotation_match.group(0)
        inner_expr_str = annotation_match.group(1)
        
        # Modify the inner expression by removing the placeholder.
        modified_inner_expr = pattern_after.sub('', inner_expr_str, count=1)
        if modified_inner_expr == inner_expr_str:
            modified_inner_expr = pattern_before.sub('', modified_inner_expr, count=1)
        
        # Rebuild the full annotation with the modified expression.
        new_annotation_str = f"<<{modified_inner_expr}>>"
        
        # Replace the old annotation in the template with the new one.
        mutated_template = mutated_template.replace(full_annotation_str, new_annotation_str)

    # Now, modify the text outside the annotation using the same logic.
    final_template = pattern_after.sub('', mutated_template, count=1)
    if final_template == mutated_template:
        final_template = pattern_before.sub('', final_template, count=1)

    return final_template

# In Cell 6

def reconstruct_solution_from_trace(
    logical_steps: list[dict], eval_trace: dict[str, Any], candidate_mutation: dict | None = None
) -> dict[str, str] | None:
    """
    Reconstructs NL solution lines, truncates if necessary, and adds the FA line.
    """
    formatted_trace = FormattingTrace(eval_trace)
    reconstructed_mapping = {}
    
    # --- FIX: Correctly identify the cutoff line for truncation. ---
    cutoff_line = None
    if candidate_mutation and candidate_mutation['type'] == 'incorrect_final_answer_selection':
        incorrect_answer_var = candidate_mutation['replacement_variable']
        # The cutoff line is the line where the incorrect intermediate answer was defined.
        cutoff_line_info = next((s for s in logical_steps if s.get('output_variable') == incorrect_answer_var), None)
        if cutoff_line_info:
            cutoff_line = cutoff_line_info['line_number']
    # --- END FIX ---

    for step in logical_steps:
        ln, template = step.get("line_number"), step.get("solution_line_template")
        if not (ln and template): continue
        
        # Apply mutations to the template string before formatting.
        if candidate_mutation and candidate_mutation['type'] != 'incorrect_final_answer_selection':
            if step.get("output_variable") == candidate_mutation['target_variable']:
                mutation_type = candidate_mutation['type']
                if mutation_type == 'incomplete_calculation':
                    template = _reconstruct_incomplete_calc_template(template, candidate_mutation)
                else:

                    template = _get_mutated_template(template, candidate_mutation)
        
        try:
            reconstructed_mapping[ln] = template.format_map(formatted_trace)
        except (KeyError, ValueError): return None

        # --- FIX: Truncate the solution immediately after the cutoff line. ---
        if ln == cutoff_line:
            break
            
    # Add the final answer line to the end of the reconstruction.
    flawed_final_answer = eval_trace.get('answer')
    if flawed_final_answer is not None:
        if isinstance(flawed_final_answer, float) and flawed_final_answer.is_integer():
            formatted_answer = int(flawed_final_answer)
        else:
            # Round floats for cleaner output, e.g., 33.333 -> 33.33
            formatted_answer = round(flawed_final_answer, 2) if isinstance(flawed_final_answer, float) else flawed_final_answer
        reconstructed_mapping['FA'] = f"#### {formatted_answer}"
            
    return reconstructed_mapping

def _generate_explanation(mutation: dict, correct_trace: dict) -> str:
    """Generates a human-readable explanation with values for a given mutation."""
    m_type = mutation['type']
    
    def get_val(var_name):
        return correct_trace.get(var_name, 'N/A')

    if m_type == 'incomplete_calculation':
        operand = mutation['operand_to_remove']
        return (f"Incomplete calculation. The term '{operand}' (value: {get_val(operand)}) "
                f"was omitted from the operation.")
    
    elif m_type == 'operator_swap':
        op_map = {ast.Add: '+', ast.Sub: '-', ast.Mult: '*', ast.Div: '/'}
        original_op_str = op_map.get(mutation['original_op_type'], '?')
        new_op_str = op_map.get(type(mutation['new_op']), '?')
        return f"Incorrect operation. The calculation should use '{original_op_str}' but used '{new_op_str}' instead."
    
    elif m_type == 'incorrect_operand':
        original_var, replacement_var = mutation['operand_to_replace'], mutation['replacement_variable']
        return (f"Incorrect operand. The variable '{replacement_var}' (value: {get_val(replacement_var)}) was used "
                f"instead of the correct variable '{original_var}' (value: {get_val(original_var)}).")

    elif m_type == 'input_misrepresentation':
        original_var, replacement_var = mutation['operand_to_replace'], mutation['replacement_variable']
        return (f"Input misrepresentation. The value for '{replacement_var}' ({get_val(replacement_var)}) was used "
                f"instead of the correct input value for '{original_var}' ({get_val(original_var)}).")

    elif m_type == 'incorrect_final_answer_selection':
        original_var, replacement_var = mutation['operand_to_replace'], mutation['replacement_variable']
        return (f"Incorrect final answer. An intermediate value '{replacement_var}' (value: {get_val(replacement_var)}) was reported "
                f"instead of the correct final answer '{original_var}' (value: {get_val(original_var)}).")
    
    return "An unknown conceptual error occurred."

def _cap_candidates(candidates: List[Dict], repro_seed: int) -> List[Dict]:
    """
    Applies constraints to the list of generated candidates to ensure variety.
    This selection process is made deterministic by the provided seed.
    
    Constraints:
    1. No more than 2 of any individual mutation type per sample.
    2. No more than 2 errors targeting any individual line per sample.
    3. No more than 5 errors in total per sample.
    """
    from collections import Counter
    
    # Apply type and line constraints
    type_counts = Counter()
    line_counts = Counter()
    
    filtered_candidates = []
    for cand in candidates:
        m_type = cand['mutation_details']['type']
        err_line = cand['erroneous_line_number']
        
        type_counts[m_type] += 1
        line_counts[err_line] += 1
        
        if type_counts[m_type] <= 2 and line_counts[err_line] <= 2:
            filtered_candidates.append(cand)
            
    # Apply total constraint, using the seed for reproducible shuffling
    if len(filtered_candidates) > 5:
        import random
        # --- FIX: Use the reproducible seed ---
        random.seed(repro_seed)
        # --- END FIX ---
        random.shuffle(filtered_candidates)
        return filtered_candidates[:5]
        
    return filtered_candidates

def generate_candidates_for_template(
    func: Callable, 
    logical_steps: list, 
    correct_trace: dict, 
    original_solution_mapping: dict, # <-- New argument
    index: int, 
    debug_mode: bool = False
) -> list[dict]:
    """Orchestrates the generation of all possible conceptual error candidates."""
    # ... (code inside this function remains the same until the candidate_package creation) ...
    mutations = discover_mutation_opportunities(func, logical_steps, correct_trace, debug_mode=debug_mode)
    candidates = []
    original_code = inspect.getsource(func)
    repro_seed = hash(f"conceptual_{index}")

    for mutation in mutations:
        # ... (inner loop logic is unchanged) ...
        mutated_code, flawed_trace = apply_mutation(func, mutation)
        if not flawed_trace or not is_plausible_trace(correct_trace, flawed_trace): continue
        correct_answer, flawed_answer = correct_trace.get('answer'), flawed_trace.get('answer')
        norm_correct = str(float(correct_answer)) if isinstance(correct_answer, (int, float)) else str(correct_answer)
        norm_flawed = str(float(flawed_answer)) if isinstance(flawed_answer, (int, float)) else str(flawed_answer)
        if correct_answer is None or flawed_answer is None or norm_correct == norm_flawed: continue
        try:
            flawed_nl_reconstruction = reconstruct_solution_from_trace(logical_steps, flawed_trace, candidate_mutation=mutation)
            if not flawed_nl_reconstruction: continue
        except AttributeError as e:
            if "'str' object has no attribute 'numerator'" in str(e): print(f"\n[DEBUG] Caught AttributeError on index {index}. Skipping mutation: {mutation}"); continue
            else: raise e

        # --- FIX: Update erroneous_line_number logic ---
        if mutation['type'] == 'incorrect_final_answer_selection':
            erroneous_line_number = "FA"
        else:
            erroneous_line_number = next((s['line_number'] for s in logical_steps if s.get('output_variable') == mutation['target_variable']), None)
        # --- END FIX ---
            
        target_var, correct_target_val, flawed_target_val = mutation['target_variable'], correct_trace.get(mutation['target_variable']), flawed_trace.get(mutation['target_variable'])

        candidate_package = {
            "index": index,
            # --- FIX: Add original_solution_mapping to the package ---
            "original_solution_mapping": original_solution_mapping,
            # --- END FIX ---
            "original_function_code": original_code,
            "mutation_details": mutation, "correct_trace": correct_trace, "flawed_trace": flawed_trace,
            "correct_value": correct_target_val, "flawed_value": flawed_target_val,
            "logical_steps": logical_steps, "flawed_nl_reconstruction": flawed_nl_reconstruction,
            "erroneous_line_number": erroneous_line_number,
            "explanation": _generate_explanation(mutation, correct_trace)
        }
        candidates.append(candidate_package)
        
    capped_candidates = _cap_candidates(candidates, repro_seed)
    for cand in capped_candidates: cand['repro_seed'] = repro_seed
    return capped_candidates

#### **Cell 7: Main Driver**
*   **Functionality:** Executes the entire pipeline in parallel for a specified range of problem indices.
*   **Logic:**
    *   `run_candidate_generation_pipeline_parallel`: The entry point. It defines a list of tasks (indices to process) and uses the `joblib` library to run them in parallel across all available CPU cores.
    *   `process_single_index`: The worker function executed by each parallel job. It contains the model fallback logic: it tries to load a template from the preferred model (`gemini-2.5-flash`), and if that fails or produces no candidates, it tries the next model (`gpt-4.1`).
    *   `save_candidates_and_get_metadata`: A utility to save the generated candidate packages to disk as individual JSON files and compile their metadata into a list for the final CSV catalog. The logic for generating a unique filename is a good safeguard.

In [150]:
def process_single_index(index_info: tuple, model_preference: list) -> List[Dict]:
    """
    Worker function to process a single problem index for candidate generation.

    This function attempts to load a valid Formalization Template according to the
    model preference order. For the first successful model, it generates a list of
    conceptual error candidates. It includes checks to skip processing for templates
    that are malformed (e.g., brittle templates with attribute access).

    Args:
        index_info: A tuple containing the (tier, index) of the problem.
        model_preference: A list of model names to try in order.

    Returns:
        A list of candidate dictionaries for the first successful model, or an
        empty list if no valid candidates could be generated for the index.
    """
    tier, index = index_info
    
    # Load the original solution text mapping once per index.
    original_solution_mapping = build_solution_mapping(index)
    if not original_solution_mapping:
        return [] # Cannot proceed without the original solution text.

    # Iterate through models, attempting to find a valid template.
    for model_name in model_preference:
        solve_module = load_function_module(tier, index, model_name)
        logical_steps = load_logical_steps(tier, index, model_name)
        
        if solve_module and logical_steps:
            # Check for brittle templates with invalid placeholders (e.g., {var.attr}).
            is_brittle = False
            for step in logical_steps:
                template = step.get('solution_line_template', '')
                placeholders = re.findall(r'\{([^}]+)\}', template)
                if any('.' in p for p in placeholders):
                    # This template is malformed; log it and try the next model.
                    print(f"[INFO] Skipping brittle template for index {index} (model: {model_name}) due to attribute access in placeholder.")
                    is_brittle = True
                    break
            if is_brittle:
                continue # Try the next model for this index.

            # If the template is valid, proceed with trace and candidate generation.
            solve_function = solve_module.solve
            correct_trace = execution_trace(solve_function)
            if not correct_trace:
                continue # If code is not executable, try the next model.
            
            # Generate candidates using the validated template and original solution text.
            candidates = generate_candidates_for_template(
                func=solve_function,
                logical_steps=logical_steps,
                correct_trace=correct_trace,
                original_solution_mapping=original_solution_mapping,
                index=index
            )
            
            # If candidates were successfully generated, tag them and return.
            if candidates:
                for cand in candidates:
                    cand['model_name'] = model_name
                return candidates # Success, no need to try other models for this index.
            
    # Return an empty list if no model produced valid candidates for this index.
    return []

def save_candidates_and_get_metadata(tier: str, index: int, model_name: str, candidates: List[Dict]) -> List[Dict]:
    """Saves candidate packages to JSON files and returns a list of metadata records."""
    output_dir = CONCEPTUAL_CANDIDATES_DIR / tier / str(index)
    output_dir.mkdir(parents=True, exist_ok=True)
    metadata_records = []
    
    for cand in candidates:
        mutation = cand['mutation_details']
        m_type = mutation['type']
        target_var = mutation['target_variable']
        
        # Use a more descriptive filename including the error type
        filename = f"{model_name}_{m_type}_{target_var}.json"
        filepath = output_dir / filename
        
        with open(filepath, 'w', encoding='utf-8') as f:
            json.dump(cand, f, indent=2, default=str)
        
        record = {
            "index": index, "tier": tier, "model": model_name,
            "mutation_type": m_type, "target_variable": target_var,
            "mutation_details": json.dumps(mutation, default=str),
            "filepath": str(filepath.relative_to(PROJECT_ROOT))
        }
        metadata_records.append(record)
        
    return metadata_records

from joblib import Parallel, delayed

def run_candidate_generation_pipeline_parallel(indices_to_generate: List[int]):
    """Drives the conceptual error candidate generation pipeline in parallel using joblib."""
    print("--- Starting Conceptual Error Candidate Generation (PARALLEL MODE) ---")
    
    model_preference = ['google_gemini-2.5-flash', 'openai_gpt-4.1']
    tasks = []
    for tier, indices in TIER_LISTS.items():
        for index in indices:
            if index in indices_to_generate:
                tasks.append((tier, index))
    
    print(f"Prepared {len(tasks)} indices for processing across all available CPU cores.")

    results_list = Parallel(n_jobs=-1)(
        delayed(process_single_index)(task, model_preference) 
        for task in tqdm(tasks, desc="Processing Indices")
    )

    print("\n--- Parallel processing complete. Saving artifacts... ---")
    
    all_candidates = [candidate for sublist in results_list if sublist for candidate in sublist]
    all_candidate_metadata = []

    if not all_candidates:
        print("No valid conceptual error candidates were generated.")
        return

    # --- FIX: Import datetime for timestamping ---
    import datetime
    # --- END FIX ---

    for cand in tqdm(all_candidates, desc="Saving Candidates"):
        mutation, index, model_name = cand['mutation_details'], cand['index'], cand['model_name']
        m_type, target_var = mutation['type'], mutation['target_variable']
        tier = next((t for t, ids in TIER_LISTS.items() if index in ids), 'unknown')
        repro_seed = cand.get('repro_seed')
        
        # --- FIX: Get the UTC timestamp ---
        now_utc = datetime.datetime.now(datetime.timezone.utc)
        # --- END FIX ---
        
        # --- FIX: Extract correct and flawed values ---
        correct_value = cand.get('correct_value')
        flawed_value = cand.get('flawed_value')
        # --- END FIX ---

        output_dir = CONCEPTUAL_CANDIDATES_DIR / tier / str(index)
        output_dir.mkdir(parents=True, exist_ok=True)
        
        replacement_var = mutation.get('replacement_variable', '')
        filename = f"{model_name}_{m_type}_{target_var}_{replacement_var}.json"
        filepath = output_dir / filename
        
        with open(filepath, 'w', encoding='utf-8') as f:
            json.dump(cand, f, indent=2, default=str)
        
        # --- FIX: Add new fields to the metadata record ---
        record = {
            "index": index, 
            "tier": tier, 
            "model": model_name, 
            "mutation_type": m_type, 
            "target_variable": target_var,
            "correct_value": str(correct_value), # Convert to string for CSV
            "flawed_value": str(flawed_value),   # Convert to string for CSV
            "repro_seed": repro_seed,
            "date_utc": now_utc.strftime('%Y-%m-%d'),
            "time_utc": now_utc.strftime('%H:%M:%S'),
            "mutation_details": json.dumps(mutation, default=str),
            "filepath": str(filepath.relative_to(PROJECT_ROOT))
        }
        # --- END FIX ---
        all_candidate_metadata.append(record)

    if all_candidate_metadata:
        catalog_df = pd.DataFrame(all_candidate_metadata)
        # --- FIX: Add new columns to the reindex list ---
        cols = [
            'index', 'tier', 'model', 'mutation_type', 'target_variable', 
            'correct_value', 'flawed_value', 'repro_seed', 
            'date_utc', 'time_utc', 'mutation_details', 'filepath'
        ]
        # --- END FIX ---
        catalog_df = catalog_df.reindex(columns=[c for c in cols if c in catalog_df.columns])
        catalog_path = CONCEPTUAL_CANDIDATES_DIR / "conceptual_candidate_catalog.csv"
        catalog_df.to_csv(catalog_path, index=False)
        print(f"\n--- Pipeline Complete ---")
        print(f"Generated and saved {len(all_candidate_metadata)} conceptual error candidates.")
        print(f"Catalog saved to: {catalog_path}")

## Cell 8: Execute Pipeline

This final cell executes the main driver function to run the entire test pipeline.

In [151]:
run_candidate_generation_pipeline_parallel(indices_to_generate=list(range(3000)))

--- Starting Conceptual Error Candidate Generation (PARALLEL MODE) ---
Prepared 3000 indices for processing across all available CPU cores.


Processing Indices:   0%|          | 0/3000 [00:00<?, ?it/s]

[INFO] Skipping brittle template for index 1484 (model: google_gemini-2.5-flash) due to attribute access in placeholder.

--- Parallel processing complete. Saving artifacts... ---


Saving Candidates:   0%|          | 0/13521 [00:00<?, ?it/s]


--- Pipeline Complete ---
Generated and saved 13521 conceptual error candidates.
Catalog saved to: /Users/arvindsuresh/Documents/Github/Erdos-DL-June25-Math/data/conceptual-error-candidates/conceptual_candidate_catalog.csv
