#### **Cell 1: Imports, path, initial setp**

In [None]:
import json
import random
import re
from pathlib import Path
from typing import List, Dict, Set, Tuple

import pandas as pd
from datasets import load_dataset, Dataset, DatasetDict
from tqdm.auto import tqdm

# --- Path and Directory Definitions ---

def find_project_root(marker: str = ".git") -> Path:
    """Traverse upwards to find the project root, marked by the git repository."""
    current_path = Path.cwd().resolve()
    while current_path != current_path.parent:
        if (current_path / marker).exists():
            return current_path
        current_path = current_path.parent
    raise FileNotFoundError(f"Could not find project root. Marker '{marker}' not found.")

# --- Global Constants and Paths ---
PROJECT_ROOT = find_project_root()
DATA_DIR = PROJECT_ROOT / 'data'
OUTPUT_DIR = DATA_DIR / "sft-datasets/verifier-v2-two-task"
PROCESSED_TEMPLATE_DIR = DATA_DIR / "template-generated-processed"

# --- Ensure output directory exists ---
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)

# --- Configuration ---
# Seed for reproducibility of shuffling and sampling
RANDOM_SEED = 42

# Define the number of core problems to use as a base
NUM_CONCEPTUAL_PROBLEMS = 1000
NUM_COMPUTATIONAL_PROBLEMS = 1000

print(f"Project root: {PROJECT_ROOT}")
print(f"Dataset output directory: {OUTPUT_DIR}")
print(f"Random seed set to: {RANDOM_SEED}")

# --- Load Catalog of Programmatically generated Computational Errors ---
# This catalog points to individual JSON files with generated flawed solutions.
PROGRAMMATIC_COMPUTATIONAL_DIR = DATA_DIR / "computational-errors-generated"
PROGRAMMATIC_CATALOG_PATH = PROGRAMMATIC_COMPUTATIONAL_DIR / "computational_error_catalog.csv"

# --- Load manually generated conceptual/computational errors ---
# This file contains the final, human-approved conceptual/computational error text and explanations.
MANUAL_ERRORS_CSV_PATH = DATA_DIR / "manually_generated_errors_final.csv"

# --- Load Original GSM8K Dataset for 'Correct' examples and problem text ---
GSM8K_DATASET: Dataset = load_dataset("gsm8k", "main")["train"]

# --- Loading and Basic Validation ---
try:
    programmatic_comp_df = pd.read_csv(PROGRAMMATIC_CATALOG_PATH)
    print(f"Loaded {len(programmatic_comp_df):,} records from programmatic computational error catalog.")
    # drop rows with missing "erroneous_line_number"
    programmatic_comp_df.dropna(subset=["erroneous_line_number"], inplace=True)
    print(f"Programmatic error records after dropping missing line numbers: {len(programmatic_comp_df):,}")
except FileNotFoundError:
    programmatic_comp_df = pd.DataFrame()
    print(f"WARNING: Programmatic computational catalog not found at {PROGRAMMATIC_CATALOG_PATH}")

try:
    with open(MANUAL_ERRORS_CSV_PATH, 'r', encoding='utf-8') as f:
        manual_errors_data = pd.read_csv(f)
    print(f"Loaded {len(manual_errors_data):,} records from manually validated errors CSV.")
    # drop rows with missing "erroneous_line_number"
    manual_errors_data.dropna(subset=["erroneous_line_number"], inplace=True)
    print(f"Manual error records after dropping missing line numbers: {len(manual_errors_data):,}")
except FileNotFoundError:
    manual_errors_data = pd.DataFrame()
    print(f"WARNING: Manual errors data not found at {MANUAL_ERRORS_CSV_PATH}")

print(f"Loaded {len(GSM8K_DATASET):,} samples from gsm8k/main train split.")

# --- Create a quick-access lookup dictionary for original GSM8K problems ---
gsm8k_problem_lookup = {
    i: {"question": sample["question"], "answer": sample["answer"]}
    for i, sample in enumerate(tqdm(GSM8K_DATASET, desc="Creating GSM8K lookup"))
}
print("Created GSM8K problem lookup dictionary.")

Project root: /Users/arvindsuresh/Documents/Github/Erdos-DL-June25-Math
Dataset output directory: /Users/arvindsuresh/Documents/Github/Erdos-DL-June25-Math/data/sft-datasets/verifier-v2-two-task
Random seed set to: 42
Loaded 22,630 records from programmatic computational error catalog.
Loaded 1,963 records from manually validated errors CSV.
Loaded 7,473 samples from gsm8k/main train split.


Creating GSM8K lookup:   0%|          | 0/7473 [00:00<?, ?it/s]

Created GSM8K problem lookup dictionary.


#### **Cell 2: Core formatting utilities**

In [110]:
def sanitize_text(text: str) -> str:
    """
    Replaces a comprehensive set of problematic Unicode characters with their
    ASCII equivalents to prevent model generation and string parsing errors.
    """
    if not isinstance(text, str):
        return ""
        
    replacements = {
        "\u2212": "-",  # Minus Sign
        "\u00d7": "*",  # Multiplication Sign
        "\u00f7": "/",  # Division Sign
        "\u22c5": "*",  # Dot Operator
        "\u201c": '"',  # Left Double Quotation Mark
        "\u201d": '"',  # Right Double Quotation Mark
        "\u2018": "'",  # Left Single Quotation Mark
        "\u2019": "'",  # Right Single Quotation Mark
        "\u2014": "-",  # Em Dash
        "\u2013": "-",  # En Dash
        "\u2026": "...",# Horizontal Ellipsis
        "\u00a0": " ",  # No-Break Space
    }
    for uni, ascii_char in replacements.items():
        text = text.replace(uni, ascii_char)
    return text

def clean_and_split_solution(raw_text: str) -> Tuple[str, str | None]:
    """
    Takes a raw solution text, sanitizes it, and separates the reasoning
    lines from the final answer line.
    
    Returns:
        A tuple containing (cleaned_reasoning_text, final_answer_string).
        final_answer_string is None if not found.
    """
    if not isinstance(raw_text, str):
        return "", None
        
    # 1. Sanitize all characters first
    sanitized_text = sanitize_text(raw_text)
    
    # 2. Remove calculator annotations
    text_no_annotations = re.sub(r'<<.*?>>', '', sanitized_text)
    
    # 3. Remove comma separators from numbers
    text_no_commas = re.sub(r'(\d),(\d)', r'\1\2', text_no_annotations)
    
    lines = text_no_commas.split('\n')
    final_answer = None
    
    # 4. Find and extract the final answer line
    if lines and re.match(r'^\s*####\s*.*$', lines[-1]):
        final_answer_line = lines.pop().strip()
        # Extract just the number part after ####
        match = re.search(r'####\s*(.*)', final_answer_line)
        if match:
            final_answer = match.group(1).strip()

    # 5. Process the remaining reasoning lines
    cleaned_lines = [line.strip() for line in lines if line.strip()]
    reasoning_text = '\n'.join(cleaned_lines)
    
    return reasoning_text, final_answer

def convert_solution_to_json_str(cleaned_reasoning: str, final_answer: str | None) -> str:
    """
    Takes cleaned reasoning text and a final answer, and converts them into
    a single JSON-formatted string.
    """
    lines = cleaned_reasoning.split('\n')
    solution_dict = {f"L{i+1}": line for i, line in enumerate(lines) if line}
    
    if final_answer is not None:
        solution_dict["FA"] = final_answer
        
    return json.dumps(solution_dict, indent=2)

def extract_equations_from_annotations(raw_text: str) -> str:
    """
    Parses a raw solution text to find all calculator annotations (<<...>>)
    and returns them as a JSON-formatted dictionary string.
    """
    if not isinstance(raw_text, str):
        return "{}"
        
    # Sanitize text first to handle different dash types in equations
    sanitized_text = sanitize_text(raw_text)
    equations = {}
    lines = sanitized_text.split('\n')
    
    for i, line in enumerate(lines):
        annotations = re.findall(r'<<(.*?)>>', line)
        if annotations:
            # We assume at most one annotation per line
            equation_str = annotations[0].strip()
            equations[f"L{i+1}"] = equation_str
            
    return json.dumps(equations, indent=2)

#### **Cell 3: Testing core formatting + equation extraction utilities**

In [111]:
def verify_utilities_on_samples(num_samples: int = 3, random_seed: int = RANDOM_SEED) -> None:
    """
    Selects random samples from each data source (programmatic, manual, original)
    and runs them through the REVISED data transformation utilities to verify their output.
    """
    print("="*80)
    print("--- Verifying REVISED Utility Functions on Live Data Samples ---")
    print("="*80)

    # --- 1. Test on Programmatically Generated Computational Errors ---
    print("\n\n--- Testing on: Programmatic Computational Errors ---")
    if not programmatic_comp_df.empty:
        valid_programmatic_samples = programmatic_comp_df.dropna(subset=['filepath']).sample(n=num_samples, random_state=random_seed)
        
        for i, (_, row) in enumerate(valid_programmatic_samples.iterrows()):
            print(f"\n--- Programmatic Sample {i+1}/{num_samples} (Index: {row['index']}) ---")
            try:
                filepath = PROJECT_ROOT / row['filepath']
                with open(filepath, 'r', encoding='utf-8') as f:
                    data = json.load(f)
                raw_solution = data['flawed_nl_solution']
                problem_text = gsm8k_problem_lookup[row['index']]['question']
                
                print(f"Original Raw Solution:\n---\n{raw_solution}\n---")
                reasoning, fa = clean_and_split_solution(raw_solution)
                print(f"Cleaned Reasoning:\n---\n{reasoning}\n---")
                print(f"Final Answer: '{fa}'")
                solution_json = convert_solution_to_json_str(reasoning, fa)
                print(f"Solution as JSON:\n---\n{solution_json}\n---")
                equations_json = extract_equations_from_annotations(raw_solution)
                print(f"Extracted Equations:\n---\n{equations_json}\n---")
            except Exception as e:
                print(f"ERROR processing sample: {e}")
    else:
        print("Skipping programmatic samples: DataFrame is empty.")

    # --- 2. Test on Manually Validated Errors ---
    print("\n\n--- Testing on: Manually Validated Errors ---")
    if not manual_errors_data.empty:
        manual_records = manual_errors_data.to_dict('records')
        k = min(num_samples, len(manual_records))
        random_manual_samples = random.sample(manual_records, k=k)
        
        for i, sample in enumerate(random_manual_samples):
            print(f"\n--- Manual Sample {i+1}/{k} (Index: {sample.get('index', 'N/A')}) ---")
            try:
                raw_solution = sample['wrong_answer']
                problem_text = sample['question']

                print(f"Original Raw Solution:\n---\n{raw_solution}\n---")
                reasoning, fa = clean_and_split_solution(raw_solution)
                print(f"Cleaned Reasoning:\n---\n{reasoning}\n---")
                print(f"Final Answer: '{fa}'")
                solution_json = convert_solution_to_json_str(reasoning, fa)
                print(f"Solution as JSON:\n---\n{solution_json}\n---")
                equations_json = extract_equations_from_annotations(raw_solution)
                print(f"Extracted Equations:\n---\n{equations_json}\n---")
            except Exception as e:
                print(f"ERROR processing sample: {e}")
    else:
        print("Skipping manual samples: DataFrame is empty.")

    # --- 3. Test on Original GSM8K Correct Solutions ---
    print("\n\n--- Testing on: Original GSM8K Correct Solutions ---")
    if gsm8k_problem_lookup:
        random_indices = random.sample(list(gsm8k_problem_lookup.keys()), k=num_samples)

        for i, idx in enumerate(random_indices):
            print(f"\n--- Original GSM8K Sample {i+1}/{num_samples} (Index: {idx}) ---")
            try:
                sample = gsm8k_problem_lookup[idx]
                raw_solution = sample['answer']
                problem_text = sample['question']
                
                print(f"Original Raw Solution:\n---\n{raw_solution}\n---")
                reasoning, fa = clean_and_split_solution(raw_solution)
                print(f"Cleaned Reasoning:\n---\n{reasoning}\n---")
                print(f"Final Answer: '{fa}'")
                solution_json = convert_solution_to_json_str(reasoning, fa)
                print(f"Solution as JSON:\n---\n{solution_json}\n---")
                equations_json = extract_equations_from_annotations(raw_solution)
                print(f"Extracted Equations:\n---\n{equations_json}\n---")
            except Exception as e:
                print(f"ERROR processing sample: {e}")
    else:
        print("Skipping original GSM8K samples: Lookup dictionary is empty.")

In [112]:
# Set a new random seed for sampling to get different samples each run if desired
verify_utilities_on_samples(num_samples=3)

--- Verifying REVISED Utility Functions on Live Data Samples ---


--- Testing on: Programmatic Computational Errors ---

--- Programmatic Sample 1/3 (Index: 3371) ---
Original Raw Solution:
---
The third chapter has P + 3 + 3 = P + <<+3+3=6>>6 pages.
The fourth chapter has P + 6 + 3 = P + <<+6+3=4>>4 pages.
The fifth chapter has P + 4 + 3 = P + <<+4+3=7>>7 pages.
The whole book has P + 3 + P + 6 + P + 4 + P + 7 = 5P + 20 = 95 pages.
Thus, without the 20 extra pages from additional chapters, 5P = 95 - 20 = 75 pages.
Therefore, the first chapter had P = 75 / 5 = <<75/5=15>>15 pages.
#### 15
---
Cleaned Reasoning:
---
The third chapter has P + 3 + 3 = P + 6 pages.
The fourth chapter has P + 6 + 3 = P + 4 pages.
The fifth chapter has P + 4 + 3 = P + 7 pages.
The whole book has P + 3 + P + 6 + P + 4 + P + 7 = 5P + 20 = 95 pages.
Thus, without the 20 extra pages from additional chapters, 5P = 95 - 20 = 75 pages.
Therefore, the first chapter had P = 75 / 5 = 15 pages.
---
Final Answer: '15'


In [None]:
SYSTEM_PROMPT_CONCEPTUAL = """[ROLE] 
You are an expert in mathematical reasoning and student assessment. You will be given a math word problem ("Question") and a student's solution ("Answer"), formatted as a JSON dictionary where each key is a line number (e.g., "L1", "L2", ...) and each value is a step in the solution.

[TASK]
- Carefully review the solution for conceptual logic errors only. You MUST IGNORE any arithmetic or calculation mistakes.
- Focus on whether the solution uses the correct numbers from the problem, applies the correct mathematical operations, and follows a logically valid sequence of steps.

[RESPONSE FORMAT]
- If the solution is conceptually correct, respond with exactly: None
- If you find a conceptual error, respond with a single line in this format: `{line_number}: {brief explanation of the conceptual error}`. (E.g. `"L2": "The student incorrectly used "/" when they should have used "+"`)
- Do NOT include any other text, commentary, or extra lines.

Remember: Only comment on CONCEPTUAL LOGIC, not on arithmetic mistakes."""

SYSTEM_PROMPT_EXTRACTION = """[ROLE]
You are a data extraction tool. You will be given a students' solution to a math word problem, formatted as a JSON dictionary where each key is a line number (e.g., "L1", "L2", ...) and each value is a step in the solution.

[TASK]
Carefully examine the provided solution and extract every mathematical calculation into a new JSON dictionary.

[RESPONSE FORMAT]
- The response should begin with a fence (```json) and end with a closing fence (```).
- Each key should be a line number (e.g., "L1", "L2", ...) and each value should be a single equation that faithfully represents the calculation performed in that line (e.g. "2*3=6").
- If any line does not contain a calculation, it should be omitted from the output.
- Here is an example of the expected output format:
```json
{"L1": "2*3=6", "L2": "4*6=24"}
```"""  

In [None]:
def format_messages_conceptual(
        question: str, 
        raw_solution: str, 
        assistant_content: str, 
        include_assistant_content=True
        ) -> list:
    """
    Formats the input question and solution JSON into a message batch for the model.
    
    Args:
        question (str): The math word problem.
        raw_solution (str): The student's solution in raw format.
        assistant_content (str): Content for the assistant's response (only to be used for SFT).
        include_assistant_content (bool): Whether to include the assistant's content in the messages (Default is True, for SFT).
        
    Returns:
        list: A list of dictionaries representing the message batch.
    """
    reasoning, final_answer = clean_and_split_solution(raw_solution)
    solution_json = convert_solution_to_json_str(reasoning, final_answer)
    messages = [
        {"role": "system", "content": SYSTEM_PROMPT_CONCEPTUAL},
        {"role": "user", "content": f"Question: {question}\nAnswer: {solution_json}"}
    ]
    if include_assistant_content:
        messages.append({"role": "assistant", "content": assistant_content})
    return messages

def get_answer_and_assistant_content_from_record(record: Dict):
    """
    Extracts the raw answer and explanation content from the record.
    """
    raw_solution = record.get('wrong_answer', None)
    err_line_no = record.get('erroneous_line_number', None)
    explanation = record.get('explanation', None)
    if raw_solution is None or err_line_no is None or explanation is None:
        raise ValueError("Record does not contain valid 'wrong_answer', 'erroneous_line_number', or 'explanation'.")
    assistant_content = f"{err_line_no}: {explanation}"
    return raw_solution, assistant_content

In [None]:
def get_phi4_message(sample_index, source, include_assistant_content=True):
    """
    Returns a formatted 'messages' list for Phi-4-mini-instruct inference.
    Args:
        sample_index (int): The index of the sample to use.
        source (str): One of "manual", "programmatic_computational", or "original".
    Returns:
        messages (list of dict): Formatted for transformers chat inference.
    """
    # get original question
    question = gsm8k_problem_lookup[sample_index]["question"]

    # get answer and assistant content based on source
    if source == "original":
        raw_solution = gsm8k_problem_lookup[sample_index]["answer"]
        assistant_content = None
    else:
        # choose the appropriate record based on source
        if source == "manual":
            record = manual_errors_data[manual_errors_data['index'] == sample_index]
            row = record[record['error_type']== 'conceptual'].iloc[0] if not record.empty else None
        elif source == "programmatic_computational":
            record = programmatic_comp_df[programmatic_comp_df['index'] == sample_index]
            row = 

        # get the record based on sample_index
        record = temp.iloc[sample_index]
    if source == "manual":
        temp = programmatic_comp_df
    elif source == "programmatic_computational":
        temp = manual_errors_data
    if source == "manual":
        if manual_errors_data.empty:
            raise ValueError("manual_errors_data is empty.")
        record = manual_errors_data.iloc[sample_index]
        question = record["question"]
        raw_solution = record["wrong_answer"]
        if include_assistant_content:
            erroneous_line_number = record.get("erroneous_line_number", None)
            explanation = record.get("explanation", None)
            if erroneous_line_number and explanation:
                assistant_content = f"{erroneous_line_number}: {explanation}"
            else:
                assistant_content = None
    elif source == "programmatic_computational":
        if programmatic_comp_df.empty:
            raise ValueError("programmatic_comp_df is empty.")
        row = programmatic_comp_df.iloc[sample_index]
        filepath = PROJECT_ROOT / row['filepath']
        with open(filepath, 'r', encoding='utf-8') as f:
            data = json.load(f)
        question = gsm8k_problem_lookup[row['index']]['question']
        raw_solution = data['flawed_nl_solution']
        if include_assistant_content:
            erroneous_line_number = data.get("target_json", {}).get("error_details", {}).get("erroneous_line_number", None)
            explanation = data.get("target_json", {}).get("error_details", {}).get("explanation", None)
            if erroneous_line_number and explanation:
                assistant_content = f"{erroneous_line_number}: {explanation}"
            else:
                assistant_content = None
    elif source == "original":
        sample = gsm8k_problem_lookup[sample_index]
        question = sample["question"]
        raw_solution = sample["answer"]
        if include_assistant_content:
            assistant_content = None
    else:
        raise ValueError("source must be one of 'manual', 'programmatic_computational', or 'original'.")

    reasoning, fa = clean_and_split_solution(raw_solution)
    solution_json = convert_solution_to_json_str(reasoning, fa)

    user_content = f"""### Question:
{question}
### Answer: 
```json
{solution_json}
```"""

    messages = [
        {"role": "system", "content": SYSTEM_PROMPT_CONCEPTUAL},
        {"role": "user", "content": user_content},
    ]
    if include_assistant_content:
        if assistant_content:
            messages.append({"role": "assistant", "content": assistant_content})
        else:
            messages.append({"role": "assistant", "content": "None"})
            
    return messages

In [115]:
def print_a_few_message(
        indices: List[int], 
        include_assistant_content: bool = True
        ) -> None:
    """
    Prints messages for a few samples based on their indices.
    """
    for index in indices:
        for source in ["manual", "programmatic_computational", "original"]:
            try:
                messages = get_phi4_message(
                    sample_index=index, 
                    source=source, 
                    include_assistant_content=include_assistant_content
                )
                print(f"\n=== Messages for Sample {index} (Source: {source}) ===\n")
                system_msg = messages[0]
                user_msg = messages[1]
                assistant_msg = messages[2] if len(messages) > 2 else None
                
                # Print system message
                print("--- System Message start---")
                print(system_msg['content'])
                print("--- System Message end---\n")

                # Print user message
                print("--- User Message start---")
                print(user_msg['content'])
                print("--- User Message end---\n")

                # Print assistant message if it exists
                if assistant_msg:
                    print("--- Assistant Message start---")
                    print(assistant_msg['content'])
                    print("--- Assistant Message end---\n")
            except Exception as e:
                print(f"ERROR processing sample {index} from source {source}: {e}")

In [116]:
# get 3 random indices between 0 and 100
random.seed(RANDOM_SEED)
random_indices = random.sample(range(100), 3)
print_a_few_message(random_indices, include_assistant_content=False)


=== Messages for Sample 81 (Source: manual) ===

--- System Message start---
[ROLE] 
You are an expert in mathematical reasoning and student assessment. You will be given a math word problem ("Question") and a student's solution ("Answer"), formatted as a JSON dictionary where each key is a line number (e.g., "L1", "L2", ...) and each value is a step in the solution.

[TASK]
- Carefully review the solution for conceptual logic errors only. You MUST IGNORE any arithmetic or calculation mistakes.
- Focus on whether the solution uses the correct numbers from the problem, applies the correct mathematical operations, and follows a logically valid sequence of steps.

[RESPONSE FORMAT]
- If the solution is conceptually correct, respond with exactly: None
- If you find a conceptual error, respond with a single line in this format: `L{n}: {brief explanation of the conceptual error}` (replace `{n}` with the line number, and provide a concise explanation).
- Do NOT include any other text, commen

In [117]:
def build_phi4_message_batch(
    indices: list,
    sources: list = ["manual", "programmatic_computational", "original"],
    include_assistant_content: bool = False
):
    """
    Returns a list of message lists, one per (index, source) pair, suitable for batch inference.
    Each element is a list of dicts: [{"role": ..., "content": ...}, ...]
    """
    batch = []
    for idx in indices:
        for source in sources:
            try:
                messages = get_phi4_message(
                    sample_index=idx,
                    source=source,
                    include_assistant_content=include_assistant_content
                )
                batch.append(messages)
            except Exception as e:
                print(f"Skipping sample {idx} from source {source}: {e}")
    return batch

In [118]:
# Example usage:
random.seed(RANDOM_SEED)
indices = random.sample(range(100), 10)
message_batch = build_phi4_message_batch(indices, include_assistant_content=False)
# Now message_batch can be passed to the Hugging Face pipeline in Colab:
outputs = pipe(message_batch, **generation_args)

NameError: name 'pipe' is not defined

In [None]:
MANUAL_CONCEPTUAL_INDICES = manual_errors_data[manual_errors_data['error_type'] == 'conceptual']['index'].unique().tolist()

In [None]:
len(MANUAL_CONCEPTUAL_INDICES)  # Should return the number of unique conceptual error indices

956