In [99]:
# ---------------------------------------------------------------------- #
#  Imports
# ---------------------------------------------------------------------- #

from datasets import load_dataset
import re
import json
import importlib.util
import copy
import ast
import inspect
from pathlib import Path
from typing import Dict, Any, Tuple, Optional, List
from num2words import num2words

# ---------------------------------------------------------------------- #
#  Global constants & Configuration
# ---------------------------------------------------------------------- #

def find_project_root():
    """Traverse upwards to find the project root, marked by the .git folder."""
    current_path = Path.cwd()
    while current_path != current_path.parent:
        if (current_path / ".git").is_dir():
            return current_path
        current_path = current_path.parent
    raise FileNotFoundError("Could not find project root. Is this a git repository?")


PROJECT_ROOT = find_project_root()

# Define project paths
CORRECT_CODE_DIR = PROJECT_ROOT / 'data' / 'code_gen_outputs_traced'
# The 'FLAWED_CODE_DIR' now serves as the location for both the flawed
# code files and their corresponding metadata files.
FLAWED_CODE_DIR = PROJECT_ROOT / 'data' / 'code_with_error_traced'

# Confirm the paths
print(f"Project root found: {PROJECT_ROOT}")
print(f"Path to correct code: {CORRECT_CODE_DIR}")
print(f"Path to flawed code & metadata: {FLAWED_CODE_DIR}")


# Load dataset and define model lists
gsm8k_train = load_dataset("gsm8k", "main", split="train")

MODEL_DICT = {
  "anthropic": ["claude-3-5-haiku-20241022"],
  "openai": ["gpt-4.1-mini"],
  "google": ["gemini-2.0-flash-thinking-exp",
             "gemini-2.5-flash-lite-preview-06-17",
             "gemini-2.5-flash"]
}

MODELS = [f"{provider}_{model}" for provider, sublist in MODEL_DICT.items() for model in sublist]


# ==============================================================================
# Utility Functions
# ==============================================================================

def build_solution_mapping(index: int, dataset: "datasets.Dataset") -> Dict[str, str]:
    """
    Extracts the natural language solution for a given problem index,
    cleans it, and structures it into a line-numbered dictionary.
    """
    solution_mapping = {}
    solution_text = dataset[index]["answer"]
    lines = [ln.strip() for ln in solution_text.splitlines() if ln.strip()]

    # Improved regex to handle commas in the final answer
    if lines and re.match(r"^####\s*[\d\.,]+$", lines[-1]):
        solution_mapping["FA"] = lines.pop(-1).strip()

    # Normalize calculator annotation brackets for consistent parsing
    angle = re.compile(r"<<([^>]+)>>")
    lines = [angle.sub(r"[[\1]]", ln) for ln in lines]

    for i, line in enumerate(lines, 1):
        solution_mapping[f"L{i}"] = line

    return solution_mapping

def execution_trace(func) -> Dict[str, Any]:
    """Simulates execution of a function and returns a variable-to-value map."""
    src = inspect.getsource(func)
    tree = ast.parse(src)
    func_def = tree.body[0]
    env = {}
    
    # Get default args
    arg_names = [arg.arg for arg in func_def.args.args]
    defaults = func_def.args.defaults
    for name, val_node in zip(arg_names[-len(defaults):], defaults):
        env[name] = eval(compile(ast.Expression(val_node), '', 'eval'))

    # Execute body
    for stmt in func_def.body:
        if isinstance(stmt, ast.Assign):
            code_obj = compile(ast.Module([stmt], []), '', 'exec')
            exec(code_obj, {}, env)
    return env

Project root found: /Users/arvindsuresh/Documents/Github/Erdos-DL-June25-Math
Path to correct code: /Users/arvindsuresh/Documents/Github/Erdos-DL-June25-Math/data/code_gen_outputs_traced
Path to flawed code & metadata: /Users/arvindsuresh/Documents/Github/Erdos-DL-June25-Math/data/code_with_error_traced


In [100]:
class NaturalLanguageErrorInjector:
    """
    Generates a flawed natural language (NL) solution and its corresponding
    structured JSON label by injecting a programmatic error from a source
    code function.

    This version is instrumented with print statements for detailed debugging.
    """

    def __init__(self, problem_index: int, model_name: str, error_type: str,
                 correct_code_dir: Path, flawed_code_dir: Path,
                 dataset: "datasets.Dataset"):
        """Initializes the injector for a specific error instance."""
        self.problem_index = problem_index
        self.model_name = model_name
        self.error_type = error_type
        self.correct_code_dir = correct_code_dir
        self.flawed_code_dir = flawed_code_dir
        self.dataset = dataset

        # Data attributes
        self.f_oracle = None
        self.f_flawed = None
        self.correct_stepwise_trace = None
        self.flawed_stepwise_trace = None
        self.metadata = None
        self.original_nl_solution = None
        self.deleted_nl_line_text = ""

        # Helper maps
        self.op_map = {'Mult': '*', 'Add': '+', 'Sub': '-', 'Div': '/'}
        self.op_synonyms = {'Mult': ['*', 'x'], 'Add': ['+'], 'Sub': ['-'], 'Div': ['/']}


    def _load_module_from_path(self, file_path: Path, module_name: str) -> Optional[Any]:
        """Loads a Python module dynamically from a given file path."""
        if not file_path.exists():
            print(f"  [LOAD_DETAIL] ❌ File not found at: {file_path}")
            return None
        spec = importlib.util.spec_from_file_location(module_name, file_path)
        if spec and spec.loader:
            module = importlib.util.module_from_spec(spec)
            spec.loader.exec_module(module)
            return module
        print(f"  [LOAD_DETAIL] ❌ Could not create module spec for: {file_path}")
        return None

    def _validate_sample_structure(self, source_code: str) -> bool:
        """
        Validates the function body using a simple line-based distance check.

        It ensures there is at most one line of code between any two `#: L<n>`
        labels by checking that the index distance between consecutive labels
        in a list of non-empty lines is not greater than 2.

        Args:
            source_code: The string content of the Python script to validate.

        Returns:
            True if the structure is valid, False otherwise.
        """
        all_lines = source_code.splitlines()
        
        # Find the start of the function body (the line after the docstring)
        body_start_index = -1
        for i, line in enumerate(all_lines):
            if line.strip().endswith('"""'):
                body_start_index = i + 1
                break
        
        if body_start_index == -1:
            return False # No valid docstring/body structure found

        # 1. Create a list of all non-empty lines from the function body
        body_lines = [line.strip() for line in all_lines[body_start_index:] if line.strip()]

        # 2. Get the index of all lines that are labels
        label_indices = [i for i, line in enumerate(body_lines) if line.startswith("#:")]

        # 3. Get successive differences between label indices
        index_diffs = set(j - i for i, j in zip(label_indices[:-1], label_indices[1:]))

        if len(index_diffs) > 1:
            return False
        return True

    def _load_data_sources(self) -> bool:
        """Loads and validates all necessary inputs from disk."""
        print("[LOAD] Attempting to load data sources...")

        # 1. Load and Validate Oracle Code
        correct_path = self.correct_code_dir / str(self.problem_index) / f"{self.model_name}.py"
        print(f"  - Loading oracle code from: {correct_path}")
        self.f_oracle = self._load_module_from_path(correct_path, "f_oracle")
        if not self.f_oracle:
            print("❌ LOAD FAILED: Oracle module could not be loaded.")
            return False
        print("  - Oracle module loaded successfully.")

        # 2. Perform Structure Validation
        print("[VALIDATE] Validating oracle code structure...")
        oracle_source = inspect.getsource(self.f_oracle.solve)
        if not self._validate_sample_structure(oracle_source):
            print("❌ VALIDATION FAILED: Sample has an invalid structure.")
            return False
        print("✅ VALIDATION SUCCEEDED.")

        # 3. Trace Oracle Code
        print("[TRACE] Generating stepwise trace for oracle code...")
        self.correct_stepwise_trace = execution_trace_stepwise(self.f_oracle.solve)
        print(f"  - Oracle trace generated with keys: {list(self.correct_stepwise_trace.keys())}")

        # 4. Load Flawed Code
        base_flawed_dir = self.flawed_code_dir / self.error_type / str(self.problem_index)
        flawed_path = base_flawed_dir / f"{self.model_name}.py"
        print(f"  - Loading flawed code from: {flawed_path}")
        self.f_flawed = self._load_module_from_path(flawed_path, "f_flawed")
        if not self.f_flawed:
            if self.error_type == 'skipped_step':
                print("  - Note: Flawed code not found for skipped_step, which is expected.")
                self.flawed_stepwise_trace = {}
            else:
                print("❌ LOAD FAILED: Flawed code module could not be loaded.")
                return False
        else:
            print("[TRACE] Generating stepwise trace for flawed code...")
            self.flawed_stepwise_trace = execution_trace_stepwise(self.f_flawed.solve)
            print(f"  - Flawed trace generated with keys: {list(self.flawed_stepwise_trace.keys())}")

        # 5. Load Metadata
        metadata_path = self.flawed_code_dir / self.error_type / str(self.problem_index) / f"metadata_{self.problem_index}_{self.error_type}.json"
        print(f"  - Loading metadata from: {metadata_path}")
        if not metadata_path.exists():
            print(f"❌ LOAD FAILED: Metadata file not found.")
            return False
        with open(metadata_path, 'r') as f:
            self.metadata = json.load(f).get(self.model_name)
        if not self.metadata:
            print(f"❌ LOAD FAILED: Metadata for model '{self.model_name}' not found in file.")
            return False
        print(f"  - Metadata loaded successfully: {self.metadata}")

        # 6. Load NL Solution
        print(f"  - Loading original NL solution for index {self.problem_index}...")
        self.original_nl_solution = build_solution_mapping(self.problem_index, self.dataset)
        if not self.original_nl_solution:
            print("❌ LOAD FAILED: Could not build original NL solution.")
            return False
        print("  - NL solution loaded successfully.")
        
        print("✅ LOAD SUCCEEDED: All data sources loaded and validated.")
        return True

    def _replace_number_in_string(self, text: str, old_num: float, new_num: float) -> str:
        """Helper to replace numbers (numeral and word form)."""
        if abs(old_num - new_num) < 1e-9: return text
        new_num_str = str(int(new_num)) if new_num == int(new_num) else f"{new_num:.2f}".rstrip('0').rstrip('.')
        if old_num == int(old_num):
            old_num_word = num2words(int(old_num))
            text = re.sub(r'\b' + re.escape(old_num_word) + r'\b', new_num_str, text, flags=re.IGNORECASE)
        old_num_str = str(int(old_num)) if old_num == int(old_num) else str(old_num)
        text = re.sub(r'\b' + re.escape(old_num_str) + r'\b', new_num_str, text)
        return text

    def _generate_flawed_nl_solution(self) -> Dict[str, str]:
        """Creates the flawed NL solution using stepwise traces."""
        print("[GENERATE] Generating flawed NL solution...")
        flawed_nl = copy.deepcopy(self.original_nl_solution)
        
        if self.error_type == 'skipped_step':
            line_to_delete = self.metadata['line_label']
            print(f"  [GENERATE_DETAIL] Skipping step: Deleting line '{line_to_delete}'.")
            if line_to_delete in flawed_nl:
                self.deleted_nl_line_text = flawed_nl.pop(line_to_delete)
            return flawed_nl

        sorted_keys = sorted(flawed_nl.keys(), key=lambda k: (k[0] != 'L', int(k[1:]) if k.startswith('L') else float('inf')))

        for key in sorted_keys:
            if not key.startswith('L'): continue
            
            line_text = flawed_nl[key]
            print(f"\n  [GENERATE_DETAIL] Processing {key}: '{line_text}'")

            correct_snapshot = self.correct_stepwise_trace.get(key)
            flawed_snapshot = self.flawed_stepwise_trace.get(key)
            
            if not correct_snapshot or not flawed_snapshot:
                print("    - No trace found for this line. Skipping modification.")
                continue
            
            original_result = correct_snapshot.get('value')
            flawed_result = flawed_snapshot.get('value')
            print(f"    - Correct value: {original_result}, Flawed value: {flawed_result}")

            if original_result is not None and flawed_result is not None:
                modified_line = self._replace_number_in_string(line_text, float(original_result), float(flawed_result))
                if modified_line != line_text:
                    print(f"    - Propagated numerical error. New line: '{modified_line}'")
                    line_text = modified_line
                else:
                    print("    - No numerical change needed for this line.")

            if self.error_type == 'incorrect_operation' and key == self.metadata.get('line_label'):
                original_op_name = self.metadata['original_op']
                new_op_name = self.metadata['new_op']
                symbols_to_replace = self.op_synonyms.get(original_op_name, [])
                new_op_symbol = self.op_map.get(new_op_name, '?')
                
                print(f"    - Operator swap: Attempting to replace '{symbols_to_replace}' with '{new_op_symbol}'.")
                
                if new_op_symbol != '?':
                    for old_symbol in symbols_to_replace:
                        line_text = re.sub(re.escape(old_symbol), new_op_symbol, line_text)
                    print(f"    - Line after operator swap: '{line_text}'")

            flawed_nl[key] = line_text
        
        return flawed_nl
    
    def _generate_json_label(self) -> Dict[str, Any]:
        """Constructs the final structured JSON label."""
        print("[LABEL] Generating final JSON label...")
        # This part remains the same as it's just templating.
        m = self.metadata
        explanation, correction = "Error: Could not generate.", "Error: Could not generate."
        if self.error_type == 'computational_error':
            explanation = f"There is a computational error. The solution states the result is {m['new_value']}, but the correct value is {m['original_value']}."
            correction = f"To correct this, replace the incorrect value {m['new_value']} with the correct value {m['original_value']}."
        elif self.error_type == 'incorrect_operation':
            original_op_symbol = self.op_map.get(m['original_op'], '?')
            new_op_symbol = self.op_map.get(m['new_op'], '?')
            explanation = f"The solution incorrectly uses a '{new_op_symbol}' operation where a '{original_op_symbol}' operation was needed."
            correction = f"To correct this, the operator should be changed from '{new_op_symbol}' to '{original_op_symbol}'."
        elif self.error_type == 'skipped_step':
            explanation = f"A necessary calculation step is missing. The solution fails to perform the step that would have defined the variable '{m['deleted_variable']}'."
            correction = f"To correct this, the following step must be inserted back into the solution: '{self.deleted_nl_line_text}'"
        print("✅ LABEL SUCCEEDED.")
        return {"verdict": "Flawed", "error_details": {"error_type": self.error_type, "erroneous_line_number": m['line_label'], "explanation": explanation, "correction": correction}}

    def inject_nl_error(self) -> Optional[Tuple[Dict[str, str], Dict[str, Any]]]:
        """Orchestrates the end-to-end injection process."""
        print(f"\n--- Starting injection for Index: {self.problem_index}, Model: {self.model_name}, Error: {self.error_type} ---")
        if not self._load_data_sources():
            print("--- ❌ INJECTION HALTED: Data loading or validation failed. ---")
            return None
        
        flawed_nl_solution = self._generate_flawed_nl_solution()
        json_label = self._generate_json_label()
        
        print("--- ✅ INJECTION PROCESS COMPLETED SUCCESSFULLY. ---")
        return flawed_nl_solution, json_label

In [101]:
def test_single_injection(
        problem_index: int, 
        model_name: str, 
        error_type: str,
        print_original_solution: bool = True, 
        verbose: bool = True
    ):
    """
    A simple wrapper to test the NaturalLanguageErrorInjector for a single case.
    Uses the global path and dataset variables.

    Args:
        problem_index: The GSM8K index to test.
        model_name: The model whose output will be used.
        error_type: The error type to inject.
        print_original_solution: If True, prints the original NL solution.
        verbose: If True, prints detailed step-by-step outputs.
    """
    if verbose:
        print(f"----------------------")
        print(f" {error_type.upper()}")
        print(f"----------------------\n")

    injector = NaturalLanguageErrorInjector(
        problem_index=problem_index,
        model_name=model_name,
        error_type=error_type,
        correct_code_dir=CORRECT_CODE_DIR,
        flawed_code_dir=FLAWED_CODE_DIR,
        dataset=gsm8k_train
    )

    result = injector.inject_nl_error()

    if result:
        flawed_solution, final_label = result
        if verbose:
            if print_original_solution:
                print("--- Original NL Solution ---")
                print(json.dumps(injector.original_nl_solution, indent=4))
            
            print("--- Generated Flawed NL Solution ---")
            print(json.dumps(flawed_solution, indent=4))
            
            print("\n--- Generated JSON Label ---")
            print(json.dumps(final_label, indent=4))
            print("\nInjection Successful.\n")
        return result
    else:
        if verbose:
            print(f"Injection Failed for {error_type}.\n")
        return None

In [102]:
test_error_types = ['computational_error', 'incorrect_operation', 'skipped_step']
test_model = "anthropic_claude-3-5-haiku-20241022"

for index in range(10):
    print(f"=====================")
    print(f"  PROBLEM: {index}  ")
    print(f"=====================\n")
    
    # Print the original solution once per problem
    original_solution = build_solution_mapping(index, gsm8k_train)
    print("--- Original NL Solution ---")
    print(json.dumps(original_solution, indent=4))
    
    # Test all error types for this problem
    for error_type in test_error_types:
        test_single_injection(
            problem_index=index,
            model_name=test_model,
            error_type=error_type,
            # Pass False to prevent re-printing the original solution
            print_original_solution=False 
        )

  PROBLEM: 0  

--- Original NL Solution ---
{
    "FA": "#### 72",
    "L1": "Natalia sold 48/2 = [[48/2=24]]24 clips in May.",
    "L2": "Natalia sold 48+24 = [[48+24=72]]72 clips altogether in April and May."
}
----------------------
 COMPUTATIONAL_ERROR
----------------------


--- Starting injection for Index: 0, Model: anthropic_claude-3-5-haiku-20241022, Error: computational_error ---
[LOAD] Attempting to load data sources...
  - Loading oracle code from: /Users/arvindsuresh/Documents/Github/Erdos-DL-June25-Math/data/code_gen_outputs_traced/0/anthropic_claude-3-5-haiku-20241022.py
  - Oracle module loaded successfully.
[VALIDATE] Validating oracle code structure...
✅ VALIDATION SUCCEEDED.
[TRACE] Generating stepwise trace for oracle code...
  - Oracle trace generated with keys: ['L1', 'L2', 'FA']
  - Loading flawed code from: /Users/arvindsuresh/Documents/Github/Erdos-DL-June25-Math/data/code_with_error_traced/computational_error/0/anthropic_claude-3-5-haiku-20241022.py
[TRACE] 

In [103]:
# Test suite for debugging the NL error injection pipeline.
# Selected to test a range of scenarios including:
# - Previous complete failures (0, 6)
# - Error propagation in multi-step problems (14, 23, 95)
# - Edge cases with unusual formatting or types (8, 35, 90)
# - Regressions on known cosmetic bugs (54, 61)

test_indices = [0, 6, 8, 14, 23, 35, 54, 61, 90, 95]

for index in test_indices:
    print(f"=====================")
    print(f"  PROBLEM: {index}  ")
    print(f"=====================\n")
    
    # Print the original solution once per problem
    original_solution = build_solution_mapping(index, gsm8k_train)
    print("--- Original NL Solution ---")
    print(json.dumps(original_solution, indent=4))
    
    # Test all error types for this problem
    for error_type in test_error_types:
        test_single_injection(
            problem_index=index,
            model_name=test_model,
            error_type=error_type,
            # Pass False to prevent re-printing the original solution
            print_original_solution=False 
        )

  PROBLEM: 0  

--- Original NL Solution ---
{
    "FA": "#### 72",
    "L1": "Natalia sold 48/2 = [[48/2=24]]24 clips in May.",
    "L2": "Natalia sold 48+24 = [[48+24=72]]72 clips altogether in April and May."
}
----------------------
 COMPUTATIONAL_ERROR
----------------------


--- Starting injection for Index: 0, Model: anthropic_claude-3-5-haiku-20241022, Error: computational_error ---
[LOAD] Attempting to load data sources...
  - Loading oracle code from: /Users/arvindsuresh/Documents/Github/Erdos-DL-June25-Math/data/code_gen_outputs_traced/0/anthropic_claude-3-5-haiku-20241022.py
  - Oracle module loaded successfully.
[VALIDATE] Validating oracle code structure...
✅ VALIDATION SUCCEEDED.
[TRACE] Generating stepwise trace for oracle code...
  - Oracle trace generated with keys: ['L1', 'L2', 'FA']
  - Loading flawed code from: /Users/arvindsuresh/Documents/Github/Erdos-DL-June25-Math/data/code_with_error_traced/computational_error/0/anthropic_claude-3-5-haiku-20241022.py
[TRACE] 