### Cell 1: Imports and Path Definitions

In [14]:
import json
import random
from pathlib import Path

import numpy as np
import pandas as pd
from datasets import load_dataset, Dataset

# --- Path and Directory Definitions ---
def find_project_root(marker: str = ".git") -> Path:
    """Traverse upwards to find the project root, marked by the git repository."""
    current_path = Path.cwd().resolve()
    while current_path != current_path.parent:
        if (current_path / marker).exists():
            return current_path
        current_path = current_path.parent
    raise FileNotFoundError(f"Could not find project root. Marker '{marker}' not found.")

PROJECT_ROOT = find_project_root()
DATA_DIR = PROJECT_ROOT / 'data'

GENERATED_ERRORS_DIR = DATA_DIR / "computational-errors-generated"
CATALOG_FILE = GENERATED_ERRORS_DIR / "computational_error_catalog.csv"

# --- Load Original GSM8K Dataset for 'Correct' examples ---
GSM8K_TRAIN: Dataset = load_dataset("gsm8k", "main")["train"]

print(f"Project root: {PROJECT_ROOT}")
print(f"Loading catalog from: {CATALOG_FILE}")

Project root: /Users/arvindsuresh/Documents/Github/Erdos-DL-June25-Math
Loading catalog from: /Users/arvindsuresh/Documents/Github/Erdos-DL-June25-Math/data/computational-errors-generated/computational_error_catalog.csv


### Cell 2: Load and Inspect the Error Catalog

In [15]:
# --- Load the master catalog into a pandas DataFrame ---
try:
    catalog_df = pd.read_csv(CATALOG_FILE)
    print("Successfully loaded the error catalog.")
    print(f"Total available flawed examples: {len(catalog_df)}")
except FileNotFoundError:
    print(f"ERROR: Catalog file not found at {CATALOG_FILE}")
    catalog_df = pd.DataFrame()

if not catalog_df.empty:
    print("\n--- Catalog Schema ---")
    catalog_df.info()

Successfully loaded the error catalog.
Total available flawed examples: 15529

--- Catalog Schema ---
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 15529 entries, 0 to 15528
Data columns (total 12 columns):
 #   Column                   Non-Null Count  Dtype 
---  ------                   --------------  ----- 
 0   index                    15529 non-null  int64 
 1   tier                     15529 non-null  object
 2   model                    15529 non-null  object
 3   correct_trace_generated  15529 non-null  bool  
 4   target_variable          15179 non-null  object
 5   error_type               15525 non-null  object
 6   correct_value            15179 non-null  object
 7   flawed_value             15179 non-null  object
 8   repro_seed               15529 non-null  object
 9   date_utc                 15179 non-null  object
 10  time_utc                 15179 non-null  object
 11  filepath                 15179 non-null  object
dtypes: bool(1), int64(1), object(10)
memory us

### Cell 3: Core Utility Functions

In [16]:
def load_error_artifact(filepath_str: str) -> dict | None:
    """Loads a single JSON error artifact from disk, handling potential errors."""
    if not isinstance(filepath_str, str):
        return None # Guard against non-string filepaths (e.g., NaN)
    
    full_path = PROJECT_ROOT / filepath_str
    try:
        with open(full_path, 'r', encoding='utf-8') as f:
            return json.load(f)
    except (FileNotFoundError, json.JSONDecodeError) as e:
        print(f"Warning: Could not load artifact at {full_path}. Reason: {e}")
        return None

def construct_sft_input(problem_text: str, solution_text: str) -> str:
    """Constructs the standardized input string for the model."""
    return f"### Problem:\n{problem_text}\n\n### Solution:\n{solution_text}"

def get_original_problem_and_solution(index: int) -> tuple[str, str] | None:
    """Retrieves the original problem and solution text from the GSM8K dataset."""
    index = int(index)
    if index < 0 or index >= len(GSM8K_TRAIN):
        return None
    sample = GSM8K_TRAIN[index]
    return sample['question'], sample['answer']

print("Core utility functions defined.")

Core utility functions defined.


### Cell 4: SQL-style query functions

In [17]:
def query_by_index(df: pd.DataFrame, indices: list[int]) -> pd.DataFrame:
    """Selects rows matching the given list of problem indices."""
    return df[df['index'].isin(indices)].copy()

def query_by_index_range(df: pd.DataFrame, min_index: int, max_index: int) -> pd.DataFrame:
    """Selects rows where the problem index is within the specified range (inclusive)."""
    return df[df['index'].between(min_index, max_index)].copy()

def query_by_tier(df: pd.DataFrame, tiers: list[str]) -> pd.DataFrame:
    """Selects rows matching the given list of tiers."""
    return df[df['tier'].isin(tiers)].copy()

def query_by_model(df: pd.DataFrame, models: list[str]) -> pd.DataFrame:
    """Selects rows matching the given list of models."""
    return df[df['model'].isin(models)].copy()

def query_by_error_type(df: pd.DataFrame, error_types: list[str]) -> pd.DataFrame:
    """Selects rows matching the given list of error types."""
    return df[df['error_type'].isin(error_types)].copy()

print("SQL-style query functions defined.")

SQL-style query functions defined.


### Cell 5: The Dataset Assembly Function

In [18]:
def create_sft_dataset(
    catalog_df: pd.DataFrame,
    total_samples: int,
    correct_ratio: float = 0.5,
    sampling_strategy: dict[str, float] | None = None,
    sampling_column: str = 'error_type',
    random_seed: int = 42
) -> list[dict]:
    """
    Assembles a labeled dataset for SFT using stratified sampling.

    Args:
        catalog_df: The DataFrame containing the error catalog. It can be pre-filtered.
        total_samples: The total number of samples desired in the final dataset.
        correct_ratio: The desired proportion of "Correct" examples.
        sampling_strategy: A dictionary mapping values in `sampling_column` to a
                           probability. If None, a uniform distribution is used.
        sampling_column: The column to stratify on (e.g., 'error_type').
        random_seed: A seed for reproducibility.

    Returns:
        A list of dictionaries, each a complete (input, target) pair.
    """
    if catalog_df.empty:
        print("Error: Input catalog is empty. Cannot generate dataset.")
        return []
        
    num_correct = int(total_samples * correct_ratio)
    num_flawed = total_samples - num_correct
    sft_dataset = []

    # --- 1. Generate Flawed Samples via Stratified Sampling ---
    if num_flawed > 0:
        # Filter to only valid, usable rows from the catalog.
        flawed_catalog = catalog_df.dropna(subset=['filepath', sampling_column]).copy()
        
        # Determine the sampling strategy.
        if sampling_strategy is None:
            available_types = flawed_catalog[sampling_column].unique()
            strategy = {t: 1.0 / len(available_types) for t in available_types}
        else:
            if not np.isclose(sum(sampling_strategy.values()), 1.0):
                total_prob = sum(sampling_strategy.values())
                strategy = {k: v / total_prob for k, v in sampling_strategy.items()}
            else:
                strategy = sampling_strategy
        
        # Calculate the exact number of samples needed for each category.
        counts = {cat: int(prob * num_flawed) for cat, prob in strategy.items()}
        remainder = num_flawed - sum(counts.values())
        for i in range(remainder):
            cat_to_increment = max(strategy, key=lambda k: strategy.get(k, 0) if k in counts else 0)
            counts[cat_to_increment] += 1

        # Sample from each category.
        sampled_rows_list = []
        for category, count in counts.items():
            if count == 0: continue
            pool = flawed_catalog[flawed_catalog[sampling_column] == category]
            if len(pool) < count:
                print(f"Warning: Not enough samples for '{category}'. Requested {count}, have {len(pool)}. Taking all available.")
                count = len(pool)
            if count > 0:
                sampled_rows_list.append(pool.sample(n=count, random_state=random_seed))
        
        if sampled_rows_list:
            sampled_errors = pd.concat(sampled_rows_list)
            for _, row in sampled_errors.iterrows():
                artifact = load_error_artifact(row['filepath'])
                problem_text, _ = get_original_problem_and_solution(row['index'])
                if artifact and problem_text:
                    sft_input = construct_sft_input(problem_text, artifact['flawed_nl_solution'])
                    sft_dataset.append({"input": sft_input, "target": artifact['target_json']})

    # --- 2. Generate Correct Samples ---
    if num_correct > 0:
        available_indices = catalog_df['index'].unique()
        num_to_sample = min(num_correct, len(available_indices))
        if num_to_sample < num_correct:
            print(f"Warning: Requested {num_correct} correct samples, but only {num_to_sample} unique problems are available.")
        
        rng = random.Random(random_seed)
        sampled_indices = rng.sample(list(available_indices), num_to_sample)
        
        for index in sampled_indices:
            problem_text, solution_text = get_original_problem_and_solution(index)
            if problem_text and solution_text:
                sft_input = construct_sft_input(problem_text, solution_text)
                sft_target = {"verdict": "Correct", "error_details": None}
                sft_dataset.append({"input": sft_input, "target": sft_target})

    # --- 3. Finalize and Shuffle ---
    rng = random.Random(random_seed)
    rng.shuffle(sft_dataset)
    print(f"\nGenerated a dataset with {len(sft_dataset)} total samples.")
    return sft_dataset

print("Dataset assembly function defined.")

Dataset assembly function defined.


### Cell 6: Example Usage and Queries

In [19]:
# --- Example 1: A balanced 200-sample dataset with uniform error distribution ---
print("--- [Example 1: Balanced Dataset, Uniform Sampling] ---")

# We use the full catalog_df as input. create_sft_dataset will handle filtering internally.
uniform_dataset = create_sft_dataset(
    catalog_df,
    total_samples=200,
    correct_ratio=0.5,
    sampling_strategy=None # This will trigger uniform sampling over all valid error types
)

# Inspect the first SFT-ready sample
if uniform_dataset:
    print("\n--- Sample Record ---")
    print(json.dumps(uniform_dataset[0], indent=2))

# --- Example 2: A highly specific, pre-filtered dataset ---
print("\n" + "="*80)
print("--- [Example 2: Pre-filtered Dataset with Custom Proportions] ---")

# First, create a specific subset of the catalog using our query functions.
# Goal: Only errors from OpenAI in tier 3 or 4.
query_filtered_df = (
    catalog_df
    .pipe(query_by_model, models=['openai_gpt-4.1'])
    .pipe(query_by_tier, tiers=['tier3', 'tier4'])
)
print(f"Pre-filtered catalog contains {len(query_filtered_df)} potential samples.")

# Now, define a custom strategy for the errors within this subset.
custom_strategy = {
    'generate_digit_transposition_error': 0.60,
    'generate_reciprocal_error': 0.40,
}

# Create the dataset using the pre-filtered catalog and the custom strategy.
custom_dataset = create_sft_dataset(
    query_filtered_df, # <-- Use the pre-filtered DataFrame
    total_samples=100,
    correct_ratio=0.25,
    sampling_strategy=custom_strategy,
    sampling_column='error_type'
)

# Verify the composition
if custom_dataset:
    correct_count = sum(1 for item in custom_dataset if item['target']['verdict'] == 'Correct')
    flawed_count = len(custom_dataset) - correct_count
    print(f"\nFinal dataset composition: {correct_count} correct, {flawed_count} flawed.")

--- [Example 1: Balanced Dataset, Uniform Sampling] ---

Generated a dataset with 197 total samples.

--- Sample Record ---
{
  "input": "### Problem:\nTo make a cherry pie, Veronica needs 3 pounds of pitted cherries.  There are 80 single cherries in one pound of cherries.   It takes 10 minutes to pit 20 cherries.  How many hours will it take Veronica to pit all the cherries?\n\n### Solution:\nThere are 80 cherries in a pound and she needs 3 pounds to make a pie so she needs 80*3 = <<80*3=240>>240 cherries\nIt takes her 10 minutes to pit a unit of 20 cherries. She has 240/20 = <<240/20=12>>12 units of cherries to pit\nIt takes 10 minutes to pit a unit of cherries and she has 12 units so it will take her 10*12 = <<10*12=120>>120 minutes\n60 minutes are in 1 hour and it takes her 120 minutes so that\u2019s 120/60 = <<120/60=2>>2 hours\n#### 2",
  "target": {
    "verdict": "Correct",
    "error_details": null
  }
}

--- [Example 2: Pre-filtered Dataset with Custom Proportions] ---
Pre-fi

In [21]:
for sample in custom_dataset:  # Display first 5 samples
    inp = sample['input']
    target = sample['target']
    print("\n--- Sample Record ---")
    print(f"Input: {inp}")
    print(f"Target: {json.dumps(target, indent=2)}")


--- Sample Record ---
Input: ### Problem:
Joe buys 3 oranges, 7 juices, 3 jars of honey, and 4 plants at the market. The fruit costs $4.50 each, the juice was 50 cents each, the jars of honey were $5, and the plants were 2 for $18. How much does Joe spend at the market?

### Solution:
Joe spends 4.5*3 = <<4.5*3=13.5>>13.5 on oranges.
Joe spends 7*0.5 = <<7*0.5=3.5>>3.5 on juice.
Joe spends 3*5 = <<3*5=15>>15 on honey.
Each plant costs 18/2 = <<18/2=9.0>>9.0 dollars.
Joe spends 9.0*4 = <<9.0*4=36.0>>36.0 on plants.
Joe spends a total of 13.5+3.5+15+36.0 = <<13.5+3.5+15+36.0=86.0>>86.0 dollars at the market.
Target: {
  "verdict": "Flawed",
  "error_details": {
    "error_type": "computational_error",
    "erroneous_line_number": "L6",
    "explanation": "The result of this computation should be 68.0, not 86.0. It appears two adjacent digits were swapped."
  }
}

--- Sample Record ---
Input: ### Problem:
Janele wants to figure out the average weight of her cats. She has 4 of them. The f