In [33]:
# ---------------------------------------------------------------------- #
#  Imports
# ---------------------------------------------------------------------- #

from datasets import load_dataset
import re
import json
import importlib.util
import copy
import ast
import inspect
from pathlib import Path
from typing import Dict, Any, Tuple, Optional, List
from num2words import num2words

# ---------------------------------------------------------------------- #
#  Global constants & Configuration
# ---------------------------------------------------------------------- #

def find_project_root():
    """Traverse upwards to find the project root, marked by the .git folder."""
    current_path = Path.cwd()
    while current_path != current_path.parent:
        if (current_path / ".git").is_dir():
            return current_path
        current_path = current_path.parent
    raise FileNotFoundError("Could not find project root. Is this a git repository?")


PROJECT_ROOT = find_project_root()

# Define project paths
CORRECT_CODE_DIR = PROJECT_ROOT / 'data' / 'code_gen_outputs_traced'
# The 'FLAWED_CODE_DIR' now serves as the location for both the flawed
# code files and their corresponding metadata files.
FLAWED_CODE_DIR = PROJECT_ROOT / 'data' / 'code_with_error_traced'

# Confirm the paths
print(f"Project root found: {PROJECT_ROOT}")
print(f"Path to correct code: {CORRECT_CODE_DIR}")
print(f"Path to flawed code & metadata: {FLAWED_CODE_DIR}")


# Load dataset and define model lists
gsm8k_train = load_dataset("gsm8k", "main", split="train")

MODEL_DICT = {
  "anthropic": ["claude-3-5-haiku-20241022"],
  "openai": ["gpt-4.1-mini"],
  "google": ["gemini-2.0-flash-thinking-exp",
             "gemini-2.5-flash-lite-preview-06-17",
             "gemini-2.5-flash"]
}

MODELS = [f"{provider}_{model}" for provider, sublist in MODEL_DICT.items() for model in sublist]


# ==============================================================================
# Utility Functions
# ==============================================================================

def build_solution_mapping(index: int, dataset: "datasets.Dataset") -> Dict[str, str]:
    """
    Extracts the natural language solution for a given problem index,
    cleans it, and structures it into a line-numbered dictionary.
    """
    solution_mapping = {}
    solution_text = dataset[index]["answer"]
    lines = [ln.strip() for ln in solution_text.splitlines() if ln.strip()]

    # Improved regex to handle commas in the final answer
    if lines and re.match(r"^####\s*[\d\.,]+$", lines[-1]):
        solution_mapping["FA"] = lines.pop(-1).strip()

    # Normalize calculator annotation brackets for consistent parsing
    angle = re.compile(r"<<([^>]+)>>")
    lines = [angle.sub(r"[[\1]]", ln) for ln in lines]

    for i, line in enumerate(lines, 1):
        solution_mapping[f"L{i}"] = line

    return solution_mapping

def execution_trace(func) -> Dict[str, Any]:
    """Simulates execution of a function and returns a variable-to-value map."""
    src = inspect.getsource(func)
    tree = ast.parse(src)
    func_def = tree.body[0]
    env = {}
    
    # Get default args
    arg_names = [arg.arg for arg in func_def.args.args]
    defaults = func_def.args.defaults
    for name, val_node in zip(arg_names[-len(defaults):], defaults):
        env[name] = eval(compile(ast.Expression(val_node), '', 'eval'))

    # Execute body
    for stmt in func_def.body:
        if isinstance(stmt, ast.Assign):
            code_obj = compile(ast.Module([stmt], []), '', 'exec')
            exec(code_obj, {}, env)
    return env

Project root found: /Users/arvindsuresh/Documents/Github/Erdos-DL-June25-Math
Path to correct code: /Users/arvindsuresh/Documents/Github/Erdos-DL-June25-Math/data/code_gen_outputs_traced
Path to flawed code & metadata: /Users/arvindsuresh/Documents/Github/Erdos-DL-June25-Math/data/code_with_error_traced


In [34]:
class NaturalLanguageErrorInjector:
    """
    Generates a flawed natural language (NL) solution and its corresponding
    structured JSON label by injecting a programmatic error from a source
    code function.
    """

    def __init__(self, problem_index: int, model_name: str, error_type: str,
                 correct_code_dir: Path, flawed_code_dir: Path,
                 dataset: "datasets.Dataset"):
        """
        Initializes the injector for a specific error instance.
        """
        self.problem_index = problem_index
        self.model_name = model_name
        self.error_type = error_type
        
        # Paths and data sources are passed in directly
        self.correct_code_dir = correct_code_dir
        self.flawed_code_dir = flawed_code_dir
        self.dataset = dataset

        # Initialize data attributes
        self.f_oracle = None
        self.f_flawed = None
        self.correct_trace = None
        self.flawed_trace = None
        self.metadata = None
        self.original_nl_solution = None
        self.deleted_nl_line_text = ""

        # Map metadata strings to symbols for explanations
        self.op_map = {
            'Mult': '*', 'Add': '+', 'Sub': '-', 'Div': '/'
        }
        # Map AST op classes to symbols for explanations
        self.op_synonyms = {
            'Mult': ['*', 'x'],
            'Add': ['+'],
            'Sub': ['-'],
            'Div': ['/']
        }

        # Add a flag to track successful operator swaps
        self.operator_swap_successful = False

    def _load_module_from_path(self, file_path: Path, module_name: str):
        if not file_path.exists():
            print(f"❌ Error: File not found at {file_path}")
            return None
        spec = importlib.util.spec_from_file_location(module_name, file_path)
        if spec and spec.loader:
            module = importlib.util.module_from_spec(spec)
            spec.loader.exec_module(module)
            return module
        return None

    def _load_data_sources(self) -> bool:
        # 1. Load correct/oracle code and trace
        correct_path = self.correct_code_dir / str(self.problem_index) / f"{self.model_name}.py"
        self.f_oracle = self._load_module_from_path(correct_path, "f_oracle")
        if not self.f_oracle: return False
        self.correct_trace = execution_trace(self.f_oracle.solve)

        # 2. Load flawed code and trace
        base_flawed_dir = self.flawed_code_dir / self.error_type / str(self.problem_index)
        flawed_path = base_flawed_dir / f"{self.model_name}.py"
        self.f_flawed = self._load_module_from_path(flawed_path, "f_flawed")
        if self.f_flawed:
            try:
                self.flawed_trace = execution_trace(self.f_flawed.solve)
            except Exception:
                if self.error_type != 'skipped_step':
                    print(f"⚠️ Warning: Could not trace flawed function for {self.error_type}")
                self.flawed_trace = {}
        else:
            if self.error_type != 'skipped_step': return False
            self.flawed_trace = {}

        # 3. Load injection metadata from the same directory as the flawed code
        metadata_path = base_flawed_dir / f"metadata_{self.problem_index}_{self.error_type}.json"
        if not metadata_path.exists():
            print(f"❌ Error: Metadata file not found at {metadata_path}")
            return False
        with open(metadata_path, 'r') as f:
            full_metadata = json.load(f)
            self.metadata = full_metadata.get(self.model_name)
        if not self.metadata:
            print(f"❌ Error: No metadata found for model '{self.model_name}' in {metadata_path}")
            return False

        # 4. Load original NL solution from the dataset
        self.original_nl_solution = build_solution_mapping(self.problem_index, self.dataset)
        if not self.original_nl_solution:
            print(f"❌ Error: Could not build NL solution for index {self.problem_index}")
            return False
            
        return True

    def _replace_number_in_string(self, text: str, old_num: float, new_num: float) -> str:
        if abs(old_num - new_num) < 1e-9:
            return text
        
        # Improved formatting to avoid trailing decimals
        new_num_str = str(int(new_num)) if new_num.is_integer() else f"{new_num:.2f}".rstrip('0').rstrip('.')
        
        if old_num.is_integer():
            old_num_word = num2words(int(old_num))
            word_pattern = re.compile(r'\b' + re.escape(old_num_word) + r'\b', re.IGNORECASE)
            text = word_pattern.sub(new_num_str, text)

        old_num_str = str(int(old_num)) if old_num.is_integer() else str(old_num)
        numeral_pattern = re.compile(r'\b' + re.escape(old_num_str) + r'\b')
        text = numeral_pattern.sub(new_num_str, text)
        
        return text

    def _modify_nl_line(self, line_text: str) -> str:
        """
        Modifies a single NL line by replacing all numbers (operands and results)
        with their corresponding flawed values from the execution trace.
        """
        # 1. Find all unique numbers in the line text. The regex finds integers and floats.
        try:
            # Use a more robust regex to find standalone numbers.
            # Use a set for automatic deduplication.
            numbers_in_text = set(re.findall(r'\b\d+\.?\d*\b', line_text))
            unique_numbers = sorted([float(n) for n in numbers_in_text], reverse=True)
        except (ValueError, TypeError):
            return line_text # Return original if parsing fails

        if not unique_numbers:
            return line_text

        # 2. Map these numbers to variables from the correct trace.
        num_to_var_map = {}
        # Make a copy of correct_trace values to avoid re-mapping the same value to different vars.
        trace_copy = self.correct_trace.copy()
        
        for num in unique_numbers:
            found_var = None
            for var, val in trace_copy.items():
                if isinstance(val, (int, float)) and abs(val - num) < 1e-6:
                    found_var = var
                    break
            if found_var:
                num_to_var_map[num] = found_var
                # Remove from copy to prevent a value like '2' from mapping to multiple variables.
                del trace_copy[found_var]

        if not num_to_var_map:
            return line_text

        # 3. Iteratively replace each number with its flawed counterpart.
        modified_line = line_text
        for old_num, var_name in num_to_var_map.items():
            if var_name in self.flawed_trace:
                new_flawed_value = float(self.flawed_trace[var_name])
                modified_line = self._replace_number_in_string(modified_line, old_num, new_flawed_value)
        
        return modified_line

    def _generate_flawed_nl_solution(self) -> Dict[str, str]:
        """
        Creates the flawed natural language solution mapping by propagating numerical
        errors and handling specific error type modifications.
        """
        flawed_nl = copy.deepcopy(self.original_nl_solution)
        self.operator_swap_successful = False # Reset flag for each run

        if self.error_type == 'skipped_step':
            # ... (no changes to this block)
            line_to_delete = self.metadata['line_label']
            if line_to_delete in flawed_nl:
                self.deleted_nl_line_text = flawed_nl[line_to_delete]
                del flawed_nl[line_to_delete]
            return flawed_nl

        sorted_keys = sorted(flawed_nl.keys(), key=lambda k: (k[0] != 'L', int(k[1:]) if k.startswith('L') else float('inf')))
        
        for key in sorted_keys:
            line_text = flawed_nl.get(key, "")
            modified_line = self._modify_nl_line(line_text)

            if self.error_type == 'incorrect_operation' and key == self.metadata.get('line_label'):
                original_op_name = self.metadata.get('original_op')
                new_op_name = self.metadata.get('new_op')
                
                symbols_to_replace = self.op_synonyms.get(original_op_name, [])
                new_op_symbol = self.op_map.get(new_op_name, '?')

                is_op_present_in_line = any(op in modified_line for op in symbols_to_replace)

                if is_op_present_in_line and new_op_symbol != '?':
                    for old_symbol in symbols_to_replace:
                        modified_line = modified_line.replace(old_symbol, new_op_symbol)
                    # --- SET FLAG ON SUCCESS ---
                    self.operator_swap_successful = True
            
            flawed_nl[key] = modified_line
        
        return flawed_nl

    def _generate_json_label(self) -> Dict[str, Any]:
        """
        Constructs the final structured JSON label using metadata-driven
        templates based on the error type.
        """
        m = self.metadata
        explanation = "Error: Could not generate explanation."
        correction = "Error: Could not generate correction."

        if self.error_type == 'computational_error':
            explanation = f"There is a computational error. The solution states the result is {m['new_value']}, but the correct value is {m['original_value']}."
            correction = f"To correct this, replace the incorrect value {m['new_value']} with the correct value {m['original_value']}."

        elif self.error_type == 'incorrect_operation':
            # This is the corrected logic that fixes the '?' bug.
            original_op_symbol = self.op_map.get(m.get('original_op'), '?')
            new_op_symbol = self.op_map.get(m.get('new_op'), '?')
            
            explanation = f"The solution incorrectly uses a '{new_op_symbol}' operation where a '{original_op_symbol}' operation was needed."
            correction = f"To correct this, the operator should be changed from '{new_op_symbol}' to '{original_op_symbol}'."

        elif self.error_type == 'skipped_step':
            explanation = f"A necessary calculation step is missing. The solution fails to perform the step that would have defined the variable '{m['deleted_variable']}'."
            correction = f"To correct this, the following step must be inserted back into the solution: '{self.deleted_nl_line_text}'"
        
        elif self.error_type == 'incorrect_operand':
            explanation = f"The calculation uses an incorrect variable. It incorrectly references '{m['new_operand']}' instead of '{m['original_operand']}'."
            correction = f"To correct this, the variable '{m['new_operand']}' should be replaced with the correct variable, '{m['original_operand']}'."

        return {
            "verdict": "Flawed",
            "error_details": {
                "error_type": self.error_type,
                # This ensures the line number from metadata is always used correctly.
                "erroneous_line_number": m['line_label'],
                "explanation": explanation,
                "correction": correction
            }
        }

    def inject_nl_error(self) -> Optional[Tuple[Dict[str, str], Dict[str, Any]]]:
        """
        Orchestrates the end-to-end process of generating a flawed NL solution
        and its corresponding JSON label, with quality control checks.
        """
        if not self._load_data_sources():
            # print("--- Process halted due to data loading failure. ---")
            return None
        
        flawed_nl_solution = self._generate_flawed_nl_solution()
        
        # --- FINAL QUALITY GATE ---
        # For incorrect_operation, if the swap did not actually happen in the NL text
        # (due to code/NL mismatch or lack of annotation), discard this example.
        if self.error_type == 'incorrect_operation' and not self.operator_swap_successful:
            # print(f"Discarding sample for index {self.problem_index}: Operator swap failed.")
            return None

        json_label = self._generate_json_label()
        
        return flawed_nl_solution, json_label

In [None]:
def test_single_injection(
        problem_index: int, 
        model_name: str, 
        error_type: str, 
        verbose: bool = True):
    """
    A simple wrapper to test the NaturalLanguageErrorInjector for a single case.
    Uses the global path and dataset variables defined in Cell 1.
    """
    if verbose:
        print(f"============================================================")
        print(f"  Testing Injection for: {error_type.upper()}")
        print(f"  Problem Index: {problem_index}, Model: {model_name}")
        print(f"============================================================\n")

    injector = NaturalLanguageErrorInjector(
        problem_index=problem_index,
        model_name=model_name,
        error_type=error_type,
        correct_code_dir=CORRECT_CODE_DIR,
        flawed_code_dir=FLAWED_CODE_DIR,
        dataset=gsm8k_train
    )

    result = injector.inject_nl_error()

    if result:
        flawed_solution, final_label = result
        if verbose:
            print("--- Original NL Solution ---")
            print(json.dumps(injector.original_nl_solution, indent=4))
            
            print("\n--- Generated Flawed NL Solution ---")
            print(json.dumps(flawed_solution, indent=4))
            
            print("\n--- Generated JSON Label ---")
            print(json.dumps(final_label, indent=4))
            print("\n✅ Injection Successful.\n")
        return result
    else:
        if verbose:
            print(f"❌ Injection Failed for {error_type}.\n")
        return None

In [36]:
# --- Example Usage ---
# Assuming 'metadata.json' files exist in the specified directories.
# E.g., 'data/code_with_error_traced/computational_error/0/metadata.json'

# test_problem_index = 0
test_model = "anthropic_claude-3-5-haiku-20241022"
test_error_types = ['computational_error', 'incorrect_operation', 'skipped_step']

for test_problem_index in range(10):
    for error_type in test_error_types:
        test_single_injection(
            problem_index=test_problem_index,
            model_name=test_model,
            error_type=error_type,
        )

  Testing Injection for: COMPUTATIONAL_ERROR
  Problem Index: 0, Model: anthropic_claude-3-5-haiku-20241022

--- Original NL Solution ---
{
    "FA": "#### 72",
    "L1": "Natalia sold 48/2 = [[48/2=24]]24 clips in May.",
    "L2": "Natalia sold 48+24 = [[48+24=72]]72 clips altogether in April and May."
}

--- Generated Flawed NL Solution ---
{
    "FA": "#### 82",
    "L1": "Natalia sold 48/2 = [[48/2=24]]24 clips in May.",
    "L2": "Natalia sold 48+24 = [[48+24=82]]82 clips altogether in April and May."
}

--- Generated JSON Label ---
{
    "verdict": "Flawed",
    "error_details": {
        "error_type": "computational_error",
        "erroneous_line_number": "L2",
        "explanation": "There is a computational error. The solution states the result is 82.0, but the correct value is 72.0.",
        "correction": "To correct this, replace the incorrect value 82.0 with the correct value 72.0."
    }
}

✅ Injection Successful.

  Testing Injection for: INCORRECT_OPERATION
  Problem I

In [37]:
for test_problem_index in range(11, 21):
    for error_type in test_error_types:
        test_single_injection(
            problem_index=test_problem_index,
            model_name=test_model,
            error_type=error_type,
        )

  Testing Injection for: COMPUTATIONAL_ERROR
  Problem Index: 11, Model: anthropic_claude-3-5-haiku-20241022

--- Original NL Solution ---
{
    "FA": "#### 5",
    "L1": "He saved up $110 total because 95 + 15 = [[95+15=110]]110",
    "L2": "He saved $15 from his allowance because 3 x 5 = [[3*5=15]]15",
    "L3": "He earned $60 mowing lawns because 4 x 15 = [[4*15=60]]60",
    "L4": "He earned $35 shoveling driveways because 110 - 60 - 15 = [[110-60-15=35]]35",
    "L5": "He shoveled 5 driveways because 35 / 7 = [[35/7=5]]5"
}

--- Generated Flawed NL Solution ---
{
    "FA": "#### 5",
    "L1": "He saved up $110 total because 95 + 15 = [[95+15=110]]110",
    "L2": "He saved $15 from his allowance because 3 x 5 = [[3*5=15]]15",
    "L3": "He earned $60 mowing lawns because 4 x 15 = [[4*15=60]]60",
    "L4": "He earned $25 shoveling driveways because 110 - 60 - 15 = [[110-60-15=25]]25",
    "L5": "He shoveled 5 driveways because 25 / 7 = [[25/7=5]]5"
}

--- Generated JSON Label ---
{
 

In [38]:
def test_single_injection(problem_index: int, model_name: str, error_type: str,
                          print_original_solution: bool = True, verbose: bool = True):
    """
    A simple wrapper to test the NaturalLanguageErrorInjector for a single case.
    Uses the global path and dataset variables.

    Args:
        problem_index: The GSM8K index to test.
        model_name: The model whose output will be used.
        error_type: The error type to inject.
        print_original_solution: If True, prints the original NL solution.
        verbose: If True, prints detailed step-by-step outputs.
    """
    if verbose:
        print(f"------------------------------------------------------------")
        print(f"  Testing Injection for: {error_type.upper()}")
        print(f"------------------------------------------------------------\n")

    injector = NaturalLanguageErrorInjector(
        problem_index=problem_index,
        model_name=model_name,
        error_type=error_type,
        correct_code_dir=CORRECT_CODE_DIR,
        flawed_code_dir=FLAWED_CODE_DIR,
        dataset=gsm8k_train
    )

    result = injector.inject_nl_error()

    if result:
        flawed_solution, final_label = result
        if verbose:
            if print_original_solution:
                print("--- Original NL Solution ---")
                print(json.dumps(injector.original_nl_solution, indent=4))
            
            print("--- Generated Flawed NL Solution ---")
            print(json.dumps(flawed_solution, indent=4))
            
            print("\n--- Generated JSON Label ---")
            print(json.dumps(final_label, indent=4))
            print("\n✅ Injection Successful.\n")
        return result
    else:
        if verbose:
            print(f"❌ Injection Failed for {error_type}.\n")
        return None

# ==============================================================================
# New, more concise testing loop for multiple problems
# ==============================================================================

# --- Configuration ---
# You can change this list to test any range of problems
indices_to_test = range(50) 
test_model = "anthropic_claude-3-5-haiku-20241022"
test_error_types = ['computational_error', 'incorrect_operation', 'skipped_step']

# --- Main Loop ---
for index in indices_to_test:
    print(f"\n\n============================================================")
    print(f"  ANALYZING PROBLEM INDEX: {index}  ")
    print(f"============================================================\n")
    
    # Print the original solution once per problem
    original_solution = build_solution_mapping(index, gsm8k_train)
    print("--- Original NL Solution ---")
    print(json.dumps(original_solution, indent=4))
    print("\n")
    
    # Test all error types for this problem
    for error_type in test_error_types:
        test_single_injection(
            problem_index=index,
            model_name=test_model,
            error_type=error_type,
            # Pass False to prevent re-printing the original solution
            print_original_solution=False 
        )



  ANALYZING PROBLEM INDEX: 0  

--- Original NL Solution ---
{
    "FA": "#### 72",
    "L1": "Natalia sold 48/2 = [[48/2=24]]24 clips in May.",
    "L2": "Natalia sold 48+24 = [[48+24=72]]72 clips altogether in April and May."
}


------------------------------------------------------------
  Testing Injection for: COMPUTATIONAL_ERROR
------------------------------------------------------------

--- Generated Flawed NL Solution ---
{
    "FA": "#### 82",
    "L1": "Natalia sold 48/2 = [[48/2=24]]24 clips in May.",
    "L2": "Natalia sold 48+24 = [[48+24=82]]82 clips altogether in April and May."
}

--- Generated JSON Label ---
{
    "verdict": "Flawed",
    "error_details": {
        "error_type": "computational_error",
        "erroneous_line_number": "L2",
        "explanation": "There is a computational error. The solution states the result is 82.0, but the correct value is 72.0.",
        "correction": "To correct this, replace the incorrect value 82.0 with the correct value 72.