In [121]:
import json
import random
import re
from pathlib import Path
from typing import List, Dict, Set, Tuple

import pandas as pd
from datasets import load_dataset, Dataset, DatasetDict
from tqdm.auto import tqdm

# --- Path and Directory Definitions ---

def find_project_root(marker: str = ".git") -> Path:
    """Traverse upwards to find the project root, marked by the git repository."""
    current_path = Path.cwd().resolve()
    while current_path != current_path.parent:
        if (current_path / marker).exists():
            return current_path
        current_path = current_path.parent
    raise FileNotFoundError(f"Could not find project root. Marker '{marker}' not found.")

# --- Global Constants and Paths ---
PROJECT_ROOT = find_project_root()
DATA_DIR = PROJECT_ROOT / 'data'
OUTPUT_DIR = DATA_DIR / "sft-datasets/verifier-v2-two-task"
PROCESSED_TEMPLATE_DIR = DATA_DIR / "template-generated-processed"

# --- Ensure output directory exists ---
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)

# --- Configuration ---
# Seed for reproducibility of shuffling and sampling
RANDOM_SEED = 42

# Define the number of core problems to use as a base
NUM_CONCEPTUAL_PROBLEMS = 1000
NUM_COMPUTATIONAL_PROBLEMS = 1000

print(f"Project root: {PROJECT_ROOT}")
print(f"Dataset output directory: {OUTPUT_DIR}")
print(f"Random seed set to: {RANDOM_SEED}")

Project root: /Users/arvindsuresh/Documents/Github/Erdos-DL-June25-Math
Dataset output directory: /Users/arvindsuresh/Documents/Github/Erdos-DL-June25-Math/data/sft-datasets/verifier-v2-two-task
Random seed set to: 42


In [122]:
# --- Load Catalog of Programmatically generated Computational Errors ---
# This catalog points to individual JSON files with generated flawed solutions.
PROGRAMMATIC_COMPUTATIONAL_DIR = DATA_DIR / "computational-errors-generated"
PROGRAMMATIC_CATALOG_PATH = PROGRAMMATIC_COMPUTATIONAL_DIR / "computational_error_catalog.csv"

# --- Load Catalog of Programmatically generated Conceptual Errors ---
# This catalog points to individual JSON files with generated conceptual candidates.
CONCEPTUAL_CANDIDATES_DIR = DATA_DIR / "conceptual-error-candidates"
CONCEPTUAL_CATALOG_PATH = CONCEPTUAL_CANDIDATES_DIR / "conceptual_candidate_catalog.csv"

# --- Load manually validated/corrected conceptual errors ---
# This file contains the final, human-approved conceptual error text and explanations.
MANUAL_CONCEPTUAL_PATH = DATA_DIR / "final-datasets/conceptual_errors_final.json"

# --- Load Original GSM8K Dataset for 'Correct' examples and problem text ---
GSM8K_DATASET: Dataset = load_dataset("gsm8k", "main")["train"]

# --- Loading and Basic Validation ---
try:
    programmatic_comp_df = pd.read_csv(PROGRAMMATIC_CATALOG_PATH)
    print(f"Loaded {len(programmatic_comp_df):,} records from programmatic computational error catalog.")
except FileNotFoundError:
    programmatic_comp_df = pd.DataFrame()
    print(f"WARNING: Programmatic computational catalog not found at {PROGRAMMATIC_CATALOG_PATH}")

try:
    conceptual_cand_df = pd.read_csv(CONCEPTUAL_CATALOG_PATH)
    print(f"Loaded {len(conceptual_cand_df):,} records from conceptual error candidate catalog.")
except FileNotFoundError:
    conceptual_cand_df = pd.DataFrame()
    print(f"WARNING: Conceptual candidate catalog not found at {CONCEPTUAL_CATALOG_PATH}")

try:
    with open(MANUAL_CONCEPTUAL_PATH, 'r', encoding='utf-8') as f:
        manual_conceptual_data = json.load(f)
    print(f"Loaded {len(manual_conceptual_data):,} records from manually validated conceptual errors JSON.")
except FileNotFoundError:
    manual_conceptual_data = []
    print(f"WARNING: Manual conceptual data not found at {MANUAL_CONCEPTUAL_PATH}")

print(f"Loaded {len(GSM8K_DATASET):,} samples from gsm8k/main train split.")

# --- Create a quick-access lookup dictionary for original GSM8K problems ---
gsm8k_problem_lookup = {
    i: {"question": sample["question"], "answer": sample["answer"]}
    for i, sample in enumerate(tqdm(GSM8K_DATASET, desc="Creating GSM8K lookup"))
}
print("Created GSM8K problem lookup dictionary.")

Loaded 22,645 records from programmatic computational error catalog.
Loaded 33,716 records from conceptual error candidate catalog.
Loaded 7,473 samples from gsm8k/main train split.


Creating GSM8K lookup:   0%|          | 0/7473 [00:00<?, ?it/s]

Created GSM8K problem lookup dictionary.


In [123]:
def sanitize_text(text: str) -> str:
    """
    Replaces a comprehensive set of problematic Unicode characters with their
    ASCII equivalents to prevent model generation and string parsing errors.
    """
    if not isinstance(text, str):
        return ""
        
    replacements = {
        "\u2212": "-",  # Minus Sign
        "\u00d7": "*",  # Multiplication Sign
        "\u00f7": "/",  # Division Sign
        "\u22c5": "*",  # Dot Operator
        "\u201c": '"',  # Left Double Quotation Mark
        "\u201d": '"',  # Right Double Quotation Mark
        "\u2018": "'",  # Left Single Quotation Mark
        "\u2019": "'",  # Right Single Quotation Mark
        "\u2014": "-",  # Em Dash
        "\u2013": "-",  # En Dash
        "\u2026": "...",# Horizontal Ellipsis
        "\u00a0": " ",  # No-Break Space
    }
    for uni, ascii_char in replacements.items():
        text = text.replace(uni, ascii_char)
    return text

def clean_and_split_solution(raw_text: str) -> Tuple[str, str | None]:
    """
    Takes a raw solution text, sanitizes it, and separates the reasoning
    lines from the final answer line.
    
    Returns:
        A tuple containing (cleaned_reasoning_text, final_answer_string).
        final_answer_string is None if not found.
    """
    if not isinstance(raw_text, str):
        return "", None
        
    # 1. Sanitize all characters first
    sanitized_text = sanitize_text(raw_text)
    
    # 2. Remove calculator annotations
    text_no_annotations = re.sub(r'<<.*?>>', '', sanitized_text)
    
    # 3. Remove comma separators from numbers
    text_no_commas = re.sub(r'(\d),(\d)', r'\1\2', text_no_annotations)
    
    lines = text_no_commas.split('\n')
    final_answer = None
    
    # 4. Find and extract the final answer line
    if lines and re.match(r'^\s*####\s*.*$', lines[-1]):
        final_answer_line = lines.pop().strip()
        # Extract just the number part after ####
        match = re.search(r'####\s*(.*)', final_answer_line)
        if match:
            final_answer = match.group(1).strip()

    # 5. Process the remaining reasoning lines
    cleaned_lines = [line.strip() for line in lines if line.strip()]
    reasoning_text = '\n'.join(cleaned_lines)
    
    return reasoning_text, final_answer

def convert_solution_to_json_str(cleaned_reasoning: str, final_answer: str | None) -> str:
    """
    Takes cleaned reasoning text and a final answer, and converts them into
    a single JSON-formatted string.
    """
    lines = cleaned_reasoning.split('\n')
    solution_dict = {f"L{i+1}": line for i, line in enumerate(lines) if line}
    
    if final_answer is not None:
        solution_dict["FA"] = final_answer
        
    return json.dumps(solution_dict, indent=2)

def extract_equations_from_annotations(raw_text: str) -> str:
    """
    Parses a raw solution text to find all calculator annotations (<<...>>)
    and returns them as a JSON-formatted dictionary string.
    """
    if not isinstance(raw_text, str):
        return "{}"
        
    # Sanitize text first to handle different dash types in equations
    sanitized_text = sanitize_text(raw_text)
    equations = {}
    lines = sanitized_text.split('\n')
    
    for i, line in enumerate(lines):
        annotations = re.findall(r'<<(.*?)>>', line)
        if annotations:
            # We assume at most one annotation per line
            equation_str = annotations[0].strip()
            equations[f"L{i+1}"] = equation_str
            
    return json.dumps(equations, indent=2)

def create_formatted_prompt(
    system_prompt: str,
    problem_text: str,
    solution_json_str: str
) -> str:
    """
    Assembles the final prompt string using the Phi-4-mini-instruct chat template.
    """
    user_prompt = f"Problem:\n```\n{problem_text}\n```\n\nSolution:\n```json\n{solution_json_str}\n```"
    return f"<|system|>{system_prompt}<|end|><|user|>{user_prompt}<|end|>"

# --- Verification of New Functions ---
print("Core utility functions (revised) defined. Running a quick test case...")

# Example usage:
test_raw_solution = "The first number is 1,200. The second is 10.\nTheir sum is 1,200 + 10 = <<1200+10=1210>>1210.\n#### 1210"
test_problem = "What is the sum of 1,200 and 10?"
test_system_prompt = "[CONCEPTUAL_CHECK]\nYou are an expert..."

# # 1. Test cleaning and splitting
# reasoning, fa = clean_and_split_solution(test_raw_solution)
# print(f"--- Cleaned Reasoning ---\n{reasoning}\n--------------------")
# print(f"--- Final Answer ---\n{fa}\n--------------------")

# # 2. Test JSON conversion with FA
# solution_json = convert_solution_to_json_str(reasoning, fa)
# print(f"--- Solution as JSON String ---\n{solution_json}\n-----------------------------")

# # 3. Test equation extraction (should be unchanged)
# equations_json = extract_equations_from_annotations(test_raw_solution)
# print(f"--- Extracted Equations ---\n{equations_json}\n-------------------------")

# # 4. Test prompt formatting (should be unchanged)
# final_prompt = create_formatted_prompt(test_system_prompt, test_problem, solution_json)
# print(f"--- Final Formatted Prompt ---\n{final_prompt}\n----------------------------")

Core utility functions (revised) defined. Running a quick test case...


In [124]:
def verify_utilities_on_samples(num_samples: int = 3):
    """
    Selects random samples from each data source (programmatic, manual, original)
    and runs them through the REVISED data transformation utilities to verify their output.
    """
    print("="*80)
    print("--- Verifying REVISED Utility Functions on Live Data Samples ---")
    print("="*80)

    # --- 1. Test on Programmatically Generated Computational Errors ---
    print("\n\n--- Testing on: Programmatic Computational Errors ---")
    if not programmatic_comp_df.empty:
        valid_programmatic_samples = programmatic_comp_df.dropna(subset=['filepath']).sample(n=num_samples, random_state=RANDOM_SEED)
        
        for i, (_, row) in enumerate(valid_programmatic_samples.iterrows()):
            print(f"\n--- Programmatic Sample {i+1}/{num_samples} (Index: {row['index']}) ---")
            try:
                filepath = PROJECT_ROOT / row['filepath']
                with open(filepath, 'r', encoding='utf-8') as f:
                    data = json.load(f)
                raw_solution = data['flawed_nl_solution']
                problem_text = gsm8k_problem_lookup[row['index']]['question']
                
                print(f"Original Raw Solution:\n---\n{raw_solution}\n---")
                reasoning, fa = clean_and_split_solution(raw_solution)
                print(f"Cleaned Reasoning:\n---\n{reasoning}\n---")
                print(f"Final Answer: '{fa}'")
                solution_json = convert_solution_to_json_str(reasoning, fa)
                print(f"Solution as JSON:\n---\n{solution_json}\n---")
                equations_json = extract_equations_from_annotations(raw_solution)
                print(f"Extracted Equations:\n---\n{equations_json}\n---")
            except Exception as e:
                print(f"ERROR processing sample: {e}")
    else:
        print("Skipping programmatic samples: DataFrame is empty.")

    # --- 2. Test on Manually Validated Conceptual Errors ---
    print("\n\n--- Testing on: Manually Validated Conceptual Errors ---")
    if manual_conceptual_data:
        k = min(num_samples, len(manual_conceptual_data))
        random_manual_samples = random.sample(manual_conceptual_data, k=k)
        
        for i, sample in enumerate(random_manual_samples):
            print(f"\n--- Manual Sample {i+1}/{k} (Index: {sample['index']}) ---")
            try:
                raw_solution = sample['flawed_nl_solution']
                problem_text = sample['question']

                print(f"Original Raw Solution:\n---\n{raw_solution}\n---")
                reasoning, fa = clean_and_split_solution(raw_solution)
                print(f"Cleaned Reasoning:\n---\n{reasoning}\n---")
                print(f"Final Answer: '{fa}'")
                solution_json = convert_solution_to_json_str(reasoning, fa)
                print(f"Solution as JSON:\n---\n{solution_json}\n---")
                equations_json = extract_equations_from_annotations(raw_solution)
                print(f"Extracted Equations:\n---\n{equations_json}\n---")
            except Exception as e:
                print(f"ERROR processing sample: {e}")
    else:
        print("Skipping manual samples: Data list is empty.")

    # --- 3. Test on Original GSM8K Correct Solutions ---
    print("\n\n--- Testing on: Original GSM8K Correct Solutions ---")
    if gsm8k_problem_lookup:
        random_indices = random.sample(list(gsm8k_problem_lookup.keys()), k=num_samples)

        for i, idx in enumerate(random_indices):
            print(f"\n--- Original GSM8K Sample {i+1}/{num_samples} (Index: {idx}) ---")
            try:
                sample = gsm8k_problem_lookup[idx]
                raw_solution = sample['answer']
                problem_text = sample['question']
                
                print(f"Original Raw Solution:\n---\n{raw_solution}\n---")
                reasoning, fa = clean_and_split_solution(raw_solution)
                print(f"Cleaned Reasoning:\n---\n{reasoning}\n---")
                print(f"Final Answer: '{fa}'")
                solution_json = convert_solution_to_json_str(reasoning, fa)
                print(f"Solution as JSON:\n---\n{solution_json}\n---")
                equations_json = extract_equations_from_annotations(raw_solution)
                print(f"Extracted Equations:\n---\n{equations_json}\n---")
            except Exception as e:
                print(f"ERROR processing sample: {e}")
    else:
        print("Skipping original GSM8K samples: Lookup dictionary is empty.")

# Set a new random seed for sampling to get different samples each run if desired
random.seed(RANDOM_SEED + 2) 
verify_utilities_on_samples(num_samples=3)

--- Verifying REVISED Utility Functions on Live Data Samples ---


--- Testing on: Programmatic Computational Errors ---

--- Programmatic Sample 1/3 (Index: 3612) ---
Original Raw Solution:
---
He has 30*6=<<30*6=810>>810 students
So he needs to buy 810*10=<<810*10=8100>>8100 index cards
That means he needs to buy 8100/50=<<8100/50=162>>162 packs of index cards
So he spends 162*3=$<<162*3=486>>486
#### 486
---
Cleaned Reasoning:
---
He has 30*6=810 students
So he needs to buy 810*10=8100 index cards
That means he needs to buy 8100/50=162 packs of index cards
So he spends 162*3=$486
---
Final Answer: '486'
Solution as JSON:
---
{
  "L1": "He has 30*6=810 students",
  "L2": "So he needs to buy 810*10=8100 index cards",
  "L3": "That means he needs to buy 8100/50=162 packs of index cards",
  "L4": "So he spends 162*3=$486",
  "FA": "486"
}
---
Extracted Equations:
---
{
  "L1": "30*6=810",
  "L2": "810*10=8100",
  "L3": "8100/50=162",
  "L4": "162*3=486"
}
---

--- Programmatic Sample 2/

In [125]:
# --- Load the consolidated manual errors CSV ---
# This should be defined in Cell 2, but we'll place it here for clarity in this snippet.
MANUAL_ERRORS_CSV_PATH = DATA_DIR / "manually_generated_errors_final.csv"
try:
    manual_errors_df = pd.read_csv(MANUAL_ERRORS_CSV_PATH)
    print(f"Successfully loaded {len(manual_errors_df)} records from {MANUAL_ERRORS_CSV_PATH.name}")
except FileNotFoundError:
    manual_errors_df = pd.DataFrame()
    print(f"WARNING: Manual errors CSV not found at {MANUAL_ERRORS_CSV_PATH}")


def check_annotation_coverage(solution_text: str) -> bool:
    """
    Checks if every line containing an '=' sign also has a calculator annotation.
    Returns True if coverage is complete, False otherwise.
    """
    if not isinstance(solution_text, str) or pd.isna(solution_text):
        return True # Vacuously true for empty/invalid input

    # Remove the final answer line to avoid checking it
    lines = solution_text.split('\n')
    if lines and re.match(r'^\s*####\s*.*$', lines[-1]):
        lines.pop()
        
    for line in lines:
        if '=' in line:
            if '<<' not in line or '>>' not in line:
                return False # Found a line with a calculation but no annotation
                
    return True

# --- Analyze All Data Sources with Corrected Logic for Manual Errors ---
print("\n" + "="*80)
print("--- Analyzing Annotation Coverage in Source Datasets (v2) ---")
print("="*80)

# --- 1. Analyze Original GSM8K Dataset ---
total_gsm8k = len(GSM8K_DATASET)
incomplete_gsm8k_count = 0
for sample in tqdm(GSM8K_DATASET, desc="Analyzing Original GSM8K"):
    if not check_annotation_coverage(sample['answer']):
        incomplete_gsm8k_count += 1

print(f"\nOriginal GSM8K Dataset:")
print(f"  - Samples with at least one missing annotation: {incomplete_gsm8k_count} / {total_gsm8k} ({incomplete_gsm8k_count/total_gsm8k:.2%})")


# --- 2. Analyze Programmatic Computational Errors ---
total_programmatic = 0
incomplete_programmatic_count = 0
unique_programmatic_indices = programmatic_comp_df.dropna(subset=['filepath'])['index'].unique()

for idx in tqdm(unique_programmatic_indices, desc="Analyzing Programmatic Errors"):
    row = programmatic_comp_df[programmatic_comp_df['index'] == idx].iloc[0]
    try:
        filepath = PROJECT_ROOT / row['filepath']
        with open(filepath, 'r', encoding='utf-8') as f: data = json.load(f)
        solution_text = data.get('flawed_nl_solution')
        if solution_text:
            total_programmatic += 1
            if not check_annotation_coverage(solution_text):
                incomplete_programmatic_count += 1
    except (FileNotFoundError, json.JSONDecodeError, KeyError):
        continue

if total_programmatic > 0:
    print(f"\nProgrammatic Computational Errors:")
    print(f"  - Samples with at least one missing annotation: {incomplete_programmatic_count} / {total_programmatic} ({incomplete_programmatic_count/total_programmatic:.2%})")
else:
    print("\nProgrammatic Computational Errors: No valid files to analyze.")


# --- 3. Analyze Manually Generated Errors (from the correct CSV) ---
total_manual = len(manual_errors_df)
incomplete_manual_correct_count = 0
incomplete_manual_flawed_count = 0

if total_manual > 0:
    for _, row in tqdm(manual_errors_df.iterrows(), total=total_manual, desc="Analyzing Manual Errors CSV"):
        # Check the 'answer' column (correct solutions)
        if not check_annotation_coverage(row['answer']):
            incomplete_manual_correct_count += 1
        # Check the 'wrong_answer' column (flawed solutions)
        if not check_annotation_coverage(row['wrong_answer']):
            incomplete_manual_flawed_count += 1
            
    print(f"\nManually Generated Errors CSV:")
    print(f"  - Correct solutions ('answer' col) with missing annotations: {incomplete_manual_correct_count} / {total_manual} ({incomplete_manual_correct_count/total_manual:.2%})")
    print(f"  - Flawed solutions ('wrong_answer' col) with missing annotations: {incomplete_manual_flawed_count} / {total_manual} ({incomplete_manual_flawed_count/total_manual:.2%})")
else:
    print("\nManually Generated Errors CSV: No data to analyze.")

Successfully loaded 1963 records from manually_generated_errors_final.csv

--- Analyzing Annotation Coverage in Source Datasets (v2) ---


Analyzing Original GSM8K:   0%|          | 0/7473 [00:00<?, ?it/s]


Original GSM8K Dataset:
  - Samples with at least one missing annotation: 1322 / 7473 (17.69%)


Analyzing Programmatic Errors:   0%|          | 0/6636 [00:00<?, ?it/s]


Programmatic Computational Errors:
  - Samples with at least one missing annotation: 980 / 6636 (14.77%)


Analyzing Manual Errors CSV:   0%|          | 0/1963 [00:00<?, ?it/s]


Manually Generated Errors CSV:
  - Correct solutions ('answer' col) with missing annotations: 349 / 1963 (17.78%)
  - Flawed solutions ('wrong_answer' col) with missing annotations: 344 / 1963 (17.52%)


In [126]:
# Define the system prompts for each task as constants
# These will be used by the create_formatted_prompt function.

SYSTEM_PROMPT_CONCEPTUAL = """[CONCEPTUAL_CHECK]

You are an expert mathematical reasoning analyst. Your task is to verify the logical soundness of the provided solution based on the problem statement. The solution is formatted as a JSON dictionary mapping line numbers to text. You must IGNORE any potential arithmetic errors in the final calculations and focus ONLY on the conceptual logic.

- Does the solution use the correct numbers from the problem?
- Does it use the correct mathematical operations (e.g., multiplication where required, not addition)?
- Is the overall logical flow of the steps correct?

If the conceptual logic is sound, your entire output must be the single word:
None

If you find a conceptual error, your entire output must be a single line in the format `L{n}: {a brief explanation of the error}`. Do not include any other text."""

SYSTEM_PROMPT_EXTRACTION = """[EXTRACT_CALCULATIONS]

You are a data extraction tool. Your task is to parse the provided solution, which is formatted as a JSON dictionary, and extract every mathematical calculation into a new JSON dictionary.

- The keys of the output dictionary should be the line numbers (e.g., "L1", "L2").
- The values should be a string containing the full equation as written (e.g., "10 + 5 = 15").
- If a line contains no calculation, it should be omitted from the output dictionary.
- If the entire solution contains no calculations, return an empty JSON dictionary: {}

Your output must be ONLY the JSON object and nothing else."""

print("System prompt templates defined.")

System prompt templates defined.


In [127]:
# --- 1. Pre-filter for Annotation Coverage (for Task B) ---
print("="*80)
print("Step 1: Pre-filtering all sources for complete annotation coverage...")

perfect_manual_indices = set()
if not manual_errors_df.empty:
    for _, row in tqdm(manual_errors_df.iterrows(), total=len(manual_errors_df), desc="Analyzing Manual CSV"):
        if check_annotation_coverage(row['answer']) and check_annotation_coverage(row['wrong_answer']):
            perfect_manual_indices.add(row['index'])
print(f"-> Found {len(perfect_manual_indices):,} manually generated problems with full annotation coverage.")

perfect_programmatic_indices = set()
if not programmatic_comp_df.empty:
    unique_indices = programmatic_comp_df['index'].unique()
    for idx in tqdm(unique_indices, desc="Analyzing Programmatic Catalog"):
        original_answer = gsm8k_problem_lookup.get(idx, {}).get('answer')
        if not original_answer or not check_annotation_coverage(original_answer): continue
        rows = programmatic_comp_df[programmatic_comp_df['index'] == idx]
        is_perfect = False
        for _, row in rows.iterrows():
            filepath_val = row.get('filepath')
            if pd.isna(filepath_val): continue
            try:
                filepath = PROJECT_ROOT / str(filepath_val)
                with open(filepath, 'r', encoding='utf-8') as f: data = json.load(f)
                if check_annotation_coverage(data.get('flawed_nl_solution')):
                    is_perfect = True
                    break
            except Exception: continue
        if is_perfect: perfect_programmatic_indices.add(idx)
print(f"-> Found {len(perfect_programmatic_indices):,} programmatically generated problems with full annotation coverage.")
all_perfect_indices_for_B = perfect_manual_indices.union(perfect_programmatic_indices)

# --- 2. Define Quotas and Core Problems ---
print("\n" + "="*80)
print("Step 2: Defining dataset quotas and identifying core problems for Task A...")

conceptual_df = manual_errors_df[manual_errors_df['error_type'] == 'conceptual'].copy()
core_problem_indices = sorted(list(conceptual_df['index'].unique()))
N_CONCEPTUAL = len(core_problem_indices)
rng = random.Random(RANDOM_SEED)
rng.shuffle(core_problem_indices) # Shuffle the list of core problems for random assignment

QUOTA_A_CONCEPTUAL = N_CONCEPTUAL
QUOTA_A_COMPUTATIONAL = N_CONCEPTUAL // 2
QUOTA_A_CORRECT = N_CONCEPTUAL - QUOTA_A_COMPUTATIONAL # The remainder
QUOTA_B_PAIRS = (N_CONCEPTUAL * 3) // 4

print(f"Anchor: Found {N_CONCEPTUAL} unique manual conceptual problems (these are the 'core problems').")
print(f"Task A Quotas: {QUOTA_A_CONCEPTUAL} conceptual, {QUOTA_A_CORRECT} correct, {QUOTA_A_COMPUTATIONAL} computational.")
print(f"Task B Quota: {QUOTA_B_PAIRS} correct/flawed pairs (from non-core problems).")

# --- 3. Assemble Dataset Sequentially ---
print("\n" + "="*80)
print("Step 3: Assembling dataset...")

sft_samples, metadata_log = [], []
used_indices_A, used_indices_B = set(), set()

# --- TASK A ASSEMBLY (from Core Problems only) ---
# A1: Add all N conceptual flawed samples
for idx in tqdm(core_problem_indices, desc="Task A - Conceptual Flawed"):
    row = conceptual_df[conceptual_df['index'] == idx].sample(n=1, random_state=rng.randint(0, 10**9)).iloc[0]
    f_reasoning, f_fa = clean_and_split_solution(row['wrong_answer'])
    f_solution_json_str = convert_solution_to_json_str(f_reasoning, f_fa)
    explanation = f"{row['erroneous_line_number']}: {row['explanation']}"
    prompt = create_formatted_prompt(SYSTEM_PROMPT_CONCEPTUAL, row['question'], f_solution_json_str)
    sft_samples.append({"text": f"{prompt}<|assistant|>{explanation}<end>", "task": "conceptual_check"})
    metadata_log.append({"index": idx, "task": "conceptual_check", "type": "conceptual_flawed", "source": row['filepath']})
    used_indices_A.add(idx)

# A2: Add N/2 computational flawed samples, drawn from core problems
indices_for_comp = core_problem_indices[:QUOTA_A_COMPUTATIONAL]
indices_for_correct = core_problem_indices[QUOTA_A_COMPUTATIONAL:]

# Manual First
comp_manual_df = manual_errors_df[(manual_errors_df['error_type'] == 'computational') & (manual_errors_df['index'].isin(indices_for_comp))].copy()
manual_indices_found = list(comp_manual_df['index'].unique())
for idx in tqdm(manual_indices_found, desc=f"Task A - Computational (Manual)"):
    row = comp_manual_df[comp_manual_df['index'] == idx].sample(n=1, random_state=rng.randint(0, 10**9)).iloc[0]
    f_reasoning, f_fa = clean_and_split_solution(row['wrong_answer'])
    f_solution_json_str = convert_solution_to_json_str(f_reasoning, f_fa)
    prompt = create_formatted_prompt(SYSTEM_PROMPT_CONCEPTUAL, row['question'], f_solution_json_str)
    sft_samples.append({"text": f"{prompt}<|assistant|>None<end>", "task": "conceptual_check"})
    metadata_log.append({"index": idx, "task": "conceptual_check", "type": "computational_flawed", "source": row['filepath']})

# Programmatic to fill the rest
num_still_needed = QUOTA_A_COMPUTATIONAL - len(manual_indices_found)
prog_candidates = [idx for idx in indices_for_comp if idx not in manual_indices_found]
prog_indices_to_take = prog_candidates[:num_still_needed]

if num_still_needed > 0 and not programmatic_comp_df.empty:
    prog_comp_df_A = programmatic_comp_df[programmatic_comp_df['index'].isin(prog_indices_to_take)].copy()
    for idx in tqdm(prog_indices_to_take, desc=f"Task A - Computational (Prog)"):
        rows = prog_comp_df_A[prog_comp_df_A['index'] == idx]
        if rows.empty: continue
        row = rows.sample(n=1, random_state=rng.randint(0, 10**9)).iloc[0]
        filepath_val = row.get('filepath')
        if pd.isna(filepath_val): continue
        try:
            filepath = PROJECT_ROOT / str(filepath_val)
            with open(filepath, 'r') as f: data = json.load(f)
            problem_text = gsm8k_problem_lookup[idx]['question']
            f_reasoning, f_fa = clean_and_split_solution(data['flawed_nl_solution'])
            f_solution_json_str = convert_solution_to_json_str(f_reasoning, f_fa)
            prompt = create_formatted_prompt(SYSTEM_PROMPT_CONCEPTUAL, problem_text, f_solution_json_str)
            sft_samples.append({"text": f"{prompt}<|assistant|>None<end>", "task": "conceptual_check"})
            metadata_log.append({"index": idx, "task": "conceptual_check", "type": "computational_flawed", "source": row['filepath']})
        except Exception: continue

# A3: Add remaining N/2 correct original samples
for idx in tqdm(indices_for_correct, desc=f"Task A - Correct"):
    problem = gsm8k_problem_lookup[idx]
    c_reasoning, c_fa = clean_and_split_solution(problem['answer'])
    c_solution_json_str = convert_solution_to_json_str(c_reasoning, c_fa)
    prompt = create_formatted_prompt(SYSTEM_PROMPT_CONCEPTUAL, problem['question'], c_solution_json_str)
    sft_samples.append({"text": f"{prompt}<|assistant|>None<end>", "task": "conceptual_check"})
    metadata_log.append({"index": idx, "task": "conceptual_check", "type": "correct_original", "source": "gsm8k"})

# --- TASK B ASSEMBLY (from non-Core problems only) ---
# Manual First
b_manual_df = manual_errors_df[(manual_errors_df['error_type'] == 'computational') & (~manual_errors_df['index'].isin(used_indices_A)) & (manual_errors_df['index'].isin(all_perfect_indices_for_B))].copy()
b_manual_indices = list(b_manual_df['index'].unique())
rng.shuffle(b_manual_indices)
b_indices_to_take_manual = b_manual_indices[:QUOTA_B_PAIRS]

for idx in tqdm(b_indices_to_take_manual, desc=f"Task B - Extraction Pairs (Manual)"):
    row = b_manual_df[b_manual_df['index'] == idx].sample(n=1, random_state=rng.randint(0, 10**9)).iloc[0]
    # Correct
    c_reasoning, c_fa = clean_and_split_solution(row['answer'])
    c_solution_json_str = convert_solution_to_json_str(c_reasoning, c_fa)
    c_equations = extract_equations_from_annotations(row['answer'])
    c_prompt = create_formatted_prompt(SYSTEM_PROMPT_EXTRACTION, row['question'], c_solution_json_str)
    sft_samples.append({"text": f"{c_prompt}<|assistant|>{c_equations}<end>", "task": "equation_extraction"})
    metadata_log.append({"index": idx, "task": "equation_extraction", "type": "correct_paired", "source": "gsm8k"})
    # Flawed
    f_reasoning, f_fa = clean_and_split_solution(row['wrong_answer'])
    f_solution_json_str = convert_solution_to_json_str(f_reasoning, f_fa)
    f_equations = extract_equations_from_annotations(row['wrong_answer'])
    f_prompt = create_formatted_prompt(SYSTEM_PROMPT_EXTRACTION, row['question'], f_solution_json_str)
    sft_samples.append({"text": f"{f_prompt}<|assistant|>{f_equations}<end>", "task": "equation_extraction"})
    metadata_log.append({"index": idx, "task": "equation_extraction", "type": "computational_paired", "source": row['filepath']})
    used_indices_B.add(idx)

# Programmatic to fill the rest
num_still_needed_b = QUOTA_B_PAIRS - len(b_indices_to_take_manual)
if num_still_needed_b > 0:
    b_prog_df = programmatic_comp_df[(~programmatic_comp_df['index'].isin(used_indices_A)) & (~programmatic_comp_df['index'].isin(used_indices_B)) & (programmatic_comp_df['index'].isin(all_perfect_indices_for_B))].copy()
    b_prog_indices = list(b_prog_df['index'].unique())
    rng.shuffle(b_prog_indices)
    b_indices_to_take_prog = b_prog_indices[:num_still_needed_b]
    for idx in tqdm(b_indices_to_take_prog, desc=f"Task B - Extraction Pairs (Prog)"):
        row = b_prog_df[b_prog_df['index'] == idx].sample(n=1, random_state=rng.randint(0, 10**9)).iloc[0]
        filepath_val = row.get('filepath')
        if pd.isna(filepath_val): continue
        try:
            filepath = PROJECT_ROOT / str(filepath_val)
            with open(filepath, 'r') as f: data = json.load(f)
            problem = gsm8k_problem_lookup[idx]
            # Correct
            c_reasoning, c_fa = clean_and_split_solution(problem['answer'])
            c_solution_json_str = convert_solution_to_json_str(c_reasoning, c_fa)
            c_equations = extract_equations_from_annotations(problem['answer'])
            c_prompt = create_formatted_prompt(SYSTEM_PROMPT_EXTRACTION, problem['question'], c_solution_json_str)
            sft_samples.append({"text": f"{c_prompt}<|assistant|>{c_equations}<end>", "task": "equation_extraction"})
            metadata_log.append({"index": idx, "task": "equation_extraction", "type": "correct_paired", "source": "gsm8k"})
            # Flawed
            f_reasoning, f_fa = clean_and_split_solution(data['flawed_nl_solution'])
            f_solution_json_str = convert_solution_to_json_str(f_reasoning, f_fa)
            f_equations = extract_equations_from_annotations(data['flawed_nl_solution'])
            f_prompt = create_formatted_prompt(SYSTEM_PROMPT_EXTRACTION, problem['question'], f_solution_json_str)
            sft_samples.append({"text": f"{f_prompt}<|assistant|>{f_equations}<end>", "task": "equation_extraction"})
            metadata_log.append({"index": idx, "task": "equation_extraction", "type": "computational_paired", "source": row['filepath']})
            used_indices_B.add(idx)
        except Exception: continue

# --- Final Summary ---
final_metadata_df = pd.DataFrame(metadata_log)
print("\n" + "="*80)
print("--- Data Assembly Complete: Final Summary ---")
print(f"Total SFT samples generated: {len(sft_samples):,}")
print(f"Total unique problems used for Task A: {len(used_indices_A):,}")
print(f"Total unique problems used for Task B: {len(used_indices_B):,}")

print("\n--- Final Dataset Composition ---")
print(final_metadata_df.groupby(['task', 'type']).size().reset_index(name='count').to_string(index=False))

Step 1: Pre-filtering all sources for complete annotation coverage...


Analyzing Manual CSV:   0%|          | 0/1963 [00:00<?, ?it/s]

-> Found 1,196 manually generated problems with full annotation coverage.


Analyzing Programmatic Catalog:   0%|          | 0/7304 [00:00<?, ?it/s]

-> Found 5,585 programmatically generated problems with full annotation coverage.

Step 2: Defining dataset quotas and identifying core problems for Task A...
Anchor: Found 1030 unique manual conceptual problems (these are the 'core problems').
Task A Quotas: 1030 conceptual, 515 correct, 515 computational.
Task B Quota: 772 correct/flawed pairs (from non-core problems).

Step 3: Assembling dataset...


Task A - Conceptual Flawed:   0%|          | 0/1030 [00:00<?, ?it/s]

Task A - Computational (Manual):   0%|          | 0/252 [00:00<?, ?it/s]

Task A - Computational (Prog):   0%|          | 0/263 [00:00<?, ?it/s]

Task A - Correct:   0%|          | 0/515 [00:00<?, ?it/s]

Task B - Extraction Pairs (Manual):   0%|          | 0/349 [00:00<?, ?it/s]

Task B - Extraction Pairs (Prog):   0%|          | 0/423 [00:00<?, ?it/s]


--- Data Assembly Complete: Final Summary ---
Total SFT samples generated: 3,565
Total unique problems used for Task A: 1,030
Total unique problems used for Task B: 772

--- Final Dataset Composition ---
               task                 type  count
   conceptual_check computational_flawed    476
   conceptual_check    conceptual_flawed   1030
   conceptual_check     correct_original    515
equation_extraction computational_paired    772
equation_extraction       correct_paired    772


In [128]:
import pandas as pd
import random

# Ensure the metadata dataframe is created from the log, and sft_text is added.
# This should be run after Cell 5 completes.
final_metadata_df = pd.DataFrame(metadata_log)
if 'sft_text' not in final_metadata_df.columns:
    final_metadata_df['sft_text'] = [s['text'] for s in sft_samples]

# Create a new random number generator for this cell to ensure variety
rng = random.Random(RANDOM_SEED + 30)

def visualize_samples(df: pd.DataFrame, description: str, num_samples: int = 1, **kwargs):
    """
    Finds and prints a specified number of random samples from the dataframe
    based on filter criteria.
    """
    print("\n" + "="*80)
    print(f"{description} (Displaying {num_samples} sample(s))")
    print("="*80)
    
    query_parts = []
    for col, val in kwargs.items():
        if isinstance(val, str):
            query_parts.append(f'`{col}` == "{val}"')
        else:
            query_parts.append(f'`{col}` == {val}')
    query = ' & '.join(query_parts)
    
    try:
        # Select n random samples that match the query
        # Use min(num_samples, len(matching_df)) to avoid errors if fewer samples exist
        matching_df = df.query(query)
        n_to_sample = min(num_samples, len(matching_df))
        if n_to_sample == 0:
            raise ValueError("No samples found matching criteria.")
            
        samples = matching_df.sample(n=n_to_sample, random_state=rng.randint(0, 10**9))
        
        for i, (_, sample) in enumerate(samples.iterrows()):
            print(f"\n--- Sample {i+1}/{n_to_sample} ---")
            print(f"METADATA: Index={sample['index']}, Task='{sample['task']}', Type='{sample['type']}'")
            
            parts = sample['sft_text'].split('<|assistant|>')
            prompt_part = parts[0] + '<|assistant|>'
            response_part = parts[1]
            
            print("\nPROMPT (What the model sees):")
            print(prompt_part)
            
            print("\nRESPONSE (What the model must learn):")
            print(response_part)
            
        return samples['index'].tolist() # Return list of indices for paired lookups
        
    except (IndexError, ValueError) as e:
        print(f"\nERROR: Could not find enough samples. Details: {e}\nCriteria: {kwargs}")
        return None


In [129]:
# --- Visualization Requests ---

# 1. Visualize conceptual vs. correct samples for Task A
print("\n--- Request 1: Conceptual Flawed vs. Correct Original from Task A ---")
visualize_samples(
    final_metadata_df,
    "1a: 'conceptual_flawed' samples",
    num_samples=2,
    task="conceptual_check",
    type="conceptual_flawed"
)
visualize_samples(
    final_metadata_df,
    "1b: 'correct_original' samples",
    num_samples=2,
    task="conceptual_check",
    type="correct_original"
)

# 2. Visualize conceptual vs. computational samples for Task A
print("\n--- Request 2: Conceptual Flawed vs. Computational Flawed from Task A ---")
visualize_samples(
    final_metadata_df,
    "2a: 'conceptual_flawed' samples (target should be an error string)",
    num_samples=2,
    task="conceptual_check",
    type="conceptual_flawed"
)
visualize_samples(
    final_metadata_df,
    "2b: 'computational_flawed' samples (target should be 'None')",
    num_samples=2,
    task="conceptual_check",
    type="computational_flawed"
)

# 3. One computational error + original pair from Task B
print("\n--- Request 3: Paired 'correct' and 'flawed' samples for Task B ---")
try:
    pair_df = final_metadata_df[
        (final_metadata_df['task'] == 'equation_extraction') &
        (final_metadata_df['type'].isin(['correct_paired', 'computational_paired']))
    ]
    valid_indices = pair_df['index'].value_counts()
    valid_indices = valid_indices[valid_indices == 2].index
    
    # Select 2 random indices to visualize pairs for
    indices_to_show = rng.sample(list(valid_indices), k=min(2, len(valid_indices)))

    for i, idx in enumerate(indices_to_show):
        print(f"\n--- Pair {i+1}/{len(indices_to_show)} for Index {idx} ---")
        visualize_samples(
            final_metadata_df,
            f"3a: The 'correct_paired' sample",
            num_samples=1,
            task="equation_extraction",
            type="correct_paired",
            index=idx
        )
        visualize_samples(
            final_metadata_df,
            f"3b: The 'computational_paired' sample",
            num_samples=1,
            task="equation_extraction",
            type="computational_paired",
            index=idx
        )
except (IndexError, ValueError) as e:
    print(f"\nERROR: Could not find enough valid computational pairs to display. Details: {e}")


--- Request 1: Conceptual Flawed vs. Correct Original from Task A ---

1a: 'conceptual_flawed' samples (Displaying 2 sample(s))

--- Sample 1/2 ---
METADATA: Index=986, Task='conceptual_check', Type='conceptual_flawed'

PROMPT (What the model sees):
<|system|>[CONCEPTUAL_CHECK]

You are an expert mathematical reasoning analyst. Your task is to verify the logical soundness of the provided solution based on the problem statement. The solution is formatted as a JSON dictionary mapping line numbers to text. You must IGNORE any potential arithmetic errors in the final calculations and focus ONLY on the conceptual logic.

- Does the solution use the correct numbers from the problem?
- Does it use the correct mathematical operations (e.g., multiplication where required, not addition)?
- Is the overall logical flow of the steps correct?

If the conceptual logic is sound, your entire output must be the single word:
None

If you find a conceptual error, your entire output must be a single line 

In [130]:
import pandas as pd
from datasets import Dataset, DatasetDict
import random

# --- Configuration for Splitting ---
TRAIN_SPLIT_RATIO = 0.8  # 80% of unique problems for training
VALIDATION_SPLIT_RATIO = 0.1 # 10% for validation, 10% for test

def split_dataset_by_index(
    df: pd.DataFrame,
    train_size: float = 0.8,
    val_size: float = 0.1,
    seed: int = 42
) -> DatasetDict:
    """
    Splits a DataFrame into train, validation, and test sets based on a unique
    'index' column to prevent data leakage.
    """
    # Get all unique problem indices and shuffle them reproducibly
    all_indices = sorted(list(df['index'].unique()))
    random.Random(seed).shuffle(all_indices)
    
    # Calculate split points
    train_end = int(len(all_indices) * train_size)
    val_end = train_end + int(len(all_indices) * val_size)
    
    # Create sets of indices for each split
    train_indices = set(all_indices[:train_end])
    val_indices = set(all_indices[train_end:val_end])
    test_indices = set(all_indices[val_end:])
    
    # Filter the DataFrame for each split
    train_df = df[df['index'].isin(train_indices)]
    val_df = df[df['index'].isin(val_indices)]
    test_df = df[df['index'].isin(test_indices)]
    
    # Convert pandas DataFrames to Hugging Face Dataset objects
    return DatasetDict({
        "train": Dataset.from_pandas(train_df),
        "validation": Dataset.from_pandas(val_df),
        "test": Dataset.from_pandas(test_df)
    })

# --- 1. Create the final comprehensive DataFrame ---
# This is the same DataFrame used for visualization, ensuring consistency.
final_df = pd.DataFrame(metadata_log)
final_df['text'] = [s['text'] for s in sft_samples] # Add the full text for SFT

# Ensure 'text' is the last column for clarity, and drop the now-redundant 'task' column
final_df = final_df[['index', 'task', 'type', 'source', 'text']]

print("--- Final DataFrame created with all samples and metadata ---")
print(f"Total samples: {len(final_df):,}")
print(final_df.head())


# --- 2. Split the dataset to prevent data leakage ---
print("\n--- Splitting dataset into train, validation, and test sets ---")
sft_dataset_dict = split_dataset_by_index(
    final_df,
    train_size=TRAIN_SPLIT_RATIO,
    val_size=VALIDATION_SPLIT_RATIO,
    seed=RANDOM_SEED
)


# --- 3. Save all artifacts to the output directory ---
print(f"\n--- Saving all artifacts to: {OUTPUT_DIR} ---")

# a. Save the Hugging Face DatasetDict
dataset_path = OUTPUT_DIR / "sft_dataset"
sft_dataset_dict.save_to_disk(dataset_path)
print(f"  - SFT Dataset (train/validation/test splits) saved to: {dataset_path}")

# b. Save the full metadata catalog as a CSV
catalog_path = OUTPUT_DIR / "sft_catalog.csv"
# We save final_df as it contains all metadata plus the final text
final_df.to_csv(catalog_path, index=False)
print(f"  - Full metadata catalog saved to: {catalog_path}")

print("\n--- Process Complete ---")
print("Final Dataset Structure:")
print(sft_dataset_dict)

--- Final DataFrame created with all samples and metadata ---
Total samples: 3,565
   index              task               type  \
0   1280  conceptual_check  conceptual_flawed   
1    788  conceptual_check  conceptual_flawed   
2   1500  conceptual_check  conceptual_flawed   
3    840  conceptual_check  conceptual_flawed   
4    882  conceptual_check  conceptual_flawed   

                                              source  \
0  yewei/gsm8k_data/conceptual/gsm8k_augmented_12...   
1  data/manually_gen_incorrect_answers_gsm8k/gsm8...   
2                    data/1500_1599_conceptual.jsonl   
3  data/manually_gen_incorrect_answers_gsm8k/gsm8...   
4  data/manually_gen_incorrect_answers_gsm8k/gsm8...   

                                                text  
0  <|system|>[CONCEPTUAL_CHECK]\n\nYou are an exp...  
1  <|system|>[CONCEPTUAL_CHECK]\n\nYou are an exp...  
2  <|system|>[CONCEPTUAL_CHECK]\n\nYou are an exp...  
3  <|system|>[CONCEPTUAL_CHECK]\n\nYou are an exp...  
4  <|syste

Saving the dataset (0/1 shards):   0%|          | 0/2846 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/360 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/359 [00:00<?, ? examples/s]

  - SFT Dataset (train/validation/test splits) saved to: /Users/arvindsuresh/Documents/Github/Erdos-DL-June25-Math/data/sft-datasets/verifier-v2-two-task/sft_dataset
  - Full metadata catalog saved to: /Users/arvindsuresh/Documents/Github/Erdos-DL-June25-Math/data/sft-datasets/verifier-v2-two-task/sft_catalog.csv

--- Process Complete ---
Final Dataset Structure:
DatasetDict({
    train: Dataset({
        features: ['index', 'task', 'type', 'source', 'text', '__index_level_0__'],
        num_rows: 2846
    })
    validation: Dataset({
        features: ['index', 'task', 'type', 'source', 'text', '__index_level_0__'],
        num_rows: 360
    })
    test: Dataset({
        features: ['index', 'task', 'type', 'source', 'text', '__index_level_0__'],
        num_rows: 359
    })
})
