In [179]:
import json
from pathlib import Path
import ast
import inspect

# 1. Define Project Root and Data Directories
def find_project_root(marker: str = ".git") -> Path:
    """Traverse upwards to find the project root, marked by `marker`."""
    current_path = Path.cwd().resolve()
    while current_path != current_path.parent:
        if (current_path / marker).exists():
            return current_path
        current_path = current_path.parent
    raise FileNotFoundError(f"Could not find project root. Marker '{marker}' not found.")

PROJECT_ROOT = find_project_root()
DATA_DIR = PROJECT_ROOT / 'data'

# 2. Define path to validated manifests (as per your instruction)
VALIDATED_MANIFEST_DIR = DATA_DIR / 'july-7-sample-manifests' / 'tier1'

print(f"Project root: {PROJECT_ROOT}")
print(f"Data directory: {DATA_DIR}")
print(f"Validated Tier 1 manifest directory: {VALIDATED_MANIFEST_DIR}")

# 3. Check if the directory exists to prevent downstream errors
if not VALIDATED_MANIFEST_DIR.is_dir():
    print(f"\nWARNING: Directory not found: {VALIDATED_MANIFEST_DIR}")

Project root: /Users/arvindsuresh/Documents/Github/Erdos-DL-June25-Math
Data directory: /Users/arvindsuresh/Documents/Github/Erdos-DL-June25-Math/data
Validated Tier 1 manifest directory: /Users/arvindsuresh/Documents/Github/Erdos-DL-June25-Math/data/july-7-sample-manifests/tier1


In [180]:
# Specify the index to process
example_index = 6126
example_tier = "tier1"

In [181]:
# 1. Specify the manifest file to load
manifest_filename = f"_{example_index}.json"
manifest_path = VALIDATED_MANIFEST_DIR / manifest_filename

# 2. Load the manifest content into a dictionary
manifest_data = None
try:
    with open(manifest_path, 'r') as f:
        manifest_data = json.load(f)
    print(f"Successfully loaded manifest from: {manifest_path}")
except FileNotFoundError:
    print(f"ERROR: Manifest file not found at {manifest_path}")
    manifest_data = None
except json.JSONDecodeError:
    print(f"ERROR: Could not decode JSON from {manifest_path}")
    manifest_data = None

# 3. Extract main components if load was successful
if manifest_data:
    function_code_string = manifest_data.get("function_code")
    logical_steps_data = manifest_data.get("logical_steps")
    if not function_code_string:
        print("WARNING: 'function_code' key not found in manifest.")

Successfully loaded manifest from: /Users/arvindsuresh/Documents/Github/Erdos-DL-June25-Math/data/july-7-sample-manifests/tier1/_6126.json


In [182]:
# Base directory for all processed outputs
PROCESSED_DATA_DIR = DATA_DIR / 'processed'

def load_manifest_as_string(manifest_dir: Path, index: int) -> str | None:
    """
    Loads a specific manifest file as a string.
    Note: Assumes filename is '{index}.json'. Adjust if format differs.
    """
    file_path = manifest_dir / f"_{index}.json"
    try:
        return file_path.read_text()
    except FileNotFoundError:
        print(f"ERROR: Manifest file not found at {file_path}")
        return None

def save_manifest_components(
    manifest_string: str,
    tier: str,
    index: int,
    base_output_dir: Path
) -> None:
    """
    Parses a manifest string and saves its components to structured directories.
    """
    try:
        manifest_data = json.loads(manifest_string)
    except json.JSONDecodeError:
        print(f"ERROR: Invalid JSON in manifest string for index {index}.")
        return

    function_code = manifest_data.get("function_code")
    logical_steps = manifest_data.get("logical_steps")

    if not function_code or not logical_steps:
        print(f"WARNING: Manifest for index {index} is missing required keys.")
        return

    # 1. Define output directory for the specific index
    output_dir = base_output_dir / tier / str(index)
    output_dir.mkdir(parents=True, exist_ok=True)

    # 2. Save function_code to a .py file
    py_file_path = output_dir / "solve.py"
    py_file_path.write_text(function_code)

    # 3. Save logical_steps to a .json file
    json_file_path = output_dir / "logical_steps.json"
    with open(json_file_path, 'w') as f:
        json.dump(logical_steps, f, indent=2)

In [183]:
# --- Example of how to use the functions ---

# Load
manifest_str = load_manifest_as_string(VALIDATED_MANIFEST_DIR, example_index)

# Process and Save
if manifest_str:
    save_manifest_components(
        manifest_string=manifest_str,
        tier=example_tier,
        index=example_index,
        base_output_dir=PROCESSED_DATA_DIR
    )
    print(f"\nProcessed files for index {example_index} saved under: {PROCESSED_DATA_DIR}")


Processed files for index 6126 saved under: /Users/arvindsuresh/Documents/Github/Erdos-DL-June25-Math/data/processed


In [184]:
import importlib.util
from types import ModuleType

def load_function_module(
    tier: str,
    index: int,
    base_dir: Path = PROCESSED_DATA_DIR
) -> ModuleType | None:
    """
    Dynamically loads the 'solve.py' function for a given tier and index.

    Args:
        tier: The tier of the problem (e.g., 'tier1').
        index: The index of the problem.
        base_dir: The base directory for processed files.

    Returns:
        The loaded module object, or None if the file is not found.
    """
    py_file_path = base_dir / tier / str(index) / "solve.py"
    if not py_file_path.exists():
        print(f"ERROR: Python file not found at {py_file_path}")
        return None

    # Create a unique name for the module to avoid conflicts
    module_name = f"manifests.t{tier}.i{index}.solve"

    spec = importlib.util.spec_from_file_location(module_name, py_file_path)
    if spec and spec.loader:
        module = importlib.util.module_from_spec(spec)
        spec.loader.exec_module(module)
        return module
    else:
        print(f"ERROR: Could not create module spec for {py_file_path}")
        return None


def load_logical_steps(
    tier: str,
    index: int,
    base_dir: Path = PROCESSED_DATA_DIR
) -> list[dict] | None:
    """
    Loads the 'logical_steps.json' for a given tier and index.

    Args:
        tier: The tier of the problem (e.g., 'tier1').
        index: The index of the problem.
        base_dir: The base directory for processed files.

    Returns:
        A list of dictionaries representing the logical steps, or None on error.
    """
    json_file_path = base_dir / tier / str(index) / "logical_steps.json"
    try:
        with open(json_file_path, 'r') as f:
            return json.load(f)
    except FileNotFoundError:
        print(f"ERROR: JSON file not found at {json_file_path}")
        return None
    except json.JSONDecodeError:
        print(f"ERROR: Failed to decode JSON from {json_file_path}")
        return None


In [185]:
# --- Example of how to use the functions ---

# Load the module
solve_module = load_function_module(example_tier, example_index)

if solve_module:
    print(f"Successfully loaded module for {example_tier}, index {example_index}")
    # You can now access the function like this:
    # solve_function = solve_module.solve
    # print(solve_function())

# Load the logical steps
logical_steps = load_logical_steps(example_tier, example_index)

if logical_steps:
    print(f"Successfully loaded {len(logical_steps)} logical steps.")
    # print(logical_steps[0])

Successfully loaded module for tier1, index 6126
Successfully loaded 3 logical steps.


In [186]:
logical_steps

[{'line_number': 'L1',
  'question_inputs': ['pencils_per_day', 'days_per_week'],
  'WK_inputs': [],
  'output_variable': 'pencils_made',
  'solution_line_template': 'Making {pencils_per_day} pencils a day, he made {pencils_per_day} pencils/day * {days_per_week} days/week = <<{pencils_per_day}*{days_per_week}={pencils_made}>>{pencils_made} pencils during the week.'},
 {'line_number': 'L2',
  'question_inputs': ['initial_stock'],
  'WK_inputs': [],
  'output_variable': 'total_available',
  'solution_line_template': 'With the pencils from the stock, he got {pencils_made} pencils + {initial_stock} pencils = <<{pencils_made}+{initial_stock}={total_available}>>{total_available} pencils.'},
 {'line_number': 'L3',
  'question_inputs': ['pencils_sold'],
  'WK_inputs': [],
  'output_variable': 'final_stock',
  'solution_line_template': 'Subtracting the ones he sold, he has {total_available} pencils - {pencils_sold} pencils = <<{total_available}-{pencils_sold}={final_stock}>>{final_stock} pencil

In [187]:
from typing import Any, Dict

def execution_trace(func) -> Dict[str, Any] | None:
    """
    Parses a function's source code and executes it line by line
    to capture the final value of each assigned variable.

    This function is designed for simple, self-contained functions
    like the 'solve()' methods in the manifests.

    Args:
        func: The function object to trace.

    Returns:
        A dictionary mapping variable names to their computed values, or None on error.
    """
    try:
        # 1. Get the source code and parse it into an Abstract Syntax Tree (AST)
        src = inspect.getsource(func)
        tree = ast.parse(src)
    except (TypeError, FileNotFoundError, SyntaxError) as e:
        print(f"ERROR: Could not get or parse source for {func.__name__}: {e}")
        return None

    # 2. Assume the first node in the file is the function definition
    func_def = tree.body[0]
    if not isinstance(func_def, ast.FunctionDef):
        print("ERROR: The parsed source code does not start with a function definition.")
        return None

    # 3. `env` will store the execution environment (variable states)
    env = {}

    # 4. Execute each statement in the function body sequentially
    for stmt in func_def.body:
        # We only process simple assignments (e.g., x = y + z)
        if isinstance(stmt, ast.Assign):
            # Compile and execute the single assignment statement.
            # The result is stored in the `env` dictionary, which acts
            # as the memory for subsequent computations.
            try:
                # ast.Module is required to create a valid code block
                code_obj = compile(ast.Module([stmt], type_ignores=[]), '<string>', 'exec')
                exec(code_obj, {}, env)
            except Exception as e:
                # This helps debug issues in the manifest's function_code
                print(f"ERROR: Failed to execute line: {ast.unparse(stmt)}. Error: {e}")
                return None # Halt on error to ensure trace integrity
        
        # We ignore other node types like `ast.Return`

    return env

In [188]:
# --- Example of how to use the function ---

# 1. Load the module containing the solve function
solve_module = load_function_module(example_tier, example_index)

correct_trace = None
if solve_module:
    # 2. Extract the 'solve' function from the module
    if hasattr(solve_module, 'solve'):
        solve_function = solve_module.solve
        
        # 3. Generate the trace
        print(f"Generating execution trace for index {example_index}...")
        correct_trace = execution_trace(solve_function)
    else:
        print(f"ERROR: 'solve' function not found in module for index {example_index}")

if correct_trace:
    print("\nSuccessfully generated correct execution trace:")
    print(json.dumps(correct_trace, indent=2))

Generating execution trace for index 6126...

Successfully generated correct execution trace:
{
  "pencils_per_day": 100,
  "days_per_week": 5,
  "pencils_made": 500,
  "initial_stock": 80,
  "total_available": 580,
  "pencils_sold": 350,
  "final_stock": 230,
  "answer": 230
}


In [189]:
import re
from datasets import load_dataset

GSM8K_TRAIN = load_dataset("gsm8k", "main")["train"]

def build_solution_mapping(
        index: int, 
        dataset: 'datasets.Dataset' = GSM8K_TRAIN,
        exclude_FA: bool = True
    ):
    """
    Extracts the natural language solution for a given problem index,
    cleans it, and structures it into a line-numbered dictionary.
    """
    solution_mapping = {}
    solution_text = dataset[index]["answer"]
    lines = [ln.strip() for ln in solution_text.splitlines() if ln.strip()]

    # Improved regex to handle commas in the final answer
    if lines and re.match(r"^####\s*[\d\.,]+$", lines[-1]):
        solution_mapping["FA"] = lines.pop(-1).strip()

    for i, line in enumerate(lines, 1):
        solution_mapping[f"L{i}"] = line

    if exclude_FA and "FA" in solution_mapping:
        del solution_mapping["FA"]

    return solution_mapping

def reconstruct_solution_lines(
    logical_steps: list[dict],
    eval_trace: dict[str, Any]
) -> dict[str, str]:
    """
    Substitutes variable placeholders in solution templates with their values from an execution trace.

    This is used to verify that the formalized manifest and the generated trace
    correctly represent the original natural language solution.

    Args:
        logical_steps: The list of dictionaries from 'logical_steps.json'.
        eval_trace: A dictionary mapping variable names to their computed values.

    Returns:
        A dictionary mapping line numbers (e.g., 'L1') to the reconstructed solution lines.
    """
    reconstructed_mapping = {}

    # This regex finds all instances of {variable_name}
    placeholder_pattern = re.compile(r'\{([a-zA-Z0-9_]+)\}')

    for step in logical_steps:
        line_number = step.get("line_number")
        template = step.get("solution_line_template")

        if not line_number or not template:
            print(f"WARNING: Skipping malformed logical step: {step}")
            continue

        # The replacer function is called for each match found by re.sub
        # It looks up the variable name (group 1 of the match) in the trace
        def replacer(match):
            variable_name = match.group(1)
            # Use .get() for safety; provides a default if var is not in trace
            value = eval_trace.get(variable_name)
            if value is None:
                # This indicates a mismatch between the template and the trace
                return f"{{ERROR: {variable_name} not in trace}}"
            return str(value)

        # Substitute all placeholders in the template
        reconstructed_line = placeholder_pattern.sub(replacer, template)
        reconstructed_mapping[line_number] = reconstructed_line

    return reconstructed_mapping

In [190]:
# --- Example of how to use the function ---

# We assume 'logical_steps' and 'correct_trace' were loaded/generated in previous cells
if logical_steps and correct_trace:
    print("Reconstructing solution lines from correct_trace...")
    reconstructed_solution = reconstruct_solution_lines(logical_steps, correct_trace)

    print("\n--- RECONSTRUCTED SOLUTION ---")
    print(json.dumps(reconstructed_solution, indent=2))

# For comparison, let's build the original solution mapping again

print("\n--- ORIGINAL SOLUTION ---")
original_solution = build_solution_mapping(index=example_index, dataset=GSM8K_TRAIN)
print(json.dumps(original_solution, indent=2))

Reconstructing solution lines from correct_trace...

--- RECONSTRUCTED SOLUTION ---
{
  "L1": "Making 100 pencils a day, he made 100 pencils/day * 5 days/week = <<100*5=500>>500 pencils during the week.",
  "L2": "With the pencils from the stock, he got 500 pencils + 80 pencils = <<500+80=580>>580 pencils.",
  "L3": "Subtracting the ones he sold, he has 580 pencils - 350 pencils = <<580-350=230>>230 pencils."
}

--- ORIGINAL SOLUTION ---
{
  "L1": "Making 100 pencils a day, he made 100 pencils/day * 5 days/week = <<100*5=500>>500 pencils during the week.",
  "L2": "With the pencils from the stock, he got 500 pencils + 80 pencils = <<500+80=580>>580 pencils.",
  "L3": "Subtracting the ones he sold, he has 580 pencils - 350 pencils = <<580-350=230>>230 pencils."
}


In [191]:
import random

def get_target_variables(logical_steps: list[dict]) -> list[str]:
    """
    Extracts all 'output_variable' names from the logical steps.
    These are the potential locations for injecting a computational error.
    """
    return [step['output_variable'] for step in logical_steps if 'output_variable' in step]


def generate_off_by_n_error(
    correct_trace: dict[str, Any],
    target_variables: list[str],
    offset_range: tuple[int, int] = (1, 5)
) -> dict[str, Any] | None:
    """
    Selects a random target variable and generates a simple computational error.

    Args:
        correct_trace: The dictionary mapping variables to their correct values.
        target_variables: A list of variable names eligible for modification.
        offset_range: The range of integer offsets to add/subtract (inclusive).

    Returns:
        A dictionary containing the error details:
        {'variable': name, 'correct_value': val, 'flawed_value': new_val}
        or None if no valid error could be generated.
    """
    if not target_variables:
        print("ERROR: No target variables provided.")
        return None

    # 1. Select a random variable to modify
    variable_to_change = random.choice(target_variables)
    correct_value = correct_trace.get(variable_to_change)

    if not isinstance(correct_value, (int, float)):
        print(f"WARNING: Selected variable '{variable_to_change}' is not a number. Skipping.")
        return None

    # 2. Generate a non-zero, random offset
    random.seed(42)  # For reproducibility in this example
    offset = random.randint(offset_range[0], offset_range[1])
    if random.random() < 0.5:  # 50% chance to be negative
        offset = -offset
    
    # 3. Calculate the flawed value
    flawed_value = correct_value + offset

    # 4. Basic sanity check: ensure the new value is not the same as the old
    #    (This can happen if offset is 0, which we've prevented, but it's good practice)
    #    Also ensure we don't accidentally create a negative number for a positive one.
    if flawed_value == correct_value or (correct_value > 0 and flawed_value <= 0):
        # A simple retry mechanism for this edge case
        flawed_value = correct_value + offset + 1

    return {
        "variable": variable_to_change,
        "correct_value": correct_value,
        "flawed_value": flawed_value
    }

In [192]:
# --- Example of how to use the functions ---

if logical_steps and correct_trace:
    # 1. Get the list of all possible variables to make flawed
    targets = get_target_variables(logical_steps)
    print(f"Potential error targets: {targets}")
    
    # 2. Generate a single computational error
    error_details = generate_off_by_n_error(correct_trace, targets)
    
    if error_details:
        print("\n--- GENERATED ERROR DETAILS ---")
        print(json.dumps(error_details, indent=2))

Potential error targets: ['pencils_made', 'total_available', 'final_stock']

--- GENERATED ERROR DETAILS ---
{
  "variable": "final_stock",
  "correct_value": 230,
  "flawed_value": 229
}


In [193]:
import copy

def generate_flawed_trace(
    func,
    error_details: dict[str, Any]
) -> dict[str, Any] | None:
    """
    Generates a new execution trace by injecting a computational error.

    It modifies the function's AST in memory to replace a calculation
    with a hardcoded flawed value, then re-executes the logic to
    propagate the error.

    Args:
        func: The original function object to modify.
        error_details: Dict containing 'variable', 'correct_value', 'flawed_value'.

    Returns:
        A new dictionary representing the flawed execution trace, or None on error.
    """
    try:
        src = inspect.getsource(func)
        tree = ast.parse(src)
    except (TypeError, FileNotFoundError, SyntaxError) as e:
        print(f"ERROR: Could not get or parse source for {func.__name__}: {e}")
        return None

    func_def = tree.body[0]
    if not isinstance(func_def, ast.FunctionDef):
        print("ERROR: Parsed source does not start with a function definition.")
        return None

    variable_to_change = error_details["variable"]
    flawed_value = error_details["flawed_value"]
    
    modified_body = copy.deepcopy(func_def.body)

    node_found_and_modified = False
    for i, node in enumerate(modified_body):
        if isinstance(node, ast.Assign):
            if node.targets[0].id == variable_to_change:
                # --- FIX IS HERE ---
                # Create a new constant node for the flawed value and copy the
                # location info (lineno, col_offset) from the original node.
                new_value_node = ast.copy_location(
                    ast.Constant(value=flawed_value),
                    node.value  # The original RHS node we are replacing
                )
                node.value = new_value_node
                node_found_and_modified = True
                break
    
    if not node_found_and_modified:
        print(f"ERROR: Could not find assignment node for '{variable_to_change}' to modify.")
        return None

    env = {}
    for stmt in modified_body:
        if isinstance(stmt, ast.Assign):
            try:
                code_obj = compile(ast.Module([stmt], type_ignores=[]), '<string>', 'exec')
                exec(code_obj, {}, env)
            except Exception as e:
                print(f"ERROR: Failed to execute modified line: {ast.unparse(stmt)}. Error: {e}")
                return None
    
    return env

In [194]:
# --- Example of how to use the function (run this again) ---

if solve_module and error_details:
    solve_function = solve_module.solve
    
    print("Generating flawed trace by propagating the error...")
    flawed_trace = generate_flawed_trace(solve_function, error_details)

    if flawed_trace:
        print("\n--- FLAWED EXECUTION TRACE ---")
        print(json.dumps(flawed_trace, indent=2))

        print("\n--- COMPARISON ---")
        print(f"Variable Changed: '{error_details['variable']}'")
        print(f"Correct Value:   {error_details['correct_value']}")
        print(f"Flawed Value:    {flawed_trace.get(error_details['variable'])}")

Generating flawed trace by propagating the error...

--- FLAWED EXECUTION TRACE ---
{
  "pencils_per_day": 100,
  "days_per_week": 5,
  "pencils_made": 500,
  "initial_stock": 80,
  "total_available": 580,
  "pencils_sold": 350,
  "final_stock": 229,
  "answer": 229
}

--- COMPARISON ---
Variable Changed: 'final_stock'
Correct Value:   230
Flawed Value:    229


In [195]:
def find_error_line_number(
    variable_name: str,
    logical_steps: list[dict]
) -> str | None:
    """Finds the line number corresponding to a given output variable."""
    for step in logical_steps:
        if step.get("output_variable") == variable_name:
            return step.get("line_number")
    return None

def generate_training_artifacts(
    logical_steps: list[dict],
    error_details: dict[str, Any],
    flawed_trace: dict[str, Any]
) -> tuple[str, dict] | None:
    """
    Generates the final training data: a flawed solution and its JSON label.

    Args:
        logical_steps: The list of dictionaries from 'logical_steps.json'.
        error_details: Dict from `generate_off_by_n_error`.
        flawed_trace: The flawed execution trace from `generate_flawed_trace`.

    Returns:
        A tuple containing:
        - The flawed natural language solution (string).
        - The complete target JSON label (dictionary).
        Returns None if any component is missing.
    """
    # 1. Generate the flawed natural language solution text
    # We re-use the function from Cell 6, but with the flawed trace.
    flawed_solution_map = reconstruct_solution_lines(logical_steps, flawed_trace)
    if not flawed_solution_map:
        print("ERROR: Failed to reconstruct flawed solution lines.")
        return None
    
    # Order the lines correctly before joining
    sorted_lines = sorted(flawed_solution_map.items(), key=lambda item: int(item[0][1:]))
    flawed_nl_solution = "\n".join([line for _, line in sorted_lines])

    # 2. Find the line number where the error was introduced
    erroneous_line = find_error_line_number(error_details["variable"], logical_steps)
    if not erroneous_line:
        print(f"ERROR: Could not find line number for variable '{error_details['variable']}'.")
        return None

    # 3. Create the explanation text for the JSON label
    explanation = (
        f"The result of this computation should be {error_details['correct_value']}, "
        f"not {error_details['flawed_value']}."
    )

    # 4. Assemble the final target JSON object
    target_json = {
        "verdict": "Flawed",
        "error_details": {
            "error_type": "computational_error",
            "erroneous_line_number": erroneous_line,
            "explanation": explanation,
        }
    }

    return flawed_nl_solution, target_json


# --- Example of how to use the function ---

# We assume 'logical_steps', 'error_details', and 'flawed_trace' are available
if logical_steps and error_details and flawed_trace:
    artifacts = generate_training_artifacts(logical_steps, error_details, flawed_trace)
    
    if artifacts:
        flawed_solution, target_json_label = artifacts
        
        print("--- GENERATED FLAWED NL SOLUTION ---")
        print(flawed_solution)
        
        print("\n--- GENERATED TARGET JSON LABEL ---")
        print(json.dumps(target_json_label, indent=2))

--- GENERATED FLAWED NL SOLUTION ---
Making 100 pencils a day, he made 100 pencils/day * 5 days/week = <<100*5=500>>500 pencils during the week.
With the pencils from the stock, he got 500 pencils + 80 pencils = <<500+80=580>>580 pencils.
Subtracting the ones he sold, he has 580 pencils - 350 pencils = <<580-350=229>>229 pencils.

--- GENERATED TARGET JSON LABEL ---
{
  "verdict": "Flawed",
  "error_details": {
    "error_type": "computational_error",
    "erroneous_line_number": "L3",
    "explanation": "The result of this computation should be 230, not 229."
  }
}


In [None]:
import random

# --- Individual Error Generators ---

def generate_off_by_n_error(
    correct_value: int,
    offset_range: tuple[int, int] = (1, 5)
    ):
    """Generates a minor miscalculation error. Assumes input is an integer."""
    offset = random.randint(offset_range[0], offset_range[1])
    if random.random() < 0.5:
        offset = -offset

    flawed_value = correct_value + offset
    if flawed_value == correct_value: # Avoid generating a non-error
        flawed_value += 1

    return {
        "flawed_value": flawed_value,
        "explanation_type": "This appears to be a minor miscalculation."
    }

def generate_off_by_factor_of_10_error(correct_value: int):
    """Generates a dropped/added zero error. Assumes input is a multiple of 100."""
    options = ["divide", "multiply"]
    choice = random.choice(options)
    
    if choice == "divide":
        flawed_value = correct_value // 10
        explanation = "It appears a zero was dropped from the number."
    else: # multiply
        flawed_value = correct_value * 10
        explanation = "It appears an extra zero was added to the number."

    return {
        "flawed_value": flawed_value,
        "explanation_type": explanation
    }

def generate_digit_transposition_error(correct_value: int):
    """Swaps two adjacent digits. Assumes input has at least two distinct adjacent digits."""
    s_val = str(abs(correct_value))
    # Find indices where adjacent digits are different
    possible_indices = [i for i in range(len(s_val) - 1) if s_val[i] != s_val[i+1]]
    idx_to_swap = random.choice(possible_indices)
    
    s_list = list(s_val)
    s_list[idx_to_swap], s_list[idx_to_swap+1] = s_list[idx_to_swap+1], s_list[idx_to_swap]
    
    flawed_value = int("".join(s_list))
    if correct_value < 0:
        flawed_value = -flawed_value

    return {
        "flawed_value": flawed_value,
        "explanation_type": "It appears two adjacent digits were swapped."
    }

# --- NEW ERROR GENERATOR ---
def generate_stem_off_by_n_error(
    correct_value: int,
    offset_range: tuple[int, int] = (1, 3)
) -> dict[str, Any]:
    """
    Applies a small offset to the 'stem' of a number (the part before the final zero).
    Assumes the input is a non-zero multiple of 10.
    """
    stem = correct_value // 10
    
    offset = random.randint(offset_range[0], offset_range[1])
    if random.random() < 0.5:
        offset = -offset

    flawed_stem = stem + offset
    if flawed_stem == stem:
        flawed_stem += 1 # Ensure the value changes

    flawed_value = flawed_stem * 10

    return {
        "flawed_value": flawed_value,
        "explanation_type": "It appears there was a miscalculation with the digits before the final zero."
    }

In [None]:
# import inspect
# import ast

# def has_distinct_adjacent_digits(n: int) -> bool:
#     """Helper to check if a number is suitable for digit transposition."""
#     s = str(abs(n))
#     return len(s) >= 2 and any(s[i] != s[i+1] for i in range(len(s) - 1))

# def get_operator_for_variable(func, variable_name: str) -> str | None:
#     """
#     Inspects the function's AST to find the operator used to compute a variable.
#     """
#     try:
#         src = inspect.getsource(func)
#         tree = ast.parse(src)
#     except (TypeError, FileNotFoundError, SyntaxError):
#         return None

#     for node in ast.walk(tree):
#         if isinstance(node, ast.Assign) and node.targets[0].id == variable_name:
#             if isinstance(node.value, ast.BinOp):
#                 op = node.value.op
#                 if isinstance(op, ast.Add): return "add"
#                 if isinstance(op, ast.Sub): return "sub"
#                 if isinstance(op, ast.Mult): return "mult"
#                 if isinstance(op, ast.Div): return "div"
#             # Can be extended for ast.UnaryOp etc. if needed
#             return "other"
#     return None

# def generate_computational_error(
#     func, # The 'solve' function object is now required
#     correct_trace: dict[str, Any],
#     target_variables: list[str]
# ) -> dict[str, Any] | None:
#     """
#     Selects a random variable and applies a valid, randomly chosen error type
#     based on the arithmetic operator used.
#     """
#     shuffled_targets = random.sample(target_variables, len(target_variables))

#     for variable_name in shuffled_targets:
#         correct_value = correct_trace.get(variable_name)
#         if not isinstance(correct_value, int):
#             continue

#         # 1. Get the operator context for this variable
#         operator = get_operator_for_variable(func, variable_name)

#         # 2. Determine which error types are applicable based on context
#         applicable_generators = []
        
#         # Rule: Off-by-n is only for addition and subtraction.
#         if operator in ["add", "sub"]:
#             applicable_generators.append(generate_off_by_n_error)
        
#         # Rule: Factor-of-10 error for multiples of 100.
#         if correct_value % 100 == 0 and correct_value != 0:
#             applicable_generators.append(generate_off_by_factor_of_10_error)
        
#         # Rule: Digit transposition is generally applicable for multi-digit numbers.
#         if has_distinct_adjacent_digits(correct_value):
#             applicable_generators.append(generate_digit_transposition_error)

#         # 3. If any are applicable, select one and generate the error
#         if applicable_generators:
#             generator_func = random.choice(applicable_generators)
#             error_result = generator_func(correct_value)
            
#             return {
#                 "variable": variable_name,
#                 "correct_value": correct_value,
#                 "flawed_value": error_result["flawed_value"],
#                 "explanation_type": error_result["explanation_type"]
#             }
            
#     print("WARNING: Could not find a suitable variable/error type combination.")
#     return None

In [198]:
# def has_distinct_adjacent_digits(n: int) -> bool:
#     """Helper to check if a number is suitable for digit transposition."""
#     s = str(abs(n))
#     if len(s) < 2:
#         return False
#     return any(s[i] != s[i+1] for i in range(len(s) - 1))

# def generate_computational_error(
#     correct_trace: dict[str, Any],
#     target_variables: list[str]
#     ):
#     """
#     Selects a random variable and applies a valid, randomly chosen error type.
#     This is the main orchestrator for error generation.
#     """
#     # Create a copy to avoid modifying the original list
#     shuffled_targets = random.sample(target_variables, len(target_variables))

#     for variable_name in shuffled_targets:
#         correct_value = correct_trace.get(variable_name)

#         if not isinstance(correct_value, int):
#             continue

#         # 1. Determine which error types are applicable
#         applicable_generators = []
#         if correct_value != 0:
#             applicable_generators.append(generate_off_by_n_error)

#         # New criterion: must be a multiple of 100
#         if correct_value % 100 == 0 and correct_value != 0:
#             applicable_generators.append(generate_off_by_factor_of_10_error)
        
#         if has_distinct_adjacent_digits(correct_value):
#             applicable_generators.append(generate_digit_transposition_error)

#         # 2. If any are applicable, select one and generate the error
#         if applicable_generators:
#             generator_func = random.choice(applicable_generators)
#             error_result = generator_func(correct_value)
            
#             # 3. Combine results into the final error details object
#             return {
#                 "variable": variable_name,
#                 "correct_value": correct_value,
#                 "flawed_value": error_result["flawed_value"],
#                 "explanation_type": error_result["explanation_type"]
#             }
            
#     # Return None if no suitable error could be generated for any target variable
#     print("WARNING: Could not find a suitable variable/error type combination.")
#     return None

# # --- Example of how to use the new controller ---
# if logical_steps and correct_trace:
#     print("\n--- TESTING REFACTORED ERROR GENERATION CONTROLLER ---")
#     targets = get_target_variables(logical_steps)
    
#     # Generate a single computational error using the new controller
#     new_error_details = generate_computational_error(correct_trace, targets)
    
#     if new_error_details:
#         print("Generated a random error with the new controller:")
#         print(json.dumps(new_error_details, indent=2))

In [None]:
import inspect
import ast

def has_distinct_adjacent_digits(n: int) -> bool:
    """Helper to check if a number is suitable for digit transposition."""
    s = str(abs(n))
    return len(s) >= 2 and any(s[i] != s[i+1] for i in range(len(s) - 1))

def get_operator_for_variable(func, variable_name: str) -> str | None:
    """Inspects the AST to find the operator used to compute a variable."""
    # (implementation from previous response is unchanged)
    try:
        src = inspect.getsource(func)
        tree = ast.parse(src)
    except (TypeError, FileNotFoundError, SyntaxError): return None
    for node in ast.walk(tree):
        if isinstance(node, ast.Assign) and node.targets[0].id == variable_name:
            if isinstance(node.value, ast.BinOp):
                op = node.value.op
                if isinstance(op, ast.Add): return "add"
                if isinstance(op, ast.Sub): return "sub"
                if isinstance(op, ast.Mult): return "mult"
                if isinstance(op, ast.Div): return "div"
            return "other"
    return None

def get_operand_names_for_variable(func, variable_name: str) -> list[str]:
    """
    Inspects the AST to find the names of variables used as operands
    in the calculation for the given target variable.
    """
    operand_names = []
    try:
        src = inspect.getsource(func)
        tree = ast.parse(src)
    except (TypeError, FileNotFoundError, SyntaxError):
        return []

    for node in ast.walk(tree):
        if isinstance(node, ast.Assign) and node.targets[0].id == variable_name:
            # We walk the right-hand side of the assignment
            for sub_node in ast.walk(node.value):
                # And collect the names of all variables used
                if isinstance(sub_node, ast.Name):
                    operand_names.append(sub_node.id)
            return list(set(operand_names)) # Use set to get unique names
    return []

# Helper functions like get_operator_for_variable remain unchanged.

def generate_computational_error(
    func,
    correct_trace: dict[str, Any],
    target_variables: list[str]
) -> dict[str, Any] | None:
    """
    Selects a random variable and applies a valid, realistic error type.
    """
    shuffled_targets = random.sample(target_variables, len(target_variables))

    for variable_name in shuffled_targets:
        correct_value = correct_trace.get(variable_name)
        if not isinstance(correct_value, int) or correct_value == 0:
            continue

        operator = get_operator_for_variable(func, variable_name)
        applicable_generators = []

        # --- Enhanced Contextual Logic ---
        
        # Rule 1: Handle addition and subtraction with special cases
        if operator in ["add", "sub"]:
            operand_names = get_operand_names_for_variable(func, variable_name)
            operand_values = [correct_trace.get(name) for name in operand_names]
            
            # Check if all operands are integers ending in zero
            all_end_in_zero = all(
                isinstance(v, int) and v % 10 == 0 for v in operand_values
            ) if operand_values else False

            # If they all end in zero, the "stem" error is more realistic.
            if all_end_in_zero and correct_value % 10 == 0:
                applicable_generators.append(generate_stem_off_by_n_error)
            # Otherwise, the standard off-by-n is fine.
            else:
                applicable_generators.append(generate_off_by_n_error)

        # Other rules remain the same
        if correct_value % 100 == 0:
            applicable_generators.append(generate_off_by_factor_of_10_error)
        
        if has_distinct_adjacent_digits(correct_value):
            applicable_generators.append(generate_digit_transposition_error)

        # --- End of Enhanced Logic ---

        if applicable_generators:
            generator_func = random.choice(applicable_generators)
            error_result = generator_func(correct_value)
            
            return {
                "variable": variable_name,
                "correct_value": correct_value,
                "flawed_value": error_result["flawed_value"],
                "explanation_type": error_result["explanation_type"]
            }
            
    print("WARNING: Could not find a suitable variable/error type combination.")
    return None

In [199]:
def generate_training_artifacts_enhanced(
    logical_steps: list[dict],
    error_details: dict[str, Any],
    flawed_trace: dict[str, Any]
    ):
    """
    Generates the final training data with an enhanced explanation.
    """
    flawed_solution_map = reconstruct_solution_lines(logical_steps, flawed_trace)
    if not flawed_solution_map:
        return None
    
    sorted_lines = sorted(flawed_solution_map.items(), key=lambda item: int(item[0][1:]))
    flawed_nl_solution = "\n".join([line for _, line in sorted_lines])

    erroneous_line = find_error_line_number(error_details["variable"], logical_steps)
    if not erroneous_line:
        return None

    # --- Create the enhanced explanation ---
    base_explanation = (
        f"The result of this computation should be {error_details['correct_value']}, "
        f"not {error_details['flawed_value']}."
    )
    type_explanation = error_details["explanation_type"]
    final_explanation = f"{base_explanation} {type_explanation}"

    target_json = {
        "verdict": "Flawed",
        "error_details": {
            "error_type": "computational_error",
            "erroneous_line_number": erroneous_line,
            "explanation": final_explanation,
        }
    }

    return flawed_nl_solution, target_json

# You can now re-run the full pipeline:
# 1. `new_error_details = generate_computational_error(...)`
# 2. `flawed_trace = generate_flawed_trace(..., new_error_details)`
# 3. `flawed_solution, label = generate_training_artifacts_enhanced(..., new_error_details, flawed_trace)`

In [200]:
# --- Example of use ---
# # Assume solve_module, logical_steps, correct_trace are loaded
solve_function = solve_module.solve
targets = get_target_variables(logical_steps)

new_error_details = generate_computational_error(solve_function, correct_trace, targets)
if new_error_details:
    print(json.dumps(new_error_details, indent=2))
    
flawed_trace = generate_flawed_trace(solve_function, new_error_details) # type: ignore
flawed_solution, target_json_label = generate_training_artifacts_enhanced(
    logical_steps, new_error_details, flawed_trace # type: ignore
    ) # type: ignore

print("\n--- ORIGINAL SOLUTION ---")
original_solution = build_solution_mapping(index=example_index, dataset=GSM8K_TRAIN)
print(json.dumps(original_solution, indent=2))

print("\n--- FLAWED SOLUTION ---")
print(flawed_solution)
print("\n--- TARGET JSON LABEL ---")
print(json.dumps(target_json_label, indent=2))


{
  "variable": "total_available",
  "correct_value": 580,
  "flawed_value": 581,
  "explanation_type": "This appears to be a minor miscalculation."
}

--- ORIGINAL SOLUTION ---
{
  "L1": "Making 100 pencils a day, he made 100 pencils/day * 5 days/week = <<100*5=500>>500 pencils during the week.",
  "L2": "With the pencils from the stock, he got 500 pencils + 80 pencils = <<500+80=580>>580 pencils.",
  "L3": "Subtracting the ones he sold, he has 580 pencils - 350 pencils = <<580-350=230>>230 pencils."
}

--- FLAWED SOLUTION ---
Making 100 pencils a day, he made 100 pencils/day * 5 days/week = <<100*5=500>>500 pencils during the week.
With the pencils from the stock, he got 500 pencils + 80 pencils = <<500+80=581>>581 pencils.
Subtracting the ones he sold, he has 581 pencils - 350 pencils = <<581-350=231>>231 pencils.

--- TARGET JSON LABEL ---
{
  "verdict": "Flawed",
  "error_details": {
    "error_type": "computational_error",
    "erroneous_line_number": "L2",
    "explanation": "Th