#### Cell 1: Paths, definitions, dataset

In [43]:
import json
from pathlib import Path
import ast
import inspect
from datasets import load_dataset

GSM8K_TRAIN = load_dataset("gsm8k", "main")["train"]

# 1. Define Project Root and Data Directories
def find_project_root(marker: str = ".git") -> Path:
    """Traverse upwards to find the project root, marked by `marker`."""
    current_path = Path.cwd().resolve()
    while current_path != current_path.parent:
        if (current_path / marker).exists():
            return current_path
        current_path = current_path.parent
    raise FileNotFoundError(f"Could not find project root. Marker '{marker}' not found.")

PROJECT_ROOT = find_project_root()
DATA_DIR = PROJECT_ROOT / 'data'

# 2. Define path to validated manifests 
VALIDATED_MANIFEST_DIR = DATA_DIR / 'tier-manifests-examples-json' / 'tier2'
# Base directory for all processed outputs
PROCESSED_DATA_DIR = DATA_DIR / 'tier-manifests-examples-processed'

print(f"Project root: {PROJECT_ROOT}")
print(f"Data directory: {DATA_DIR}")
print(f"Validated Tier 2 manifest directory: {VALIDATED_MANIFEST_DIR}")

# 3. Check if the directory exists to prevent downstream errors
if not VALIDATED_MANIFEST_DIR.is_dir():
    print(f"\nWARNING: Directory not found: {VALIDATED_MANIFEST_DIR}")

# Specify the index to process
example_index = 4822
example_tier = "tier2"

Project root: /Users/arvindsuresh/Documents/Github/Erdos-DL-June25-Math
Data directory: /Users/arvindsuresh/Documents/Github/Erdos-DL-June25-Math/data
Validated Tier 2 manifest directory: /Users/arvindsuresh/Documents/Github/Erdos-DL-June25-Math/data/tier-manifests-examples-json/tier2


#### Cell 2: Data loading and processing

In [44]:
import importlib.util
from types import ModuleType

def load_manifest_as_string(manifest_dir: Path, index: int) -> str | None:
    """
    Loads a specific manifest file as a string.
    Note: Assumes filename is '{index}.json'. Adjust if format differs.
    """
    file_path = manifest_dir / f"_{index}.json"
    try:
        return file_path.read_text()
    except FileNotFoundError:
        print(f"ERROR: Manifest file not found at {file_path}")
        return None


def save_manifest_components(
    manifest_string: str,
    tier: str,
    index: int,
    base_output_dir: Path
) -> None:
    """
    Parses a manifest string and saves its components to structured directories.
    """
    try:
        manifest_data = json.loads(manifest_string)
    except json.JSONDecodeError:
        print(f"ERROR: Invalid JSON in manifest string for index {index}.")
        return

    function_code = manifest_data.get("function_code")
    logical_steps = manifest_data.get("logical_steps")

    if not function_code or not logical_steps:
        print(f"WARNING: Manifest for index {index} is missing required keys.")
        return

    # 1. Define output directory for the specific index
    output_dir = base_output_dir / tier / str(index)
    output_dir.mkdir(parents=True, exist_ok=True)

    # 2. Save function_code to a .py file
    py_file_path = output_dir / "solve.py"
    py_file_path.write_text(function_code)

    # 3. Save logical_steps to a .json file
    json_file_path = output_dir / "logical_steps.json"
    with open(json_file_path, 'w') as f:
        json.dump(logical_steps, f, indent=2)


def load_function_module(
    tier: str,
    index: int,
    base_dir: Path = PROCESSED_DATA_DIR
) -> ModuleType | None:
    """
    Dynamically loads the 'solve.py' function for a given tier and index.

    Args:
        tier: The tier of the problem (e.g., 'tier1').
        index: The index of the problem.
        base_dir: The base directory for processed files.

    Returns:
        The loaded module object, or None if the file is not found.
    """
    py_file_path = base_dir / tier / str(index) / "solve.py"
    if not py_file_path.exists():
        print(f"ERROR: Python file not found at {py_file_path}")
        return None

    # Create a unique name for the module to avoid conflicts
    module_name = f"manifests.t{tier}.i{index}.solve"

    spec = importlib.util.spec_from_file_location(module_name, py_file_path)
    if spec and spec.loader:
        module = importlib.util.module_from_spec(spec)
        spec.loader.exec_module(module)
        return module
    else:
        print(f"ERROR: Could not create module spec for {py_file_path}")
        return None


def load_logical_steps(
    tier: str,
    index: int,
    base_dir: Path = PROCESSED_DATA_DIR
) -> list[dict] | None:
    """
    Loads the 'logical_steps.json' for a given tier and index.

    Args:
        tier: The tier of the problem (e.g., 'tier1').
        index: The index of the problem.
        base_dir: The base directory for processed files.

    Returns:
        A list of dictionaries representing the logical steps, or None on error.
    """
    json_file_path = base_dir / tier / str(index) / "logical_steps.json"
    try:
        with open(json_file_path, 'r') as f:
            return json.load(f)
    except FileNotFoundError:
        print(f"ERROR: JSON file not found at {json_file_path}")
        return None
    except json.JSONDecodeError:
        print(f"ERROR: Failed to decode JSON from {json_file_path}")
        return None

for index in [310, 2300, 2401, 2918, 4215, 4822]:
    # Load
    manifest_str = load_manifest_as_string(VALIDATED_MANIFEST_DIR, index)

    # Process and Save
    if manifest_str:
        save_manifest_components(
            manifest_string=manifest_str,
            tier=example_tier,
            index=index,
            base_output_dir=PROCESSED_DATA_DIR
        )
        print(f"\nProcessed files for index {index} saved under: {PROCESSED_DATA_DIR}")


Processed files for index 310 saved under: /Users/arvindsuresh/Documents/Github/Erdos-DL-June25-Math/data/tier-manifests-examples-processed

Processed files for index 2300 saved under: /Users/arvindsuresh/Documents/Github/Erdos-DL-June25-Math/data/tier-manifests-examples-processed

Processed files for index 2401 saved under: /Users/arvindsuresh/Documents/Github/Erdos-DL-June25-Math/data/tier-manifests-examples-processed

Processed files for index 2918 saved under: /Users/arvindsuresh/Documents/Github/Erdos-DL-June25-Math/data/tier-manifests-examples-processed

Processed files for index 4215 saved under: /Users/arvindsuresh/Documents/Github/Erdos-DL-June25-Math/data/tier-manifests-examples-processed

Processed files for index 4822 saved under: /Users/arvindsuresh/Documents/Github/Erdos-DL-June25-Math/data/tier-manifests-examples-processed


#### Cell 3: Core trace and solution generation

In [45]:
from typing import Any, Dict
import re
import copy

def execution_trace(func) -> Dict[str, Any] | None:
    """
    Parses a function's source code and executes it line by line
    to capture the final value of each assigned variable.

    This function is designed for simple, self-contained functions
    like the 'solve()' methods in the manifests.

    Args:
        func: The function object to trace.

    Returns:
        A dictionary mapping variable names to their computed values, or None on error.
    """
    try:
        # 1. Get the source code and parse it into an Abstract Syntax Tree (AST)
        src = inspect.getsource(func)
        tree = ast.parse(src)
    except (TypeError, FileNotFoundError, SyntaxError) as e:
        print(f"ERROR: Could not get or parse source for {func.__name__}: {e}")
        return None

    # 2. Assume the first node in the file is the function definition
    func_def = tree.body[0]
    if not isinstance(func_def, ast.FunctionDef):
        print("ERROR: The parsed source code does not start with a function definition.")
        return None

    # 3. `env` will store the execution environment (variable states)
    env = {}

    # 4. Execute each statement in the function body sequentially
    for stmt in func_def.body:
        # We only process simple assignments (e.g., x = y + z)
        if isinstance(stmt, ast.Assign):
            # Compile and execute the single assignment statement.
            # The result is stored in the `env` dictionary, which acts
            # as the memory for subsequent computations.
            try:
                # ast.Module is required to create a valid code block
                code_obj = compile(ast.Module([stmt], type_ignores=[]), '<string>', 'exec')
                exec(code_obj, {}, env)
            except Exception as e:
                # This helps debug issues in the manifest's function_code
                print(f"ERROR: Failed to execute line: {ast.unparse(stmt)}. Error: {e}")
                return None # Halt on error to ensure trace integrity
        
        # We ignore other node types like `ast.Return`

    return env


def generate_flawed_trace(
    func,
    error_details: dict[str, Any]
) -> dict[str, Any] | None:
    """
    Generates a new execution trace by injecting a computational error.

    It modifies the function's AST in memory to replace a calculation
    with a hardcoded flawed value, then re-executes the logic to
    propagate the error.

    Args:
        func: The original function object to modify.
        error_details: Dict containing 'variable', 'correct_value', 'flawed_value'.

    Returns:
        A new dictionary representing the flawed execution trace, or None on error.
    """
    try:
        src = inspect.getsource(func)
        tree = ast.parse(src)
    except (TypeError, FileNotFoundError, SyntaxError) as e:
        print(f"ERROR: Could not get or parse source for {func.__name__}: {e}")
        return None

    func_def = tree.body[0]
    if not isinstance(func_def, ast.FunctionDef):
        print("ERROR: Parsed source does not start with a function definition.")
        return None

    variable_to_change = error_details["variable"]
    flawed_value = error_details["flawed_value"]
    
    modified_body = copy.deepcopy(func_def.body)

    node_found_and_modified = False
    for i, node in enumerate(modified_body):
        if isinstance(node, ast.Assign):
            if node.targets[0].id == variable_to_change:
                # --- FIX IS HERE ---
                # Create a new constant node for the flawed value and copy the
                # location info (lineno, col_offset) from the original node.
                new_value_node = ast.copy_location(
                    ast.Constant(value=flawed_value),
                    node.value  # The original RHS node we are replacing
                )
                node.value = new_value_node
                node_found_and_modified = True
                break
    
    if not node_found_and_modified:
        print(f"ERROR: Could not find assignment node for '{variable_to_change}' to modify.")
        return None

    env = {}
    for stmt in modified_body:
        if isinstance(stmt, ast.Assign):
            try:
                code_obj = compile(ast.Module([stmt], type_ignores=[]), '<string>', 'exec')
                exec(code_obj, {}, env)
            except Exception as e:
                print(f"ERROR: Failed to execute modified line: {ast.unparse(stmt)}. Error: {e}")
                return None
    
    return env


def build_solution_mapping(
        index: int, 
        dataset: 'datasets.Dataset' = GSM8K_TRAIN,
        exclude_FA: bool = True
    ):
    """
    Extracts the natural language solution for a given problem index,
    cleans it, and structures it into a line-numbered dictionary.
    """
    solution_mapping = {}
    solution_text = dataset[index]["answer"]
    lines = [ln.strip() for ln in solution_text.splitlines() if ln.strip()]

    # Improved regex to handle commas in the final answer
    if lines and re.match(r"^####\s*[\d\.,]+$", lines[-1]):
        solution_mapping["FA"] = lines.pop(-1).strip()

    for i, line in enumerate(lines, 1):
        solution_mapping[f"L{i}"] = line

    if exclude_FA and "FA" in solution_mapping:
        del solution_mapping["FA"]

    return solution_mapping


def reconstruct_solution_lines_enhanced(
    logical_steps: list[dict],
    eval_trace: dict[str, Any]
) -> dict[str, str]:
    """
    Substitutes variable placeholders in solution templates with their values,
    using normalization for clean formatting.
    """
    reconstructed_mapping = {}
    placeholder_pattern = re.compile(r'\{([a-zA-Z0-9_]+)\}')

    for step in logical_steps:
        line_number = step.get("line_number")
        template = step.get("solution_line_template")

        if not line_number or not template:
            continue

        def replacer(match):
            variable_name = match.group(1)
            value = eval_trace.get(variable_name)

            if value is None:
                return f"{{ERROR: {variable_name} not in trace}}"
            
            # Use the normalization utility for correct string rendering
            value_to_render = normalize_value(value)
            
            return str(value_to_render)

        reconstructed_line = placeholder_pattern.sub(replacer, template)
        reconstructed_mapping[line_number] = reconstructed_line

    return reconstructed_mapping



#### Cell 4: Individual Error Generation functions

In [46]:
import random


def generate_off_by_n_error(
    correct_value: int,
    offset_range: tuple[int, int] = (1, 5)
    ):
    """Generates a minor miscalculation error. Assumes input is an integer."""
    offset = random.randint(offset_range[0], offset_range[1])
    if random.random() < 0.5:
        offset = -offset

    flawed_value = correct_value + offset
    if flawed_value == correct_value: # Avoid generating a non-error
        flawed_value += 1

    return {
        "flawed_value": flawed_value,
        "explanation_type": "This appears to be a minor miscalculation."
    }


def generate_off_by_factor_of_10_error(correct_value: int):
    """Generates a dropped/added zero error. Assumes input is a multiple of 100."""
    options = ["divide", "multiply"]
    choice = random.choice(options)
    
    if choice == "divide":
        flawed_value = correct_value // 10
        explanation = "It appears a zero was dropped from the number."
    else: # multiply
        flawed_value = correct_value * 10
        explanation = "It appears an extra zero was added to the number."

    return {
        "flawed_value": flawed_value,
        "explanation_type": explanation
    }


def generate_digit_transposition_error(correct_value: int) -> dict[str, Any] | None:
    """
    Swaps two adjacent digits. Now includes a check to prevent creating
    a leading zero, which would alter the number's magnitude.
    """
    # This check is important as this function should only ever receive integers
    if not isinstance(correct_value, int):
        return None

    s_val = str(abs(correct_value))
    
    # Pre-condition: must have at least 2 digits to swap.
    if len(s_val) < 2:
        return None

    # Find all indices where adjacent digits are different
    possible_indices = [i for i in range(len(s_val) - 1) if s_val[i] != s_val[i+1]]
    
    # --- NEW VALIDATION LOGIC ---
    # Filter out swaps that would create a leading zero.
    # A swap at index i is invalid if i=0 and the digit at i+1 is '0'.
    valid_indices = [
        i for i in possible_indices
        if not (i == 0 and s_val[i+1] == '0')
    ]
    
    # If no valid swaps are possible, we cannot generate this error.
    if not valid_indices:
        return None
    # --- END NEW LOGIC ---

    idx_to_swap = random.choice(valid_indices)
    
    s_list = list(s_val)
    s_list[idx_to_swap], s_list[idx_to_swap+1] = s_list[idx_to_swap+1], s_list[idx_to_swap]
    
    flawed_value = int("".join(s_list))
    if correct_value < 0:
        flawed_value = -flawed_value

    return {
        "flawed_value": flawed_value,
        "explanation_type": "It appears two adjacent digits were swapped."
    }


def generate_stem_off_by_n_error(
    correct_value: int,
    offset_range: tuple[int, int] = (1, 3)
) -> dict[str, Any]:
    """
    Applies a small offset to the 'stem' of a number (the part before the final zero).
    Assumes the input is a non-zero multiple of 10.
    """
    stem = correct_value // 10
    
    offset = random.randint(offset_range[0], offset_range[1])
    if random.random() < 0.5:
        offset = -offset

    flawed_stem = stem + offset
    if flawed_stem == stem:
        flawed_stem += 1 # Ensure the value changes

    flawed_value = flawed_stem * 10

    return {
        "flawed_value": flawed_value,
        "explanation_type": "It appears there was a miscalculation with the digits before the final zero."
    }


def generate_decimal_shift_error(correct_value: float) -> dict[str, Any]:
    """Multiplies or divides a float by 10 to simulate a decimal shift."""
    choice = random.choice(["multiply", "divide"])
    flawed_value = correct_value * 10 if choice == "multiply" else correct_value / 10
    return {
        "flawed_value": round(flawed_value, 10), # Round to avoid precision issues
        "explanation_type": "It appears the decimal point was misplaced."
    }


def generate_monetary_off_by_cents_error(correct_value: float) -> dict[str, Any]:
    """Adds or subtracts a small, realistic monetary amount."""
    offsets = [0.10, 0.25, 0.50, 1.00, 2.00]
    offset = random.choice(offsets)
    if random.random() < 0.5:
        offset = -offset
    
    flawed_value = correct_value + offset
    return {
        "flawed_value": round(flawed_value, 2), # Ensure result is a valid monetary value
        "explanation_type": "This appears to be a minor miscalculation in cents or dollars."
    }


def generate_float_off_by_n_error(correct_value: float) -> dict[str, Any]:
    """Applies a small offset to a general float."""
    # This creates an offset that is roughly 10-20% of the original value's magnitude
    magnitude = abs(correct_value)
    offset = random.uniform(magnitude * 0.1, magnitude * 0.2)
    if random.random() < 0.5:
        offset = -offset
        
    flawed_value = correct_value + offset
    return {
        "flawed_value": round(flawed_value, 10),
        "explanation_type": "This appears to be a minor miscalculation."
    }



#### Cell 5: Error Generation Controller and Helpers

In [47]:
import ast
import inspect

def get_target_variables(logical_steps: list[dict]) -> list[str]:
    """
    Extracts all 'output_variable' names from the logical steps.
    These are the potential locations for injecting a computational error.
    """
    return [step['output_variable'] for step in logical_steps if 'output_variable' in step]


def has_distinct_adjacent_digits(n: int) -> bool:
    """Helper to check if a number is suitable for digit transposition."""
    s = str(abs(n))
    return len(s) >= 2 and any(s[i] != s[i+1] for i in range(len(s) - 1))


def get_operator_for_variable(func, variable_name: str) -> str | None:
    """Inspects the AST to find the operator used to compute a variable."""
    # (implementation from previous response is unchanged)
    try:
        src = inspect.getsource(func)
        tree = ast.parse(src)
    except (TypeError, FileNotFoundError, SyntaxError): return None
    for node in ast.walk(tree):
        if isinstance(node, ast.Assign) and node.targets[0].id == variable_name:
            if isinstance(node.value, ast.BinOp):
                op = node.value.op
                if isinstance(op, ast.Add): return "add"
                if isinstance(op, ast.Sub): return "sub"
                if isinstance(op, ast.Mult): return "mult"
                if isinstance(op, ast.Div): return "div"
            return "other"
    return None


def get_operand_names_for_variable(func, variable_name: str) -> list[str]:
    """
    Inspects the AST to find the names of variables used as operands
    in the calculation for the given target variable.
    """
    operand_names = []
    try:
        src = inspect.getsource(func)
        tree = ast.parse(src)
    except (TypeError, FileNotFoundError, SyntaxError):
        return []

    for node in ast.walk(tree):
        if isinstance(node, ast.Assign) and node.targets[0].id == variable_name:
            # We walk the right-hand side of the assignment
            for sub_node in ast.walk(node.value):
                # And collect the names of all variables used
                if isinstance(sub_node, ast.Name):
                    operand_names.append(sub_node.id)
            return list(set(operand_names)) # Use set to get unique names
    return []


# This is a helper function to avoid repeating the integer logic.
# The leading underscore indicates it's for internal use by the main controller.
def _get_applicable_integer_errors(correct_value: int, operator: str, operand_values: list) -> list:
    """Returns a list of applicable error generator functions for an integer."""
    applicable_generators = []
    
    # Rule: Handle addition and subtraction with special cases
    if operator in ["add", "sub"]:
        all_end_in_zero = all(isinstance(v, int) and v % 10 == 0 for v in operand_values) if operand_values else False
        if all_end_in_zero and correct_value % 10 == 0 and correct_value != 0:
            applicable_generators.append(generate_stem_off_by_n_error)
        else:
            applicable_generators.append(generate_off_by_n_error)

    # Rule: Factor-of-10 error for multiples of 100.
    if correct_value % 100 == 0 and correct_value != 0:
        applicable_generators.append(generate_off_by_factor_of_10_error)
    
    # Rule: Digit transposition.
    if has_distinct_adjacent_digits(correct_value):
        applicable_generators.append(generate_digit_transposition_error)
        
    return applicable_generators


def generate_computational_error(
    func,
    correct_trace: dict[str, Any],
    target_variables: list[str],
    debug: bool = False
) -> dict[str, Any] | None:
    """
    Applies a valid error type based on a rule-based system for ints and floats.
    Includes the fix for integer-like float handling.
    """
    shuffled_targets = random.sample(target_variables, len(target_variables))

    if debug: print("\n--- DEBUG: Analyzing potential error targets ---")

    for variable_name in shuffled_targets:
        correct_value = correct_trace.get(variable_name)
        operator = get_operator_for_variable(func, variable_name)

        if debug: print(f"\n- Var: '{variable_name}' (Value: {correct_value}, Type: {type(correct_value).__name__}, Op: '{operator}')")

        # --- Integer Branch ---
        if isinstance(correct_value, int):
            operand_names = get_operand_names_for_variable(func, variable_name)
            operand_values = [correct_trace.get(name) for name in operand_names]
            applicable_generators = _get_applicable_integer_errors(correct_value, operator, operand_values)
            
            if debug: print_applicable_generators(applicable_generators)
                
            if applicable_generators:
                generator_func = random.choice(applicable_generators)
                error_result = generator_func(correct_value)
                return { "variable": variable_name, "correct_value": correct_value, **error_result }

        # --- Float Branch ---
        elif isinstance(correct_value, float):
            applicable_generators = []
            
            # Rule #3: Integer-like floats (e.g., 2000.0)
            if correct_value.is_integer():
                int_val = int(correct_value)
                operand_names = get_operand_names_for_variable(func, variable_name)
                operand_values = [correct_trace.get(name) for name in operand_names]
                
                # Get the list of applicable INTEGER generators
                integer_error_gens = _get_applicable_integer_errors(int_val, operator, operand_values)

                if debug: print_applicable_generators(integer_error_gens)

                if integer_error_gens:
                    # --- THIS IS THE FIX ---
                    # 1. Choose an integer-based generator.
                    generator_func = random.choice(integer_error_gens)
                    # 2. Call it with the INTEGER value.
                    error_result = generator_func(int_val)
                    # 3. Convert the resulting flawed integer BACK to a float.
                    error_result['flawed_value'] = float(error_result['flawed_value'])
                    # 4. Return the complete error details.
                    return { "variable": variable_name, "correct_value": correct_value, **error_result }

            # Rules for "true" floats
            else:
                # Rule #2: Monetary floats ending in .00
                if str(correct_value).endswith('.00'):
                    applicable_generators.append(generate_monetary_off_by_cents_error)
                # Rule #4: General floats
                else:
                    applicable_generators.append(generate_float_off_by_n_error)
                
                # Decimal shift is applicable to most true floats
                applicable_generators.append(generate_decimal_shift_error)
                
                if debug: print_applicable_generators(applicable_generators)

                if applicable_generators:
                    generator_func = random.choice(applicable_generators)
                    error_result = generator_func(correct_value)
                    return { "variable": variable_name, "correct_value": correct_value, **error_result }

    if debug: print("--- END DEBUG ---")
    else: print("WARNING: Could not find a suitable variable/error type combination.")
    return None


# --- IN CELL 5 ---
def get_applicable_generators(
    func,
    correct_trace: dict[str, Any],
    variable_name: str
) -> list:
    """
    Final refined logic to provide diverse but non-redundant error types.
    """
    correct_value = correct_trace.get(variable_name)
    if not isinstance(correct_value, (int, float)): return []

    operator = get_operator_for_variable(func, variable_name)
    operand_names = get_operand_names_for_variable(func, variable_name)
    operand_values = [correct_trace.get(name) for name in operand_names]
    
    # --- Integer and Integer-like Float Logic ---
    if isinstance(correct_value, int) or (isinstance(correct_value, float) and correct_value.is_integer()):
        int_val = int(correct_value)
        return _get_applicable_integer_errors(int_val, operator, operand_values)
    
    # --- "True" Float Logic ---
    elif isinstance(correct_value, float):
        applicable_generators = []
        applicable_generators.append(generate_float_off_by_n_error)
        
        # Decimal shift is only for true floats.
        if correct_value != 0:
            applicable_generators.append(generate_decimal_shift_error)
            
        return list(set(applicable_generators))
        
    return []


# Helper for cleaner debug output
def print_applicable_generators(generators: list):
    gen_names = [g.__name__ for g in generators]
    print(f"  Applicable Generators: {gen_names if gen_names else 'NONE'}")


# def generate_computational_error(
#     func,
#     correct_trace: dict[str, Any],
#     target_variables: list[str],
#     debug: bool = False
# ) -> dict[str, Any] | None:
#     """
#     Selects a random variable and applies a valid, realistic error type.
#     If debug=True, it prints its decision-making process for each variable.
#     """
#     shuffled_targets = random.sample(target_variables, len(target_variables))

#     if debug:
#         print("\n--- DEBUG: Analyzing potential error targets ---")

#     for variable_name in shuffled_targets:
#         correct_value = correct_trace.get(variable_name)

#         # Start with a clean list for each variable
#         applicable_generators = []
        
#         # We only work with numeric types for computational errors
#         if not isinstance(correct_value, (int, float)):
#             if debug:
#                 print(f"\n- Var: '{variable_name}' -> SKIPPED (Value: '{correct_value}', Type: {type(correct_value).__name__})")
#             continue

#         operator = get_operator_for_variable(func, variable_name)

#         # --- Applicability Logic ---
#         if isinstance(correct_value, int):
#             # Rule 1: Handle addition and subtraction with special cases
#             if operator in ["add", "sub"]:
#                 operand_names = get_operand_names_for_variable(func, variable_name)
#                 operand_values = [correct_trace.get(name) for name in operand_names]
#                 all_end_in_zero = all(isinstance(v, int) and v % 10 == 0 for v in operand_values) if operand_values else False

#                 if all_end_in_zero and correct_value % 10 == 0 and correct_value != 0:
#                     applicable_generators.append(generate_stem_off_by_n_error)
#                 else:
#                     applicable_generators.append(generate_off_by_n_error)

#             # Rule 2: Factor-of-10 error for multiples of 100.
#             if correct_value % 100 == 0 and correct_value != 0:
#                 applicable_generators.append(generate_off_by_factor_of_10_error)
            
#             # Rule 3: Digit transposition.
#             if has_distinct_adjacent_digits(correct_value):
#                 applicable_generators.append(generate_digit_transposition_error)
        
#         # Add float-specific rules here in the future
#         # elif isinstance(correct_value, float):
#         #     pass

#         if debug:
#             print(f"\n- Var: '{variable_name}' (Value: {correct_value}, Op: '{operator}')")
#             gen_names = [g.__name__ for g in applicable_generators]
#             print(f"  Applicable Generators: {gen_names if gen_names else 'NONE'}")

#         if applicable_generators:
#             generator_func = random.choice(applicable_generators)
#             error_result = generator_func(correct_value)
            
#             return {
#                 "variable": variable_name,
#                 "correct_value": correct_value,
#                 "flawed_value": error_result["flawed_value"],
#                 "explanation_type": error_result["explanation_type"]
#             }
            
#     if debug:
#         print("--- END DEBUG ---")
#     else:
#         # This warning only shows when not in debug mode
#         print("WARNING: Could not find a suitable variable/error type combination.")
#     return None


#### Cell 6: Validation Utilities

In [48]:

def get_sign(n) -> int:
    """Returns the sign of a number (1 for positive, -1 for negative, 0 for zero)."""
    if n > 0: return 1
    if n < 0: return -1
    return 0


def normalize_value(value):
    """
    Normalizes a numeric value for consistent comparison and rendering.
    If a float can be represented as an integer (e.g., 20.0), it is cast to int.
    """
    if isinstance(value, float) and value.is_integer():
        return int(value)
    return value


def is_trace_valid(
    flawed_trace: dict[str, Any],
    correct_trace: dict[str, Any]
) -> bool:
    """
    Validates a flawed trace with a strict but correct type-checking rule:
    An integer-like value must remain integer-like.
    """
    for var_name, correct_val in correct_trace.items():
        if var_name not in flawed_trace:
            continue

        flawed_val = flawed_trace.get(var_name)
        
        # --- NEW, CORRECTED TYPE CHECK ---
        # Normalize both values to see if they represent integers.
        is_correct_int_like = isinstance(normalize_value(correct_val), int)
        is_flawed_int_like = isinstance(normalize_value(flawed_val), int)

        # Rule: If the correct value was integer-like, the flawed one must also be.
        # This prevents errors like "15 students" becoming "15.5 students".
        if is_correct_int_like and not is_flawed_int_like:
            print(f"VALIDATION FAIL (Type): Var '{var_name}' changed from an int-like value to a true float ({correct_val} -> {flawed_val}).")
            return False
        # --- END NEW TYPE CHECK ---

        # Sign Check remains the same
        processed_correct = normalize_value(correct_val)
        processed_flawed = normalize_value(flawed_val)
        if isinstance(processed_correct, (int, float)):
            if get_sign(processed_correct) != get_sign(processed_flawed):
                print(f"VALIDATION FAIL (Sign): Var '{var_name}' changed sign from {correct_val} to {flawed_val}.")
                return False
                
    return True


# # --- Example of how to use and verify the new function ---

# # 1. Load your module, logical steps, and generate the correct trace first
# solve_module = load_function_module(example_tier, example_index)
# logical_steps = load_logical_steps(example_tier, example_index)
# correct_trace = None
# if solve_module:
#     solve_function = solve_module.solve
#     correct_trace = execution_trace(solve_function)

# # 2. Run the enhanced reconstruction
# if logical_steps and correct_trace:
#     print("Reconstructing solution lines with enhanced formatting...")
#     reconstructed_solution = reconstruct_solution_lines_enhanced(logical_steps, correct_trace)

#     print("\n--- RECONSTRUCTED SOLUTION (ENHANCED) ---")
#     print(json.dumps(reconstructed_solution, indent=2))

#     # 3. Build and print the original solution for direct comparison
#     print("\n--- ORIGINAL SOLUTION (FOR COMPARISON) ---")
#     original_solution = build_solution_mapping(index=example_index)
#     print(json.dumps(original_solution, indent=2))

#### Cell 7: Final Artifact assembly

In [49]:

def find_error_line_number(
    variable_name: str,
    logical_steps: list[dict]
) -> str | None:
    """Finds the line number corresponding to a given output variable."""
    for step in logical_steps:
        if step.get("output_variable") == variable_name:
            return step.get("line_number")
    return None


def generate_training_artifacts_enhanced(
    logical_steps: list[dict],
    error_details: dict[str, Any],
    flawed_trace: dict[str, Any]
    ):
    """
    Generates the final training data with an enhanced explanation.
    """
    flawed_solution_map = reconstruct_solution_lines_enhanced(logical_steps, flawed_trace)
    if not flawed_solution_map:
        return None
    
    sorted_lines = sorted(flawed_solution_map.items(), key=lambda item: int(item[0][1:]))
    flawed_nl_solution = "\n".join([line for _, line in sorted_lines])

    erroneous_line = find_error_line_number(error_details["variable"], logical_steps)
    if not erroneous_line:
        return None

    # --- Create the enhanced explanation ---
    base_explanation = (
        f"The result of this computation should be {error_details['correct_value']}, "
        f"not {error_details['flawed_value']}."
    )
    type_explanation = error_details["explanation_type"]
    final_explanation = f"{base_explanation} {type_explanation}"

    target_json = {
        "verdict": "Flawed",
        "error_details": {
            "error_type": "computational_error",
            "erroneous_line_number": erroneous_line,
            "explanation": final_explanation,
        }
    }

    return flawed_nl_solution, target_json



#### Cell 8: Main Orchestrator

In [50]:

def create_single_error_example(
    func,
    logical_steps,
    correct_trace,
    max_attempts=10
):
    """
    Orchestrates the entire error generation process with validation and retries.
    """
    target_variables = get_target_variables(logical_steps)

    for i in range(max_attempts):
        print(f"Attempt {i+1}/{max_attempts}...")
        
        error_details = generate_computational_error(func, correct_trace, target_variables)
        if not error_details:
            print(" -> Failed to generate an error candidate. Retrying.")
            continue

        flawed_trace = generate_flawed_trace(func, error_details)
        if not flawed_trace:
            print(" -> Failed to generate a flawed trace. Retrying.")
            continue

        # The orchestrator calls the new, more robust validator
        if is_trace_valid(flawed_trace, correct_trace):
            flawed_nl_solution, target_json = generate_training_artifacts_enhanced(
                logical_steps, error_details, flawed_trace
            )
            print(f" -> Success! Generated a valid flawed example.")
            return flawed_nl_solution, target_json
        else:
            print(" -> Generated trace was invalid. Retrying.")

    print(f"\nFailed to generate a valid error example after {max_attempts} attempts.")
    return None, None



In [51]:
print(build_solution_mapping(index=4822, dataset=GSM8K_TRAIN))

{'L1': 'He gets 500*.8=$<<500*.8=400>>400 off the cost of lenses', 'L2': 'That means the lenses cost 500-400=$<<500-400=100>>100', 'L3': 'The frames cost 200-50=$<<200-50=150>>150', 'L4': 'So he pays 100+150=$<<100+150=250>>250'}


#### Cell 9: Testing

In [52]:
import json
import random
import inspect

def run_comprehensive_error_test(tier: str, index: int):
    """
    Runs an exhaustive and deterministic test of all possible valid computational
    errors for every variable in a given problem.
    """
    print("="*60)
    print(f"STARTING COMPREHENSIVE TEST FOR: Tier '{tier}', Index '{index}'")
    print("="*60)

    # --- 1. Load Data ---
    solve_module = load_function_module(tier, index)
    logical_steps = load_logical_steps(tier, index)
    if not (solve_module and logical_steps):
        print("ERROR: Failed to load data. Aborting.")
        return
    solve_function = solve_module.solve
    
    # --- 2. Generate Correct Trace ---
    correct_trace = execution_trace(solve_function)
    if not correct_trace:
        print("ERROR: Failed to generate correct trace. Aborting.")
        return

    # --- 3. Iterate Through All Target Variables ---
    target_variables = get_target_variables(logical_steps)
    print(f"\nFound {len(target_variables)} target variables to test: {target_variables}")

    for variable_name in target_variables:
        # --- THIS IS THE FIX ---
        correct_value = correct_trace.get(variable_name)
        # --- END FIX ---
        
        print("\n" + "~"*60)
        print(f"--- TESTING VARIABLE: '{variable_name}' (Correct Value: {correct_value}) ---")
        
        seed = hash(f"{index}-{variable_name}")
        random.seed(seed)
        
        applicable_generators = get_applicable_generators(solve_function, correct_trace, variable_name)
        
        if not applicable_generators:
            print("  -> No applicable error types found for this variable.")
            continue
        
        print(f"  -> Found {len(applicable_generators)} applicable error types: {[g.__name__ for g in applicable_generators]}")

        # --- 4. Test Every Applicable Error Type ---
        for generator_func in applicable_generators:
            print(f"\n  --- Applying Error: {generator_func.__name__} ---")
            
            random.seed(seed + hash(generator_func.__name__))

            # The value passed to the generator depends on whether it's an integer-like float
            value_for_generator = int(correct_value) if isinstance(correct_value, float) and correct_value.is_integer() else correct_value
            error_result = generator_func(value_for_generator)
            
            if not error_result:
                print("    -> Generator returned None (e.g., no valid transposition). Skipping.")
                continue

            error_details = {"variable": variable_name, "correct_value": correct_value, **error_result}
            if isinstance(correct_value, float) and correct_value.is_integer():
                error_details['flawed_value'] = float(error_details['flawed_value'])

            flawed_trace = generate_flawed_trace(solve_function, error_details)
            if not is_trace_valid(flawed_trace, correct_trace):
                print("    -> FAILED: Generated an invalid trace (sign/type change). Skipping.")
                continue

            flawed_solution, target_label = generate_training_artifacts_enhanced(logical_steps, error_details, flawed_trace)
            
            print(f"\n    --- Flawed Solution (using {generator_func.__name__}) ---")
            print(flawed_solution)
            print("\n    --- Target JSON Label ---")
            print(json.dumps(target_label, indent=2))

    print("\n" + "="*60)
    print("COMPREHENSIVE TEST COMPLETE")
    print("="*60)


# --- Run the new comprehensive test ---
for index in [310, 2300, 2401, 2918, 4822]:
    try:
        run_comprehensive_error_test(tier=example_tier, index=index)
    except Exception as e:
        import traceback
        print(f"\nAN UNCAUGHT ERROR OCCURRED: {e}")
        traceback.print_exc()

STARTING COMPREHENSIVE TEST FOR: Tier 'tier2', Index '310'

Found 8 target variables to test: ['hours_per_worker_per_month', 'wage_per_warehouse_worker', 'total_warehouse_wages', 'wage_per_manager', 'total_manager_wages', 'total_wages', 'fica_taxes', 'grand_total']

~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
--- TESTING VARIABLE: 'hours_per_worker_per_month' (Correct Value: 200) ---
  -> Found 2 applicable error types: ['generate_off_by_factor_of_10_error', 'generate_digit_transposition_error']

  --- Applying Error: generate_off_by_factor_of_10_error ---

    --- Flawed Solution (using generate_off_by_factor_of_10_error) ---
First figure out how many hours each worker works per month by multiplying the number of days they work by the number of hours a day they work: 25 days * 8 hours/day = <<25*8=2000>>2000 hours
Then calculate how much one warehouse worker makes per month by multiplying their hourly rate by the number of hours they work: 2000 hours * $15/hour = $<<2

In [53]:
# import json
# import random
# import inspect

# # It's good practice to use a separate directory for test outputs
# # to avoid polluting your main processed data directories.
# TEST_PROCESSED_DIR = DATA_DIR / 'integration_test_processed'

# def run_integration_test(tier: str, index: int):
#     """
#     Runs a full integration test with verbose failure reporting for each component.
#     """
#     print("="*50)
#     print(f"STARTING INTEGRATION TEST FOR: Tier '{tier}', Index '{index}'")
#     print("="*50)

#     # --- 1. Raw Manifest Loading ---
#     print("\n--- [1/9] Testing Raw Manifest Loading ---")
#     manifest_str = load_manifest_as_string(VALIDATED_MANIFEST_DIR, index)
#     if not manifest_str:
#         print(f" -> FAILED: Could not load raw manifest from path: {VALIDATED_MANIFEST_DIR / f'_{index}.json'}")
#         return
#     print(" -> SUCCESS: Raw manifest loaded as string.")

#     # --- 2. Saving Components ---
#     print("\n--- [2/9] Testing Save Manifest Components ---")
#     save_manifest_components(manifest_str, tier, index, base_output_dir=TEST_PROCESSED_DIR)
#     expected_py_path = TEST_PROCESSED_DIR / tier / str(index) / "solve.py"
#     if not expected_py_path.exists():
#         print(f" -> FAILED: Component file not created at {expected_py_path}")
#         return
#     print(f" -> SUCCESS: Components saved to '{TEST_PROCESSED_DIR / tier / str(index)}'.")

#     # --- 3. Processed Data Loading ---
#     print("\n--- [3/9] Testing Processed Data Loading ---")
#     solve_module = load_function_module(tier, index, base_dir=TEST_PROCESSED_DIR)
#     logical_steps = load_logical_steps(tier, index, base_dir=TEST_PROCESSED_DIR)
#     if not (solve_module and logical_steps):
#         print(" -> FAILED: Could not load processed components from test directory.")
#         return
#     solve_function = solve_module.solve
#     print(" -> SUCCESS: Module and logical steps loaded.")

#     # --- 4. Correct Trace Generation ---
#     print("\n--- [4/9] Testing Correct Trace Generation ---")
#     correct_trace = execution_trace(solve_function)
#     if not correct_trace:
#         print(" -> FAILED: execution_trace() returned None.")
#         print("\n--- FUNCTION SOURCE ---")
#         print(inspect.getsource(solve_function))
#         return
#     print(" -> SUCCESS: Correct trace generated.")

#     # --- 5. Solution Reconstruction ---
#     print("\n--- [5/9] Testing Solution Reconstruction ---")
#     # Note: This test uses your existing `reconstruct_solution_lines_enhanced`
#     reconstructed_solution = reconstruct_solution_lines_enhanced(logical_steps, correct_trace)
#     original_solution = build_solution_mapping(index=index)
#     if reconstructed_solution == original_solution:
#         print(" -> SUCCESS: Solution correctly reconstructed.")
#     else:
#         print(" -> FAILED: Reconstructed solution does not match original. (This may be expected due to formatting differences)")
#         print("\n--- ORIGINAL SOLUTION ---")
#         print(json.dumps(original_solution, indent=2))
#         print("\n--- RECONSTRUCTED SOLUTION ---")
#         print(json.dumps(reconstructed_solution, indent=2))
#         print("\n--- MISMATCH DETAILS ---")
#         all_keys = set(original_solution.keys()) | set(reconstructed_solution.keys())
#         for key in sorted(list(all_keys), key=lambda x: int(x[1:])):
#             if original_solution.get(key) != reconstructed_solution.get(key):
#                 print(f"  Mismatch on line {key}:")
#                 print(f"    Original:      '{original_solution.get(key)}'")
#                 print(f"    Reconstructed: '{reconstructed_solution.get(key)}'")

#     # --- 6. Error Generation ---
#     print("\n--- [6/9] Testing Error Generation Controller ---")
#     error_details = generate_computational_error(solve_function, correct_trace, get_target_variables(logical_steps))
#     if error_details:
#         print(f" -> SUCCESS: Generated error for variable '{error_details['variable']}'.")
#     else:
#         print(" -> FAILED: Could not generate error details. Rerunning in debug mode:")
#         generate_computational_error(solve_function, correct_trace, get_target_variables(logical_steps), debug=True)
#         return

#     # --- 7. Flawed Trace & Validator ---
#     print("\n--- [7/9] Testing Flawed Trace & Validator ---")
#     flawed_trace = generate_flawed_trace(solve_function, error_details)
#     if not flawed_trace:
#         print(" -> FAILED: generate_flawed_trace() returned None.")
#         return
#     if not is_trace_valid(flawed_trace, correct_trace):
#         print(" -> FAILED: Validator incorrectly flagged the generated trace as invalid.")
#         print("\n--- CORRECT TRACE ---")
#         print(json.dumps(correct_trace, indent=2))
#         print("\n--- FLAWED TRACE ---")
#         print(json.dumps(flawed_trace, indent=2))
#         return
#     else:
#         print(" -> SUCCESS: Flawed trace generated and validated.")

#     # --- 8. Final Artifact Assembly ---
#     print("\n--- [8/9] Testing Final Artifact Assembly ---")
#     artifacts = generate_training_artifacts_enhanced(logical_steps, error_details, flawed_trace)
#     if artifacts and len(artifacts) == 2:
#         print(" -> SUCCESS: Final artifacts assembled.")
#     else:
#         print(" -> FAILED: Could not assemble final artifacts.")
#         return

#     # --- 9. Full Orchestrator ---
#     print("\n--- [9/9] Testing Full Orchestrator ---")
#     random.seed() # Reset seed for an independent test
#     final_solution, final_label = create_single_error_example(solve_function, logical_steps, correct_trace)
    
#     if not (final_solution and final_label):
#         print(" -> FAILED: The main orchestrator could not produce a valid example.")
#     else:
#         print(" -> SUCCESS: The main orchestrator produced a valid, complete example.")
        
#         # --- NEW SUCCESS REPORT BLOCK ---
#         print("\n" + "="*20 + " SUCCESS REPORT " + "="*20)
        
#         # 1. Get Question
#         question_text = GSM8K_TRAIN[index].get("question", "Question not found.")
#         print(f"\n--- INDEX: {index} ---")
#         print(f"--- QUESTION ---\n{question_text}")
        
#         # 2. Get Correct Solution
#         original_solution = build_solution_mapping(index=index)
#         print("\n--- CORRECT SOLUTION ---")
#         for line_num in sorted(original_solution.keys(), key=lambda x: int(x[1:])):
#             print(f"{line_num}: {original_solution[line_num]}")
            
#         # 3. Format and Display Flawed Solution
#         flawed_solution_map = {f"L{i+1}": line for i, line in enumerate(final_solution.splitlines())}
#         print("\n--- FLAWED SOLUTION (GENERATED) ---")
#         for line_num in sorted(flawed_solution_map.keys(), key=lambda x: int(x[1:])):
#             print(f"{line_num}: {flawed_solution_map[line_num]}")
            
#         # 4. Display Final Label
#         print("\n--- TARGET JSON LABEL (GENERATED) ---")
#         print(json.dumps(final_label, indent=2))
        
#         print("\n" + "="*56)
#         # --- END OF NEW BLOCK ---
        
#     print("\n" + "="*50)
#     print("INTEGRATION TEST COMPLETE")
#     print("="*50)


# # --- Run the test ---
# try:
#     # Use a seed that is known to work to test the success path
#     random.seed(2025) 
#     run_integration_test(tier="tier2", index=2918)
# except Exception as e:
#     import traceback
#     print(f"\nAN UNCAUGHT ERROR OCCURRED: {e}")
#     traceback.print_exc()

In [54]:
# import json
# import random

# # It's good practice to use a separate directory for test outputs
# # to avoid polluting your main processed data directories.
# TEST_PROCESSED_DIR = DATA_DIR / 'integration_test_processed'

# def run_integration_test(tier: str, index: int):
#     """
#     Runs a full integration test with verbose failure reporting.
#     """
#     print("="*50)
#     print(f"STARTING INTEGRATION TEST FOR: Tier '{tier}', Index '{index}'")
#     print("="*50)

#     # --- 1. Raw Manifest Loading ---
#     print("\n--- [1/9] Testing Raw Manifest Loading ---")
#     manifest_str = load_manifest_as_string(VALIDATED_MANIFEST_DIR, index)
#     if not manifest_str:
#         print(f" -> FAILED: Could not load raw manifest from path: {VALIDATED_MANIFEST_DIR / f'_{index}.json'}")
#         return
#     print(" -> SUCCESS: Raw manifest loaded as string.")

#     # --- 2. Saving Components ---
#     print("\n--- [2/9] Testing Save Manifest Components ---")
#     save_manifest_components(manifest_str, tier, index, base_output_dir=TEST_PROCESSED_DIR)
#     expected_py_path = TEST_PROCESSED_DIR / tier / str(index) / "solve.py"
#     if not expected_py_path.exists():
#         print(f" -> FAILED: Component file not created at {expected_py_path}")
#         return
#     print(f" -> SUCCESS: Components saved to '{TEST_PROCESSED_DIR / tier / str(index)}'.")

#     # --- 3. Processed Data Loading ---
#     print("\n--- [3/9] Testing Processed Data Loading ---")
#     solve_module = load_function_module(tier, index, base_dir=TEST_PROCESSED_DIR)
#     logical_steps = load_logical_steps(tier, index, base_dir=TEST_PROCESSED_DIR)
#     if not (solve_module and logical_steps):
#         print(" -> FAILED: Could not load processed components from test directory.")
#         return
#     solve_function = solve_module.solve
#     print(" -> SUCCESS: Module and logical steps loaded.")

#     # --- 4. Correct Trace Generation ---
#     print("\n--- [4/9] Testing Correct Trace Generation ---")
#     correct_trace = execution_trace(solve_function)
#     if not correct_trace:
#         print(" -> FAILED: execution_trace() returned None.")
#         print("   Function Source Code:")
#         print(inspect.getsource(solve_function))
#         return
#     print(" -> SUCCESS: Correct trace generated.")

#     # --- 5. Solution Reconstruction ---
#     print("\n--- [5/9] Testing Solution Reconstruction ---")
#     reconstructed_solution = reconstruct_solution_lines_enhanced(logical_steps, correct_trace)
#     original_solution = build_solution_mapping(index=index)
#     if reconstructed_solution == original_solution:
#         print(" -> SUCCESS: Solution correctly reconstructed.")
#     else:
#         print(" -> FAILED: Reconstructed solution does not match original.")
#         print("\n--- ORIGINAL SOLUTION ---")
#         print(json.dumps(original_solution, indent=2))
#         print("\n--- RECONSTRUCTED SOLUTION ---")
#         print(json.dumps(reconstructed_solution, indent=2))
#         print("\n--- MISMATCH DETAILS ---")
#         all_keys = set(original_solution.keys()) | set(reconstructed_solution.keys())
#         for key in sorted(list(all_keys)):
#             if original_solution.get(key) != reconstructed_solution.get(key):
#                 print(f"  Mismatch on line {key}:")
#                 print(f"    Original:      '{original_solution.get(key)}'")
#                 print(f"    Reconstructed: '{reconstructed_solution.get(key)}'")

#     # --- 6. Error Generation ---
#     print("\n--- [6/9] Testing Error Generation Controller ---")
#     error_details = generate_computational_error(solve_function, correct_trace, get_target_variables(logical_steps))
#     if error_details:
#         print(f" -> SUCCESS: Generated error for variable '{error_details['variable']}'.")
#     else:
#         print(" -> FAILED: Could not generate error details. Rerunning in debug mode:")
#         generate_computational_error(solve_function, correct_trace, get_target_variables(logical_steps), debug=True)
#         return

#     # --- 7. Flawed Trace & Validator ---
#     print("\n--- [7/9] Testing Flawed Trace & Validator ---")
#     flawed_trace = generate_flawed_trace(solve_function, error_details)
#     if not flawed_trace:
#         print(" -> FAILED: generate_flawed_trace() returned None.")
#         return
#     if not is_trace_valid(flawed_trace, correct_trace):
#         print(" -> FAILED: Validator incorrectly flagged the generated trace as invalid.")
#         print("   Error details:", json.dumps(error_details, indent=2))
#     else:
#         print(" -> SUCCESS: Flawed trace generated and validated.")

#     # --- 8. Final Artifact Assembly ---
#     print("\n--- [8/9] Testing Final Artifact Assembly ---")
#     artifacts = generate_training_artifacts_enhanced(logical_steps, error_details, flawed_trace)
#     if artifacts and len(artifacts) == 2:
#         print(" -> SUCCESS: Final artifacts assembled.")
#     else:
#         print(" -> FAILED: Could not assemble final artifacts.")
#         return

#     # --- 9. Full Orchestrator ---
#     print("\n--- [9/9] Testing Full Orchestrator ---")
#     random.seed()
#     final_solution, final_label = create_single_error_example(solve_function, logical_steps, correct_trace)
#     if not (final_solution and final_label):
#         print(" -> FAILED: The main orchestrator could not produce a valid example.")
#     else:
#         print(" -> SUCCESS: The main orchestrator produced a valid, complete example.")
        
#     print("\n" + "="*50)
#     print("INTEGRATION TEST COMPLETE (up to last successful step)")
#     print("="*50)


# # --- Run the test on the previously failing example ---
# try:
#     random.seed(2024)
#     run_integration_test(tier="tier2", index=2918)
# except Exception as e:
#     import traceback
#     print(f"\nAN UNCAUGHT ERROR OCCURRED: {e}")
#     traceback.print_exc()

In [55]:
# solve_function = solve_module.solve
# targets = get_target_variables(logical_steps)

# new_error_details = generate_computational_error(solve_function, correct_trace, targets)
# if new_error_details:
#     print(json.dumps(new_error_details, indent=2))
    
# flawed_trace = generate_flawed_trace(solve_function, new_error_details) # type: ignore
# flawed_solution, target_json_label = generate_training_artifacts_enhanced(
#     logical_steps, new_error_details, flawed_trace # type: ignore
#     ) # type: ignore

# print("\n--- ORIGINAL SOLUTION ---")
# original_solution = build_solution_mapping(index=example_index, dataset=GSM8K_TRAIN)
# print(json.dumps(original_solution, indent=2))

# print("\n--- FLAWED SOLUTION ---")
# print(flawed_solution)
# print("\n--- TARGET JSON LABEL ---")
# print(json.dumps(target_json_label, indent=2))


In [56]:
def has_computational_division(solution_text: str) -> bool:
    """Returns True if a '/' is followed by optional whitespace and then a digit."""
    pattern = re.compile(r'/\s*\d')
    return bool(pattern.search(solution_text))

def has_float(solution_text: str) -> bool:
    """Returns True if the solution text contains a floating-point number."""
    pattern = re.compile(r'(?<!\d)\.\d+|\d+\.\d+')
    return bool(pattern.search(solution_text))

def is_symbolic(solution_text: str) -> bool:
    """Returns True if the solution contains a symbolic reasoning line (Let @ ...)."""
    pattern = re.compile(r'^Let [a-zA-Z] ', re.MULTILINE)
    return bool(pattern.search(solution_text))

def mutually_disjoint_tiers(dataset):
    tiers = {}
    symbolic_set = set(
        idx for idx, sample in enumerate(dataset)
        if is_symbolic(sample.get("answer", ""))
    )
    non_symbolic_indices = [
        idx for idx in range(len(dataset)) if idx not in symbolic_set
    ]

    # Tier 1: Only integer arithmetic (no floats, no computational division)
    tiers["tier1"] = sorted([
        idx for idx in non_symbolic_indices
        if not has_float(dataset[idx].get("answer", "")) and not has_computational_division(dataset[idx].get("answer", ""))
    ])

    # Tier 2: Float arithmetic, no computational division
    tiers["tier2"] = sorted([
        idx for idx in non_symbolic_indices
        if has_float(dataset[idx].get("answer", "")) and not has_computational_division(dataset[idx].get("answer", ""))
    ])

    # Tier 3: Computational division, no floats
    tiers["tier3"] = sorted([
        idx for idx in non_symbolic_indices
        if not has_float(dataset[idx].get("answer", "")) and has_computational_division(dataset[idx].get("answer", ""))
    ])

    # Tier 4: Both floats and computational division
    tiers["tier4"] = sorted([
        idx for idx in non_symbolic_indices
        if has_float(dataset[idx].get("answer", "")) and has_computational_division(dataset[idx].get("answer", ""))
    ])

    # Tier 5: Symbolic reasoning (Let @ ...)
    tiers["tier5"] = sorted(symbolic_set)

    return tiers

TIER_LISTS = mutually_disjoint_tiers(GSM8K_TRAIN)

# Display the number of samples in each tier
for tier, indices in TIER_LISTS.items():
    print(f"{tier:<10}: {len(indices)} samples")
print(f"{'Total':<10}: {len(GSM8K_TRAIN)} samples")

def indices_tier2_with_percent(dataset, tier2_indices):
    """
    Returns indices in tier2_indices where the question or answer contains a '%' sign.
    """
    result = []
    for idx in tier2_indices:
        sample = dataset[idx]
        if '%' in sample.get("question", "") or '%' in sample.get("answer", ""):
            result.append(idx)
    return result

# Example usage:
# Assuming you have GSM8K_TRAIN loaded and tier2_indices already computed:
tier2_indices = mutually_disjoint_tiers(GSM8K_TRAIN)["tier2"]
percent_indices = indices_tier2_with_percent(GSM8K_TRAIN, tier2_indices)

len(percent_indices)

def indices_with_float_dot_zero(dataset, indices):
    """
    Returns indices in the provided list where the answer contains a float ending in '.0', '.00', '.000', etc.
    Matches both '12.0', '12.00', '12.000', ... and '.00', '.000', etc.
    """
    pattern = re.compile(r'(?<!\d)(\d+)?\.0{1,}\b')
    result = []
    for idx in indices:
        answer = dataset[idx].get("answer", "")
        if pattern.search(answer):
            result.append(idx)
    return result

# Example usage (rest unchanged):
tier2_float_dot_zero = indices_with_float_dot_zero(GSM8K_TRAIN, tier2_indices)
tier4_float_dot_zero = indices_with_float_dot_zero(GSM8K_TRAIN, tier4_indices)

all_float_dot_zero = sorted(set(tier2_float_dot_zero + tier4_float_dot_zero))

print(f"Tier 2 with float ending in .0...: {len(tier2_float_dot_zero)}")
print(f"Tier 4 with float ending in .0...: {len(tier4_float_dot_zero)}")
print(f"Total (tiers 2+4) with float ending in .0...: {len(all_float_dot_zero)}")

for idx in all_float_dot_zero:
    print(f"Index: {idx},\nAnswer: {GSM8K_TRAIN[idx]['answer']}")
    print("-" * 40)

tier1     : 2767 samples
tier2     : 837 samples
tier3     : 3113 samples
tier4     : 544 samples
tier5     : 212 samples
Total     : 7473 samples


NameError: name 'tier4_indices' is not defined

In [None]:
import re

def find_non_money_dot_double_zero_lines(dataset, indices):
    """
    For each answer, loops through lines, removes calculator annotations,
    and finds floats ending in '.00' where the first preceding non-whitespace character is not '$'
    and the number is not part of a comma-formatted money value (e.g., $2,000.00).
    Returns a list of (idx, line_num, match) for each such occurrence.
    """
    # Match floats ending in .00, not preceded by a comma and not part of $X,XXX.00
    float_pattern = re.compile(r'(?<![\d,])(\d+)\.00\b')
    calc_ann_pattern = re.compile(r'<<.*?>>')
    results = []
    for idx in indices:
        answer = dataset[idx].get("answer", "")
        for line_num, line in enumerate(answer.splitlines(), 1):
            # Remove calculator annotations
            clean_line = calc_ann_pattern.sub('', line)
            for match in float_pattern.finditer(clean_line):
                start = match.start()
                # Look backwards for first non-whitespace character before the match
                prefix = clean_line[:start].rstrip()
                if not prefix or prefix[-1] != '$':
                    results.append((idx, line_num, match.group()))
    return results

# Example usage:
non_money_dot_double_zero_lines = find_non_money_dot_double_zero_lines(GSM8K_TRAIN, all_float_dot_zero)
for idx, line_num, val in non_money_dot_double_zero_lines:
    print(f"Index: {idx}, Line: {line_num}, Float: {val}, LineText: {GSM8K_TRAIN[idx]['answer'].splitlines()[line_num-1]}")

Index: 178, Line: 3, Float: 5.00, LineText: All total he spends 12.50+12.50+5.00 = $<<12.50+12.50+5.00=30.00>>30.00
Index: 2289, Line: 1, Float: 13.00, LineText: Book 1 was $12.50 and book 2 was $15.00 so 13.00 +15.00 = $28.00
Index: 2289, Line: 1, Float: 15.00, LineText: Book 1 was $12.50 and book 2 was $15.00 so 13.00 +15.00 = $28.00
Index: 2289, Line: 2, Float: 28.00, LineText: He was buying 4 books, so these 2 were 25% off so 28.00 * .25 = $7.00
Index: 2317, Line: 3, Float: 35.00, LineText: When you add all the meals together they spent 25.50+35.00+27.00+8.50+24.00 = $<<25.50+35.00+27.00+8.50+24.00=120.00>>120.00
Index: 2317, Line: 3, Float: 27.00, LineText: When you add all the meals together they spent 25.50+35.00+27.00+8.50+24.00 = $<<25.50+35.00+27.00+8.50+24.00=120.00>>120.00
Index: 2317, Line: 3, Float: 24.00, LineText: When you add all the meals together they spent 25.50+35.00+27.00+8.50+24.00 = $<<25.50+35.00+27.00+8.50+24.00=120.00>>120.00
Index: 2317, Line: 4, Float: 120.

In [None]:
print(build_solution_mapping(index=2567, dataset=GSM8K_TRAIN))

{'L1': 'From the carwash, Hank donates $100 * 0.90 = $<<100*0.90=90>>90', 'L2': 'From the bake sale, Hank donates $80 * 0.75 = $<<80*0.75=60>>60', 'L3': 'From mowing lawns, Hank donates $50 * 1.00 = $<<50*1=50>>50', 'L4': 'Hank donates a total of $90 + $60 + $50 = $<<90+60+50=200>>200'}


In [None]:
import re

def indices_with_float_between_0_and_1(dataset, indices=None):
    """
    Returns indices of samples where the answer contains a float strictly between 0 and 1 (e.g., 0.25, .5, 0.999).
    If indices is None, checks all samples in the dataset.
    """
    # Matches floats like 0.25, .5, 0.999, but not 0.0 or 1.0
    float_pattern = re.compile(r'(?<!\d)(0?\.\d+)\b')
    result = []
    search_indices = indices if indices is not None else range(len(dataset))
    for idx in search_indices:
        answer = dataset[idx].get("answer", "")
        for match in float_pattern.finditer(answer):
            try:
                val = float(match.group(1))
                if 0 < val < 1:
                    result.append(idx)
                    break  # Only need one match per sample
            except ValueError:
                continue
    return result

# Example usage:
floats_between_0_and_1 = indices_with_float_between_0_and_1(GSM8K_TRAIN)
print(f"Number of samples with float in (0,1): {len(floats_between_0_and_1)}")
for idx in floats_between_0_and_1:
    print(f"Index: {idx}")
    print()
    print(f"Question: {GSM8K_TRAIN[idx]['question']}")
    print()
    print(f"Answer: {GSM8K_TRAIN[idx]['answer']}")
    print("-" * 40)

Number of samples with float in (0,1): 924
Index: 1

Question: Weng earns $12 an hour for babysitting. Yesterday, she just did 50 minutes of babysitting. How much did she earn?

Answer: Weng earns 12/60 = $<<12/60=0.2>>0.2 per minute.
Working 50 minutes, she earned 0.2 x 50 = $<<0.2*50=10>>10.
#### 10
----------------------------------------
Index: 9

Question: Tina makes $18.00 an hour.  If she works more than 8 hours per shift, she is eligible for overtime, which is paid by your hourly wage + 1/2 your hourly wage.  If she works 10 hours every day for 5 days, how much money does she make?

Answer: She works 8 hours a day for $18 per hour so she makes 8*18 = $<<8*18=144.00>>144.00 per 8-hour shift
She works 10 hours a day and anything over 8 hours is eligible for overtime, so she gets 10-8 = <<10-8=2>>2 hours of overtime
Overtime is calculated as time and a half so and she makes $18/hour so her overtime pay is 18*.5 = $<<18*.5=9.00>>9.00
Her overtime pay is 18+9 = $<<18+9=27.00>>27.00


In [None]:
import re

def indices_with_float_at_least_3_decimal(dataset, indices=None):
    """
    Returns indices of samples where the answer contains a float with at least 3 decimal places (e.g., 0.123, 12.3456).
    If indices is None, checks all samples in the dataset.
    """
    # Matches floats like 0.123, 12.345, .789, etc. (at least 3 digits after the decimal)
    float_pattern = re.compile(r'(?<!\d)(\d*\.\d{3,})\b')
    result = []
    search_indices = indices if indices is not None else range(len(dataset))
    for idx in search_indices:
        answer = dataset[idx].get("answer", "")
        if float_pattern.search(answer):
            result.append(idx)
    return result

# Example usage:
float_3dp_indices = indices_with_float_at_least_3_decimal(GSM8K_TRAIN)
print(f"Number of samples with float of at least 3 decimal places: {len(float_3dp_indices)}")
for idx in float_3dp_indices:
    print(f"Index: {idx}")
    print()
    print(f"Question: {GSM8K_TRAIN[idx]['question']}")
    print()
    print(f"Answer: {GSM8K_TRAIN[idx]['answer']}")
    print("-" * 40)

Number of samples with float of at least 3 decimal places: 15
Index: 1136

Question: Emily loves to have pets and for that reason, she has 4 dogs in her home. Each one eats 250 grams of food per day. She has to go on vacation for 14 days. How many kilograms of food should she buy for her 4 dogs so they don't starve while she is out?

Answer: Each dog would eat 250 grams so 4 will would eat 4 x 250 grams = <<4*250=1000>>1000 grams of food per day.
1.000 grams is equal a 1 kilogram.
Emily is going on vacation for 14 days and with the 4 dogs together eating 1 kilogram of food per day, 14 days x 1 kg of food/day = <<14*1=14>>14 kg of food would be enough for two weeks.
#### 14
----------------------------------------
Index: 1706

Question: There are 400 students in the senior class at East High School. 52% of the students play sports. Of the students that play sports, 12.5% play soccer. How many students play soccer?

Answer: The number of students that play sports is 400 * 0.52 = <<400*0.

In [None]:
import re

def indices_with_dot_zero_one(dataset, indices=None):
    """
    Returns indices of samples where the answer contains '.01' or '0.01'.
    If indices is None, checks all samples in the dataset.
    """
    pattern = re.compile(r'(?<!\d)0?\.01\b')
    result = []
    search_indices = indices if indices is not None else range(len(dataset))
    for idx in search_indices:
        answer = dataset[idx].get("answer", "")
        if pattern.search(answer):
            result.append(idx)
    return result

# Example usage:
dot_zero_one_indices = indices_with_dot_zero_one(GSM8K_TRAIN)
print(f"Number of samples with '.01' or '0.01': {len(dot_zero_one_indices)}")
for idx in dot_zero_one_indices:
    print(f"Index: {idx}")
    print()
    print(f"Question: {GSM8K_TRAIN[idx]['question']}")
    print()
    print(f"Answer: {GSM8K_TRAIN[idx]['answer']}")
    print("-" * 40)

Number of samples with '.01' or '0.01': 143
Index: 26

Question: Jack is stranded on a desert island. He wants some salt to season his fish. He collects 2 liters of seawater in an old bucket. If the water is 20% salt, how many ml of salt will Jack get when all the water evaporates?

Answer: First find how many liters of the seawater are salt: 2 liters * 20% = <<2*20*.01=.4>>.4 liters
Then multiply that amount by 1000 ml/liter to find the number of ml of salt Jack gets: .4 liters * 1000 ml/liter = <<.4*1000=400>>400 ml
#### 400
----------------------------------------
Index: 30

Question: Ann, Bill, Cate, and Dale each buy personal pan pizzas cut into 4 pieces. If Bill and Dale eat 50% of their pizzas and Ann and Cate eat 75% of the pizzas, how many pizza pieces are left uneaten?

Answer: In total, there are 4 x 4 = <<4*4=16>>16 pizza pieces.
Bill and Dale eat 2 x 4 x 50% = <<2*4*50*.01=4>>4 pieces.
Ann and Cate eat 2 x 4 x 75% = <<2*4*75*.01=6>>6 pieces.
The four of them eat 4 + 6 = <<

In [57]:
import re

def indices_with_fraction_multiplication(dataset, indices=None):
    """
    Returns indices of samples where the answer contains a multiplication of the form:
    num * fraction or fraction * num, allowing for optional parentheses and spaces.
    Example matches: 5 * 2/3, (5) * (2/3), 2/3 * 5, (2/3) * (5), etc.
    """
    # Pattern for a number (integer or decimal) with optional parentheses
    num = r'\(?\s*\d+(?:\.\d+)?\s*\)?'
    # Pattern for a fraction with optional parentheses
    frac = r'\(?\s*\d+\s*/\s*\d+\s*\)?'
    # Allow for spaces and * in between
    pattern = re.compile(
        rf'{num}\s*\*\s*{frac}|{frac}\s*\*\s*{num}'
    )
    result = []
    search_indices = indices if indices is not None else range(len(dataset))
    for idx in search_indices:
        answer = dataset[idx].get("answer", "")
        if pattern.search(answer):
            result.append(idx)
    return result

# Example usage:
fraction_mult_indices = indices_with_fraction_multiplication(GSM8K_TRAIN)
print(f"Number of samples with num * fraction or fraction * num: {len(fraction_mult_indices)}")
for idx in fraction_mult_indices:
    print(f"Index: {idx}")
    print()
    print(f"Question: {GSM8K_TRAIN[idx]['question']}")
    print()
    print(f"Answer: {GSM8K_TRAIN[idx]['answer']}")
    print("-"* 40)

Number of samples with num * fraction or fraction * num: 892
Index: 5

Question: Mark has a garden with flowers. He planted plants of three different colors in it. Ten of them are yellow, and there are 80% more of those in purple. There are only 25% as many green flowers as there are yellow and purple flowers. How many flowers does Mark have in his garden?

Answer: There are 80/100 * 10 = <<80/100*10=8>>8 more purple flowers than yellow flowers.
So in Mark's garden, there are 10 + 8 = <<10+8=18>>18 purple flowers.
Purple and yellow flowers sum up to 10 + 18 = <<10+18=28>>28 flowers.
That means in Mark's garden there are 25/100 * 28 = <<25/100*28=7>>7 green flowers.
So in total Mark has 28 + 7 = <<28+7=35>>35 plants in his garden.
#### 35
----------------------------------------
Index: 27

Question: Brennan was researching his school project and had to download files from the internet to his computer to use for reference. After downloading 800 files, he deleted 70% of them because they 

In [59]:
import re

def indices_with_fraction_addition(dataset, indices=None):
    """
    Returns indices of samples where the answer contains a sum of two fractions,
    allowing for optional parentheses and spaces.
    Example matches: 1/2 + 3/4, (1/2) + (3/4), ( 1/2 ) + 3/4, etc.
    """
    # Pattern for a fraction with optional parentheses and spaces
    frac = r'\(?\s*\d+\s*/\s*\d+\s*\)?'
    # Allow for spaces and + in between
    pattern = re.compile(
        rf'{frac}\s*\+\s*{frac}'
    )
    result = []
    search_indices = indices if indices is not None else range(len(dataset))
    for idx in search_indices:
        answer = dataset[idx].get("answer", "")
        if pattern.search(answer):
            result.append(idx)
    return result

# Example usage:
fraction_add_indices = indices_with_fraction_addition(GSM8K_TRAIN)
print(f"Number of samples with fraction + fraction: {len(fraction_add_indices)}")
for idx in fraction_add_indices:
    print(f"Index: {idx}")
    print()
    print(f"Question: {GSM8K_TRAIN[idx]['question']}")
    print()
    print(f"Answer: {GSM8K_TRAIN[idx]['answer']}")
    print("-" * 40)

Number of samples with fraction + fraction: 11
Index: 90

Question: Herman likes to feed the birds in December, January and February.  He feeds them 1/2 cup in the morning and 1/2 cup in the afternoon.  How many cups of food will he need for all three months?

Answer: December has 31 days, January has 31 days and February has 28 days for a total of 31+31+28 = <<31+31+28=90>>90 days
He feeds them 1/2 cup in the morning and 1/2 cup in the afternoon for a total of 1/2+1/2 = <<1/2+1/2=1>>1 cup per day
If he feeds them 1 cup per day for 90 days then he will need 1*90 = <<1*90=90>>90 cups of birdseed
#### 90
----------------------------------------
Index: 299

Question: Basil gets 1/2 of a  dog cookie in the morning and before bed.  She gets 2 whole cookies during the day.  Basil’s cookies are packaged with 45 cookies per box.  How many boxes will she need to last her for 30 days?

Answer: Basil gets 1/2 in the morning and 1/2 at bedtime so she gets 1/2+1/2 = <<1/2+1/2=1>>1 whole cookie
Basi