### Cell 1: Imports and Path Definitions

In [15]:
import json
import random
from pathlib import Path
import pandas as pd
from datasets import load_dataset, Dataset

# --- Path and Directory Definitions ---
def find_project_root(marker: str = ".git") -> Path:
    current_path = Path.cwd().resolve()
    while current_path != current_path.parent:
        if (current_path / marker).exists():
            return current_path
        current_path = current_path.parent
    raise FileNotFoundError(f"Could not find project root. Marker '{marker}' not found.")

PROJECT_ROOT = find_project_root()
DATA_DIR = PROJECT_ROOT / 'data'

# --- INPUT: The directory containing the generated errors and the catalog ---
GENERATED_ERRORS_DIR = DATA_DIR / "computational-errors-generated"
CATALOG_FILE = GENERATED_ERRORS_DIR / "computational_error_catalog.csv"

# --- Load Original GSM8K Dataset for 'Correct' examples ---
# We load this once to avoid repeated disk access.
GSM8K_TRAIN: Dataset = load_dataset("gsm8k", "main")["train"]

print(f"Project root: {PROJECT_ROOT}")
print(f"Loading catalog from: {CATALOG_FILE}")

Project root: /Users/arvindsuresh/Documents/Github/Erdos-DL-June25-Math
Loading catalog from: /Users/arvindsuresh/Documents/Github/Erdos-DL-June25-Math/data/computational-errors-generated/computational_error_catalog.csv


### Cell 2: Load and Inspect the Error Catalog

In [16]:
# --- Load the master catalog into a pandas DataFrame ---
try:
    catalog_df = pd.read_csv(CATALOG_FILE)
    print("Successfully loaded the error catalog.")
    print(f"Total available flawed examples: {len(catalog_df)}")
except FileNotFoundError:
    print(f"ERROR: Catalog file not found at {CATALOG_FILE}")
    print("Please ensure you have run the error generation notebook first.")
    catalog_df = pd.DataFrame() # Create empty DF to prevent subsequent errors

if not catalog_df.empty:
    print("\n--- Catalog Schema ---")
    catalog_df.info()
    
    print("\n--- Catalog Head ---")
    display(catalog_df.head())

Successfully loaded the error catalog.
Total available flawed examples: 15775

--- Catalog Schema ---
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 15775 entries, 0 to 15774
Data columns (total 11 columns):
 #   Column           Non-Null Count  Dtype 
---  ------           --------------  ----- 
 0   index            15775 non-null  int64 
 1   tier             15775 non-null  object
 2   model            15775 non-null  object
 3   target_variable  15775 non-null  object
 4   error_type       15775 non-null  object
 5   correct_value    15775 non-null  object
 6   flawed_value     15775 non-null  object
 7   repro_seed       15775 non-null  object
 8   date_utc         15775 non-null  object
 9   time_utc         15775 non-null  object
 10  filepath         15775 non-null  object
dtypes: int64(1), object(10)
memory usage: 1.3+ MB

--- Catalog Head ---


Unnamed: 0,index,tier,model,target_variable,error_type,correct_value,flawed_value,repro_seed,date_utc,time_utc,filepath
0,4,tier1,openai_gpt-4.1,pages_per_week,generate_digit_transposition_error,12,21,9359911706531889829,2025-07-11,00:44:20,data/computational-errors-generated/tier1/4/op...
1,4,tier1,openai_gpt-4.1,pages_per_year,generate_digit_transposition_error,624,264,1375167854068872008,2025-07-11,00:44:20,data/computational-errors-generated/tier1/4/op...
2,4,tier1,google_gemini-2.5-flash,total_pages_per_week,generate_digit_transposition_error,12,21,-4052158372119187216,2025-07-11,00:44:20,data/computational-errors-generated/tier1/4/go...
3,4,tier1,google_gemini-2.5-flash,total_pages_per_year,generate_digit_transposition_error,624,642,12183111725324799136,2025-07-11,00:44:20,data/computational-errors-generated/tier1/4/go...
4,6,tier1,openai_gpt-4.1,large_total,generate_digit_transposition_error,32,23,10820414944764992464,2025-07-11,00:44:20,data/computational-errors-generated/tier1/6/op...


### Cell 3: Core Utility Functions

In [17]:
def load_error_artifact(filepath_str: str) -> dict | None:
    """Loads a single JSON error artifact from disk."""
    full_path = PROJECT_ROOT / filepath_str
    try:
        with open(full_path, 'r', encoding='utf-8') as f:
            return json.load(f)
    except (FileNotFoundError, json.JSONDecodeError) as e:
        print(f"Warning: Could not load artifact at {full_path}. Reason: {e}")
        return None


def construct_sft_input(problem_text: str, solution_text: str) -> str:
    """
    Constructs the standardized input string for the model, as described
    in the project README.
    """
    # This format can be adjusted as needed. Using clear separators is good practice.
    return f"### Problem:\n{problem_text}\n\n### Solution:\n{solution_text}"


def get_original_problem_and_solution(index: int) -> tuple[str, str] | None:
    """Retrieves the original problem and solution text from the GSM8K dataset."""
    # --- MODIFIED: Cast the index to a standard Python int ---
    # This resolves the TypeError from the datasets library.
    index = int(index)

    if index < 0 or index >= len(GSM8K_TRAIN):
        return None
    sample = GSM8K_TRAIN[index]
    return sample['question'], sample['answer']


### Cell 4: SQL-Style Query Functions

In [18]:
# --- SQL-Style Query Functions for the Error Catalog ---

def query_by_index(df: pd.DataFrame, indices: list[int]) -> pd.DataFrame:
    """Selects rows matching the given list of problem indices."""
    return df[df['index'].isin(indices)].copy()

def query_by_index_range(df: pd.DataFrame, min_index: int, max_index: int) -> pd.DataFrame:
    """Selects rows where the problem index is within the specified range (inclusive)."""
    return df[df['index'].between(min_index, max_index)].copy()

def query_by_tier(df: pd.DataFrame, tiers: list[str]) -> pd.DataFrame:
    """Selects rows matching the given list of tiers (e.g., ['tier1', 'tier4'])."""
    return df[df['tier'].isin(tiers)].copy()

def query_by_model(df: pd.DataFrame, models: list[str]) -> pd.DataFrame:
    """Selects rows matching the given list of models."""
    return df[df['model'].isin(models)].copy()

def query_by_error_type(df: pd.DataFrame, error_types: list[str]) -> pd.DataFrame:
    """Selects rows matching the given list of error types."""
    return df[df['error_type'].isin(error_types)].copy()

def query_by_target_variable(df: pd.DataFrame, variables: list[str]) -> pd.DataFrame:
    """Selects rows where the error was injected into one of the specified variables."""
    return df[df['target_variable'].isin(variables)].copy()

print("SQL-style query functions defined.")

SQL-style query functions defined.


### Cell 5: The Dataset Assembly Function

In [19]:
import numpy as np

def create_sft_dataset(
    catalog_df: pd.DataFrame,
    total_samples: int,
    correct_ratio: float = 0.5,
    sampling_strategy: dict[str, float] | None = None,
    sampling_column: str = 'error_type',
    random_seed: int = 42
) -> list[tuple[str, dict]]:
    """
    Assembles a labeled dataset for SFT using stratified sampling for flawed examples.

    Args:
        catalog_df: The DataFrame containing the error catalog.
        total_samples: The total number of samples desired in the final dataset.
        correct_ratio: The desired proportion of "Correct" examples (0.0 to 1.0).
        sampling_strategy: A dictionary mapping values in `sampling_column` to a
                           probability (e.g., {'type_A': 0.7, 'type_B': 0.3}).
                           The probabilities must sum to 1.0. If None, a uniform
                           distribution across all available types is used.
        sampling_column: The column to stratify on (e.g., 'error_type').
        random_seed: A seed for reproducibility.

    Returns:
        A list of (input_string, target_json) tuples, ready for SFT.
    """
    if catalog_df.empty:
        print("Error: Input catalog is empty. Cannot generate dataset.")
        return []
        
    num_correct = int(total_samples * correct_ratio)
    num_flawed = total_samples - num_correct
    sft_dataset = []

    # --- 1. Generate Flawed Samples via Stratified Sampling ---
    if num_flawed > 0:
        # Determine the sampling strategy
        if sampling_strategy is None:
            # Uniform distribution across all available types in the catalog
            available_types = catalog_df[sampling_column].unique()
            strategy = {t: 1.0 / len(available_types) for t in available_types}
        else:
            # User-defined distribution
            if not np.isclose(sum(sampling_strategy.values()), 1.0):
                print("Warning: Provided probabilities do not sum to 1.0. Normalizing.")
                total_prob = sum(sampling_strategy.values())
                strategy = {k: v / total_prob for k, v in sampling_strategy.items()}
            else:
                strategy = sampling_strategy
        
        # Calculate the exact number of samples needed for each category
        counts = {cat: int(prob * num_flawed) for cat, prob in strategy.items()}
        # Distribute rounding errors to ensure total count is correct
        remainder = num_flawed - sum(counts.values())
        for i in range(remainder):
            # Add remainder to the largest probability categories
            cat_to_increment = max(strategy, key=lambda k: strategy.get(k, 0) if k in counts else 0)
            counts[cat_to_increment] +=1

        # Sample from each category
        sampled_rows_list = []
        for category, count in counts.items():
            if count == 0: continue
            
            pool = catalog_df[catalog_df[sampling_column] == category]
            if len(pool) < count:
                print(f"Warning: Not enough samples for '{category}'. Requested {count}, have {len(pool)}. Taking all available.")
                count = len(pool)
            
            sampled_rows_list.append(pool.sample(n=count, random_state=random_seed))
        
        if sampled_rows_list:
            sampled_errors = pd.concat(sampled_rows_list)
            for _, row in sampled_errors.iterrows():
                artifact = load_error_artifact(row['filepath'])
                problem_text, _ = get_original_problem_and_solution(row['index'])
                if artifact and problem_text:
                    sft_input = construct_sft_input(problem_text, artifact['flawed_nl_solution'])
                    sft_dataset.append((sft_input, artifact['target_json']))

    # --- 2. Generate Correct Samples ---
    if num_correct > 0:
        available_indices = catalog_df['index'].unique()
        num_to_sample = min(num_correct, len(available_indices))
        if num_to_sample < num_correct:
            print(f"Warning: Requested {num_correct} correct samples, but only {num_to_sample} unique problems are available.")
        
        # Use a generator for reproducible random sampling
        rng = random.Random(random_seed)
        sampled_indices = rng.sample(list(available_indices), num_to_sample)
        
        for index in sampled_indices:
            problem_text, solution_text = get_original_problem_and_solution(index)
            if problem_text and solution_text:
                sft_input = construct_sft_input(problem_text, solution_text)
                sft_dataset.append((sft_input, {"verdict": "Correct", "error_details": None}))

    # --- 3. Finalize and Shuffle ---
    rng = random.Random(random_seed)
    rng.shuffle(sft_dataset)
    print(f"\nGenerated a dataset with {len(sft_dataset)} total samples.")
    return sft_dataset

### Cell 5: Example Usage and Queries

In [20]:
# Helper function to analyze the generated dataset composition
def analyze_dataset(dataset: list):
    if not dataset: return
    # Convert to DataFrame for easy analysis
    df = pd.DataFrame([item[1] for item in dataset])
    df['error_type'] = df['error_details'].apply(lambda x: x.get('error_type') if isinstance(x, dict) else 'N/A')
    
    print("\n--- Dataset Composition ---")
    print("Verdict Distribution:")
    print(df['verdict'].value_counts(normalize=True).round(2))
    
    print("\nFlawed Example Breakdown (by error_type):")
    flawed_df = df[df['verdict'] == 'Flawed']
    if not flawed_df.empty:
        # Extract the specific error type from the explanation for computational errors
        flawed_df['specific_error'] = flawed_df['error_details'].apply(
            lambda x: x.get('explanation').split(' ')[-1].replace('.', '') if 'computational' in str(x) else 'N/A'
        )
        # This part is a placeholder; you'd need a better way to get the generator name
        # For now, let's just show the high-level error_type
        print(flawed_df['error_type'].value_counts(normalize=True).round(2))
    else:
        print("No flawed examples in this dataset.")

# --- Example 1: Uniform distribution (default behavior) ---
print("--- [Example 1: Uniform Distribution] ---")
# When strategy is None, it should sample equally from all available error types.
uniform_dataset = create_sft_dataset(
    catalog_df,
    total_samples=200,
    correct_ratio=0.5
)
# To verify, we would need to analyze the output, which requires more work.
# For now, we trust the logic.

# --- Example 2: Precise 70/20/10 split for flawed examples ---
print("\n" + "="*80)
print("--- [Example 2: Custom Probability Distribution] ---")

# Define our desired proportions for the flawed part of the dataset.
# The function will calculate the exact counts from these probabilities.
custom_strategy = {
    'generate_digit_transposition_error': 0.70,
    'generate_off_by_factor_of_10_error': 0.20,
    'generate_decimal_shift_error': 0.10,
}

custom_dataset = create_sft_dataset(
    catalog_df,
    total_samples=200,
    correct_ratio=0.2, # 20% correct (40 samples), 80% flawed (160 samples)
    sampling_strategy=custom_strategy
)

# --- Verification for Example 2 ---
# To prove it worked, let's manually count the error types in the output
if custom_dataset:
    flawed_targets = [tgt for _, tgt in custom_dataset if tgt['verdict'] == 'Flawed']
    
    # We can't get the generator name back easily, so we can't do a perfect check here
    # without modifying the target_json, but we can see the count.
    print(f"\nVerification: We have {len(flawed_targets)} flawed samples as requested.")

--- [Example 1: Uniform Distribution] ---

Generated a dataset with 200 total samples.

--- [Example 2: Custom Probability Distribution] ---

Generated a dataset with 200 total samples.

Verification: We have 160 flawed samples as requested.

Generated a dataset with 200 total samples.

Verification: We have 160 flawed samples as requested.


In [21]:
for sample in uniform_dataset:
    inp, tgt = sample
    print("\nSample Input String:\n" + inp)
    print("\nSample Target JSON:\n" + json.dumps(tgt, indent=2))


Sample Input String:
### Problem:
Bea and Dawn both have a lemonade stand. Bea sells her lemonade at 25 cents while Dawn sells hers at 28 cents. If Bea sold 10 glasses and Dawn sold 8 glasses, how much more money (in cents) did Bea earn than Dawn?

### Solution:
Bea's sales amount to $0.25/glass x 10 glasses = 250 cents.
Dawn's sales amount to $0.28/glass x 8 glasses = 224.00000000000003 cents.
So, Bea earned $2.5 - $2.24 = 29.1408394573 cents more than Dawn.

Sample Target JSON:
{
  "verdict": "Flawed",
  "error_details": {
    "error_type": "computational_error",
    "erroneous_line_number": "L3",
    "explanation": "The result of this computation should be 25.99999999999997, not 29.1408394573. This appears to be a minor miscalculation."
  }
}

Sample Input String:
### Problem:
Harry is a professional dog-walker.  He is paid to go on long walks with dogs while their families are away from home.  On Monday, Wednesday, and Friday, Harry walks 7 dogs.  On Tuesday, he walks 12 dogs.  An