In [1]:
import json
import random
import re
from pathlib import Path
from typing import List, Dict, Set, Tuple

import pandas as pd
import numpy as np
from tqdm.auto import tqdm

# --- Path and Directory Definitions ---
def find_project_root(marker: str = ".git") -> Path:
    """Traverse upwards to find the project root, marked by the git repository."""
    current_path = Path.cwd().resolve()
    while current_path != current_path.parent:
        if (current_path / marker).exists():
            return current_path
        current_path = current_path.parent
    raise FileNotFoundError(f"Could not find project root. Marker '{marker}' not found.")

# --- Global Constants and Paths ---
PROJECT_ROOT = find_project_root()
DATA_DIR = PROJECT_ROOT / 'data'
OUTPUT_DIR = DATA_DIR / "sft-datasets"

# --- Ensure output directory exists ---
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)

# --- Configuration ---
RANDOM_SEED = 42
random.seed(RANDOM_SEED)
np.random.seed(RANDOM_SEED)

print(f"Project root: {PROJECT_ROOT}")
print(f"Data directory: {DATA_DIR}")
print(f"Output directory: {OUTPUT_DIR}")
print(f"Random seed set to: {RANDOM_SEED}")

Project root: /Users/arvindsuresh/Documents/Github/Erdos-DL-June25-Math
Data directory: /Users/arvindsuresh/Documents/Github/Erdos-DL-June25-Math/data
Output directory: /Users/arvindsuresh/Documents/Github/Erdos-DL-June25-Math/data/sft-datasets
Random seed set to: 42


In [2]:
# Load the master catalog
master_catalog_path = DATA_DIR / "master_catalog_sanitized.csv"
if not master_catalog_path.exists():
    raise FileNotFoundError(f"Master catalog not found: {master_catalog_path}")

master_df = pd.read_csv(master_catalog_path)
print(f"Loaded master catalog with {len(master_df):,} records")
print(f"Columns: {list(master_df.columns)}")

# Display basic statistics
print("\n=== Master Catalog Overview ===")
print(f"Total samples: {len(master_df):,}")
print(f"Unique indices: {master_df['index'].nunique():,}")

print("\n--- Error Type Distribution ---")
error_type_counts = master_df['error_type'].value_counts()
print(error_type_counts)

print("\n--- Source Distribution ---") 
source_counts = master_df['source'].value_counts()
print(source_counts)

print("\n--- Tier Distribution ---")
tier_counts = master_df['tier'].value_counts().sort_index()
print(tier_counts)

# Check for missing values in critical columns
critical_columns = ['index', 'tier', 'question', 'correct_answer', 'wrong_answer', 'error_type', 'source']
print("\n--- Missing Values Check ---")
for col in critical_columns:
    if col in master_df.columns:
        missing_count = master_df[col].isna().sum()
        print(f"{col}: {missing_count} missing ({missing_count/len(master_df)*100:.1f}%)")

# Display first few rows to understand structure
print("\n--- Sample Records ---")
print(master_df.head(3))

Loaded master catalog with 24,652 records
Columns: ['index', 'tier', 'question', 'correct_answer', 'wrong_answer', 'error_type', 'erroneous_line_number', 'explanation', 'error_subtype', 'source', 'solution_length', 'relative_line_position']

=== Master Catalog Overview ===
Total samples: 24,652
Unique indices: 6,777

--- Error Type Distribution ---
error_type
computational_error    22542
conceptual_error        2110
Name: count, dtype: int64

--- Source Distribution ---
source
programmatic    22912
manual           1740
Name: count, dtype: int64

--- Tier Distribution ---
tier
tier1    11188
tier2     3458
tier3     8288
tier4     1662
tier5       56
Name: count, dtype: int64

--- Missing Values Check ---
index: 0 missing (0.0%)
tier: 0 missing (0.0%)
question: 0 missing (0.0%)
correct_answer: 0 missing (0.0%)
wrong_answer: 0 missing (0.0%)
error_type: 0 missing (0.0%)
source: 0 missing (0.0%)

--- Sample Records ---
   index   tier                                           question  \

In [3]:
def get_anchor_indices(master_df):
    """
    Get the anchor indices (set A) - problems that have conceptual errors.
    These will be the core problems used in both 3N and 4N datasets.
    """
    # Get all indices that have conceptual errors
    conceptual_indices = master_df[
        master_df['error_type'] == 'conceptual_error'
    ]['index'].unique()
    
    anchor_indices = sorted(conceptual_indices)
    print(f"Found {len(anchor_indices)} anchor indices with conceptual errors")
    
    # Verify we have sufficient coverage across tiers
    anchor_df = master_df[master_df['index'].isin(anchor_indices)]
    anchor_tier_dist = anchor_df.groupby('index')['tier'].first().value_counts().sort_index()
    
    print("\n--- Anchor Indices by Tier ---")
    for tier, count in anchor_tier_dist.items():
        print(f"{tier}: {count} problems")
    
    return anchor_indices

anchor_indices = get_anchor_indices(master_df)

Found 1881 anchor indices with conceptual errors

--- Anchor Indices by Tier ---
tier1: 635 problems
tier2: 279 problems
tier3: 722 problems
tier4: 217 problems
tier5: 28 problems


In [4]:
def analyze_anchor_coverage(master_df, anchor_indices):
    """
    Analyze what types of samples we have for each anchor index.
    Note: Every index has a correct answer available via the correct_answer column,
    so we only need to check for error types.
    """
    coverage_analysis = []
    
    for idx in anchor_indices:
        idx_samples = master_df[master_df['index'] == idx]
        
        # Check what error types we have for this index
        available_types = set(idx_samples['error_type'].unique())
        
        # Get tier (should be consistent for same index)
        tier = idx_samples['tier'].iloc[0]
        
        analysis = {
            'index': idx,
            'tier': tier,
            'total_samples': len(idx_samples),
            'has_correct': True,  # Every index has correct answer via correct_answer column
            'has_conceptual': 'conceptual_error' in available_types,
            'has_computational': 'computational_error' in available_types,
            'available_types': available_types,
            'manual_sources': len(idx_samples[idx_samples['source'] == 'manual']),
            'programmatic_sources': len(idx_samples[idx_samples['source'] == 'programmatic'])
        }
        
        coverage_analysis.append(analysis)
    
    coverage_df = pd.DataFrame(coverage_analysis)
    
    print("=== Coverage Analysis for Anchor Indices ===")
    print(f"Total anchor indices: {len(coverage_df)}")
    print(f"Indices with correct samples: {coverage_df['has_correct'].sum()} (all indices have correct answers)")
    print(f"Indices with conceptual errors: {coverage_df['has_conceptual'].sum()}")
    print(f"Indices with computational errors: {coverage_df['has_computational'].sum()}")
    
    # Indices that can form complete triplets (correct + conceptual + computational)
    complete_triplets = coverage_df[
        coverage_df['has_correct'] & 
        coverage_df['has_conceptual'] & 
        coverage_df['has_computational']
    ]
    print(f"Indices with complete triplets: {len(complete_triplets)}")
    
    # Indices missing computational errors
    missing_computational = coverage_df[
        coverage_df['has_correct'] & 
        coverage_df['has_conceptual'] & 
        ~coverage_df['has_computational']
    ]
    print(f"Indices missing computational errors: {len(missing_computational)}")
    
    if len(missing_computational) > 0:
        print("\n--- Missing Computational Errors by Tier ---")
        missing_by_tier = missing_computational['tier'].value_counts().sort_index()
        for tier, count in missing_by_tier.items():
            print(f"{tier}: {count}")
    
    return coverage_df

coverage_df = analyze_anchor_coverage(master_df, anchor_indices)

=== Coverage Analysis for Anchor Indices ===
Total anchor indices: 1881
Indices with correct samples: 1881 (all indices have correct answers)
Indices with conceptual errors: 1881
Indices with computational errors: 1773
Indices with complete triplets: 1773
Indices missing computational errors: 108

--- Missing Computational Errors by Tier ---
tier1: 13
tier2: 6
tier3: 58
tier4: 20
tier5: 11


In [5]:
def create_3N_dataset_simplified(master_df, anchor_indices, coverage_df):
    """
    Create the 3N dataset using only anchor indices that have BOTH conceptual and computational errors.
    For each qualifying anchor index, include: correct (from correct_answer), conceptual_error, computational_error
    """
    print("=== Creating 3N Dataset (Requires Complete Triplets) ===")
    
    dataset_3N = []
    successful_indices = []
    failed_indices = []
    
    for idx in tqdm(anchor_indices, desc="Processing anchor indices"):
        idx_samples = master_df[master_df['index'] == idx]
        
        # Get the coverage info for this index
        idx_coverage = coverage_df[coverage_df['index'] == idx].iloc[0]
        
        # Skip if we don't have both conceptual and computational errors
        if not (idx_coverage['has_conceptual'] and idx_coverage['has_computational']):
            failed_indices.append(idx)
            continue
        
        # Get a representative sample for basic info (question, tier, etc.)
        base_sample = idx_samples.iloc[0]
        
        # 1. Create correct sample using correct_answer column
        correct_sample = base_sample.copy()
        correct_sample['wrong_answer'] = correct_sample['correct_answer']
        correct_sample['error_type'] = 'correct'
        correct_sample['erroneous_line_number'] = None
        correct_sample['explanation'] = None
        correct_sample['error_subtype'] = None
        correct_sample['source'] = 'gsm8k'
        dataset_3N.append(correct_sample)
        
        # 2. Add conceptual error sample (exclude any with error_type='correct')
        conceptual_samples = idx_samples[
            (idx_samples['error_type'] == 'conceptual_error') & 
            (idx_samples['error_type'] != 'correct')
        ]
        manual_conceptual = conceptual_samples[conceptual_samples['source'] == 'manual']
        if len(manual_conceptual) > 0:
            conceptual_sample = manual_conceptual.iloc[0].copy()
        else:
            conceptual_sample = conceptual_samples.iloc[0].copy()
        dataset_3N.append(conceptual_sample)
        
        # 3. Add computational error sample (exclude any with error_type='correct')
        computational_samples = idx_samples[
            (idx_samples['error_type'] == 'computational_error') & 
            (idx_samples['error_type'] != 'correct')
        ]
        manual_computational = computational_samples[computational_samples['source'] == 'manual']
        if len(manual_computational) > 0:
            computational_sample = manual_computational.iloc[0].copy()
        else:
            computational_sample = computational_samples.iloc[0].copy()
        dataset_3N.append(computational_sample)
        
        successful_indices.append(idx)
    
    dataset_3N_df = pd.DataFrame(dataset_3N)
    
    # Sort by index first, then by error_type
    dataset_3N_df = dataset_3N_df.sort_values(['index', 'error_type']).reset_index(drop=True)
    
    print(f"Successfully processed: {len(successful_indices)} indices")
    print(f"Failed to process: {len(failed_indices)} indices (missing computational errors)")
    print(f"Total samples in 3N dataset: {len(dataset_3N_df)}")
    
    # Verify perfect balance
    if len(dataset_3N_df) > 0:
        print("\n--- 3N Dataset Composition ---")
        error_type_dist = dataset_3N_df['error_type'].value_counts()
        for error_type, count in error_type_dist.items():
            print(f"{error_type}: {count}")
        
        # Check if perfectly balanced
        n_indices = len(successful_indices)
        expected_per_type = n_indices
        counts_list = list(error_type_dist.values)
        is_balanced = all(count == expected_per_type for count in counts_list)
        print(f"✓ Perfect balance: {is_balanced} (expected {expected_per_type} per type)")
        
        print("\n--- 3N Dataset by Source ---")
        source_dist = dataset_3N_df['source'].value_counts()
        for source, count in source_dist.items():
            print(f"{source}: {count}")
        
        print("\n--- 3N Dataset by Tier ---")
        tier_dist = dataset_3N_df['tier'].value_counts().sort_index()
        for tier, count in tier_dist.items():
            print(f"{tier}: {count}")
    
    return dataset_3N_df, successful_indices, failed_indices

dataset_3N_df, successful_3N_indices, failed_3N_indices = create_3N_dataset_simplified(
    master_df, anchor_indices, coverage_df
)

=== Creating 3N Dataset (Requires Complete Triplets) ===


Processing anchor indices:   0%|          | 0/1881 [00:00<?, ?it/s]

Successfully processed: 1773 indices
Failed to process: 108 indices (missing computational errors)
Total samples in 3N dataset: 5319

--- 3N Dataset Composition ---
computational_error: 1773
conceptual_error: 1773
correct: 1773
✓ Perfect balance: True (expected 1773 per type)

--- 3N Dataset by Source ---
programmatic: 2226
gsm8k: 1773
manual: 1320

--- 3N Dataset by Tier ---
tier1: 1866
tier2: 819
tier3: 1992
tier4: 591
tier5: 51


In [6]:
def select_disjoint_set_B(master_df, anchor_indices, target_size):
    """
    Select a disjoint set B of indices that:
    1. Are not in anchor_indices (set A)
    2. Have computational errors available
    3. Have the same size as the successful anchor indices
    """
    print(f"=== Selecting Disjoint Set B (target size: {target_size}) ===")
    
    # Get all indices that are NOT in anchor_indices and have computational errors
    all_computational_indices = master_df[
        master_df['error_type'] == 'computational_error'
    ]['index'].unique()
    
    candidate_B_indices = [
        idx for idx in all_computational_indices 
        if idx not in anchor_indices
    ]
    
    print(f"Candidate B indices with computational errors: {len(candidate_B_indices)}")
    
    # Check coverage for candidate B indices
    B_coverage = []
    for idx in candidate_B_indices:
        idx_samples = master_df[master_df['index'] == idx]
        available_types = set(idx_samples['error_type'].unique())
        tier = idx_samples['tier'].iloc[0]
        
        # For set B, we need: correct answer (from GSM8K) + computational error
        # The correct answer is always available from GSM8K, so just check computational
        has_computational = 'computational_error' in available_types
        
        if has_computational:
            B_coverage.append({
                'index': idx,
                'tier': tier,
                'has_computational': has_computational
            })
    
    B_coverage_df = pd.DataFrame(B_coverage)
    
    print(f"Valid B candidates: {len(B_coverage_df)}")
    
    if len(B_coverage_df) < target_size:
        print(f"WARNING: Only {len(B_coverage_df)} valid B candidates, but need {target_size}")
        target_size = len(B_coverage_df)
    
    # Sample to maintain tier distribution similar to set A
    A_tier_dist = master_df[master_df['index'].isin(anchor_indices)].groupby('index')['tier'].first().value_counts(normalize=True)
    
    selected_B_indices = []
    
    # Try to maintain proportional tier distribution
    for tier in A_tier_dist.index:
        tier_candidates = B_coverage_df[B_coverage_df['tier'] == tier]['index'].tolist()
        tier_target = int(A_tier_dist[tier] * target_size)
        tier_selected = min(tier_target, len(tier_candidates))
        
        if tier_selected > 0:
            tier_sample = random.sample(tier_candidates, tier_selected)
            selected_B_indices.extend(tier_sample)
            print(f"Selected {tier_selected} from {tier} (target: {tier_target}, available: {len(tier_candidates)})")
    
    # Fill remaining slots if needed
    remaining_needed = target_size - len(selected_B_indices)
    if remaining_needed > 0:
        remaining_candidates = [
            idx for idx in B_coverage_df['index'].tolist() 
            if idx not in selected_B_indices
        ]
        if len(remaining_candidates) >= remaining_needed:
            additional_selected = random.sample(remaining_candidates, remaining_needed)
            selected_B_indices.extend(additional_selected)
            print(f"Added {remaining_needed} additional indices to reach target size")
    
    print(f"Final set B size: {len(selected_B_indices)}")
    return selected_B_indices[:target_size]

def create_4N_dataset_larger(master_df, anchor_indices):
    """
    Create the 4N dataset using:
    - Set A (all anchor indices with conceptual errors): correct + conceptual_error samples  
    - Set B (disjoint set with computational errors): correct + computational_error samples
    """
    print("=== Creating 4N Dataset (Larger Version) ===")
    
    # Set A: ALL anchor indices (since they all have conceptual errors by definition)
    set_A_indices = anchor_indices.copy()
    print(f"Set A size: {len(set_A_indices)} (all anchor indices)")
    
    # Select disjoint set B with same size as set A
    set_B_indices = select_disjoint_set_B(master_df, anchor_indices, len(set_A_indices))
    
    dataset_4N = []
    
    # Add samples from Set A (correct + conceptual_error)
    print(f"\nProcessing Set A ({len(set_A_indices)} indices)...")
    for idx in tqdm(set_A_indices, desc="Set A"):
        idx_samples = master_df[master_df['index'] == idx]
        base_sample = idx_samples.iloc[0]
        
        # Add correct sample (created from correct_answer column)
        correct_sample = base_sample.copy()
        correct_sample['wrong_answer'] = correct_sample['correct_answer']
        correct_sample['error_type'] = 'correct'
        correct_sample['erroneous_line_number'] = None
        correct_sample['explanation'] = None
        correct_sample['error_subtype'] = None
        correct_sample['source'] = 'gsm8k'
        dataset_4N.append(correct_sample)
        
        # Add conceptual error sample (exclude any with error_type='correct')
        conceptual_samples = idx_samples[
            (idx_samples['error_type'] == 'conceptual_error') & 
            (idx_samples['error_type'] != 'correct')
        ]
        if len(conceptual_samples) > 0:
            manual_conceptual = conceptual_samples[conceptual_samples['source'] == 'manual']
            if len(manual_conceptual) > 0:
                conceptual_sample = manual_conceptual.iloc[0].copy()
            else:
                conceptual_sample = conceptual_samples.iloc[0].copy()
            dataset_4N.append(conceptual_sample)
        else:
            print(f"WARNING: No conceptual error found for anchor index {idx}")
    
    # Add samples from Set B (correct + computational_error)
    print(f"\nProcessing Set B ({len(set_B_indices)} indices)...")
    for idx in tqdm(set_B_indices, desc="Set B"):
        idx_samples = master_df[master_df['index'] == idx]
        base_sample = idx_samples.iloc[0]
        
        # Add correct sample (created from correct_answer column)
        correct_sample = base_sample.copy()
        correct_sample['wrong_answer'] = correct_sample['correct_answer']
        correct_sample['error_type'] = 'correct'
        correct_sample['erroneous_line_number'] = None
        correct_sample['explanation'] = None
        correct_sample['error_subtype'] = None
        correct_sample['source'] = 'gsm8k'
        dataset_4N.append(correct_sample)
        
        # Add computational error sample (exclude any with error_type='correct')
        computational_samples = idx_samples[
            (idx_samples['error_type'] == 'computational_error') & 
            (idx_samples['error_type'] != 'correct')
        ]
        if len(computational_samples) > 0:
            manual_computational = computational_samples[computational_samples['source'] == 'manual']
            if len(manual_computational) > 0:
                computational_sample = manual_computational.iloc[0].copy()
            else:
                computational_sample = computational_samples.iloc[0].copy()
            dataset_4N.append(computational_sample)
        else:
            print(f"WARNING: No computational error found for set B index {idx}")
    
    dataset_4N_df = pd.DataFrame(dataset_4N)
    
    # Sort by index first, then by error_type
    dataset_4N_df = dataset_4N_df.sort_values(['index', 'error_type']).reset_index(drop=True)
    
    print(f"\nTotal samples in 4N dataset: {len(dataset_4N_df)}")
    
    # Verify composition
    if len(dataset_4N_df) > 0:
        print("\n--- 4N Dataset Composition ---")
        error_type_dist = dataset_4N_df['error_type'].value_counts()
        for error_type, count in error_type_dist.items():
            print(f"{error_type}: {count}")
        
        print("\n--- 4N Dataset by Source ---")
        source_dist = dataset_4N_df['source'].value_counts()
        for source, count in source_dist.items():
            print(f"{source}: {count}")
        
        print("\n--- 4N Dataset by Tier ---")
        tier_dist = dataset_4N_df['tier'].value_counts().sort_index()
        for tier, count in tier_dist.items():
            print(f"{tier}: {count}")
        
        # Check balance between sets
        print(f"\n--- Set Distribution ---")
        set_A_samples = dataset_4N_df[dataset_4N_df['index'].isin(set_A_indices)]
        set_B_samples = dataset_4N_df[dataset_4N_df['index'].isin(set_B_indices)]
        print(f"Set A samples: {len(set_A_samples)}")
        print(f"Set B samples: {len(set_B_samples)}")
        
        # Expected composition
        expected_correct = len(set_A_indices) + len(set_B_indices)
        expected_conceptual = len(set_A_indices)
        expected_computational = len(set_B_indices)
        
        print(f"\n--- Expected vs Actual ---")
        print(f"Expected: correct={expected_correct}, conceptual={expected_conceptual}, computational={expected_computational}")
        print(f"Actual: correct={error_type_dist.get('correct', 0)}, conceptual={error_type_dist.get('conceptual_error', 0)}, computational={error_type_dist.get('computational_error', 0)}")
    
    return dataset_4N_df, set_A_indices, set_B_indices

dataset_4N_df, set_A_indices, set_B_indices = create_4N_dataset_larger(master_df, anchor_indices)

=== Creating 4N Dataset (Larger Version) ===
Set A size: 1881 (all anchor indices)
=== Selecting Disjoint Set B (target size: 1881) ===
Candidate B indices with computational errors: 4896
Valid B candidates: 4896
Selected 722 from tier3 (target: 722, available: 2046)
Selected 635 from tier1 (target: 635, available: 2036)
Selected 279 from tier2 (target: 279, available: 535)
Selected 217 from tier4 (target: 217, available: 269)
Selected 10 from tier5 (target: 28, available: 10)
Added 18 additional indices to reach target size
Final set B size: 1881

Processing Set A (1881 indices)...


Set A:   0%|          | 0/1881 [00:00<?, ?it/s]


Processing Set B (1881 indices)...


Set B:   0%|          | 0/1881 [00:00<?, ?it/s]


Total samples in 4N dataset: 7524

--- 4N Dataset Composition ---
correct: 3762
conceptual_error: 1881
computational_error: 1881

--- 4N Dataset by Source ---
gsm8k: 3762
programmatic: 2680
manual: 1082

--- 4N Dataset by Tier ---
tier1: 2562
tier2: 1120
tier3: 2898
tier4: 868
tier5: 76

--- Set Distribution ---
Set A samples: 3762
Set B samples: 3762

--- Expected vs Actual ---
Expected: correct=3762, conceptual=1881, computational=1881
Actual: correct=3762, conceptual=1881, computational=1881


In [7]:
def save_base_datasets(dataset_3N_df, dataset_4N_df, output_dir):
    """
    Save the 3N and 4N base datasets to CSV files with metadata.
    """
    print("=== Saving Base Datasets ===")
    
    # Create base-datasets directory
    base_datasets_dir = output_dir / "base-datasets-sanitized"
    base_datasets_dir.mkdir(parents=True, exist_ok=True)
    
    # Save 3N dataset
    output_3N_path = base_datasets_dir / "base_3N_dataset_sanitized.csv"
    dataset_3N_df.to_csv(output_3N_path, index=False)
    print(f"✓ 3N dataset saved to: {output_3N_path}")
    print(f"  Size: {len(dataset_3N_df):,} samples")
    print(f"  Unique indices: {dataset_3N_df['index'].nunique():,}")
    
    # Save 4N dataset  
    output_4N_path = base_datasets_dir / "base_4N_dataset_sanitized.csv"
    dataset_4N_df.to_csv(output_4N_path, index=False)
    print(f"✓ 4N dataset saved to: {output_4N_path}")
    print(f"  Size: {len(dataset_4N_df):,} samples")
    print(f"  Unique indices: {dataset_4N_df['index'].nunique():,}")
    
    # Create comprehensive metadata
    metadata = {
        "creation_info": {
            "creation_date": pd.Timestamp.now().isoformat(),
            "random_seed": RANDOM_SEED,
            "source_catalog": "master_catalog.csv",
            "creator": "make-base-datasets.ipynb"
        },
        "datasets": {
            "base_3N_dataset.csv": {
                "description": "3N strategy: Requires complete triplets (correct + conceptual_error + computational_error) for each problem",
                "strategy": "ternary_balanced",
                "total_samples": len(dataset_3N_df),
                "unique_problems": dataset_3N_df['index'].nunique(),
                "samples_per_problem": 3,
                "composition": dict(dataset_3N_df['error_type'].value_counts()),
                "source_distribution": dict(dataset_3N_df['source'].value_counts()),
                "tier_distribution": dict(dataset_3N_df['tier'].value_counts()),
                "anchor_indices_used": len(successful_3N_indices),
                "anchor_indices_failed": len(failed_3N_indices)
            },
            "base_4N_dataset.csv": {
                "description": "4N strategy: Set A (conceptual errors) + Set B (computational errors), each with correct samples",
                "strategy": "binary_balanced", 
                "total_samples": len(dataset_4N_df),
                "unique_problems": dataset_4N_df['index'].nunique(),
                "samples_per_problem": 2,
                "composition": dict(dataset_4N_df['error_type'].value_counts()),
                "source_distribution": dict(dataset_4N_df['source'].value_counts()),
                "tier_distribution": dict(dataset_4N_df['tier'].value_counts()),
                "set_A_size": len(set_A_indices),
                "set_B_size": len(set_B_indices),
                "sets_disjoint": len(set(set_A_indices) & set(set_B_indices)) == 0
            }
        },
        "usage_notes": {
            "for_colab": "These datasets are designed for Google Colab experiments",
            "columns": list(dataset_3N_df.columns),
            "correct_samples": "Created from correct_answer column, marked with source='gsm8k'",
            "error_samples": "From master catalog, prioritizing manual > programmatic sources",
            "formatting": "Ready for direct use in experiment pipeline notebooks"
        }
    }
    
    # Save metadata
    metadata_path = base_datasets_dir / "base_datasets_sanitized_metadata.json"
    with open(metadata_path, 'w') as f:
        json.dump(metadata, f, indent=2, default=str)
    print(f"✓ Metadata saved to: {metadata_path}")
    
    # Create a README file for the datasets
    readme_content = f"""# Base Datasets for SFT Experiments

Created on: {pd.Timestamp.now().strftime('%Y-%m-%d %H:%M:%S')}
Random seed: {RANDOM_SEED}

## Dataset Files

### base_3N_dataset_sanitized.csv
- **Strategy**: Ternary classification (3N)
- **Samples**: {len(dataset_3N_df):,} total ({dataset_3N_df['index'].nunique():,} unique problems × 3 samples each)
- **Structure**: Each problem has exactly 3 samples: correct, conceptual_error, computational_error
- **Use case**: Ternary classification experiments (correct vs conceptual_error vs computational_error)

### base_4N_dataset_sanitized.csv
- **Strategy**: Binary classification (4N)
- **Samples**: {len(dataset_4N_df):,} total ({dataset_4N_df['index'].nunique():,} unique problems × 2 samples each)
- **Structure**: 
  - Set A ({len(set_A_indices):,} problems): correct + conceptual_error
  - Set B ({len(set_B_indices):,} problems): correct + computational_error
- **Use case**: Binary classification experiments (correct vs flawed)

## Column Descriptions

- `index`: GSM8K problem index
- `tier`: Problem difficulty tier (tier1-tier4)
- `question`: Problem statement
- `correct_answer`: Correct solution from GSM8K
- `wrong_answer`: Solution text (correct for 'correct' samples, flawed for error samples)
- `error_type`: 'correct', 'conceptual_error', or 'computational_error'
- `erroneous_line_number`: Line number where error occurs (null for correct samples)
- `explanation`: Error explanation (null for correct samples)
- `error_subtype`: Specific error subtype (null for correct samples)
- `source`: 'gsm8k' (correct samples), 'manual', or 'programmatic'
- `solution_length`: Number of solution lines
- `relative_line_position`: Position of error relative to solution length

## Usage in Colab

```python
import pandas as pd

# Load datasets
df_3N = pd.read_csv('base_3N_dataset_sanitized.csv')
df_4N = pd.read_csv('base_4N_dataset_sanitized.csv')

# Verify structure
print("3N dataset composition:")
print(df_3N['error_type'].value_counts())

print("4N dataset composition:")  
print(df_4N['error_type'].value_counts())
```
"""
    
    readme_path = base_datasets_dir / "README.md"
    with open(readme_path, 'w') as f:
        f.write(readme_content)
    print(f"✓ README saved to: {readme_path}")
    
    return base_datasets_dir

# Call the function
base_datasets_dir = save_base_datasets(dataset_3N_df, dataset_4N_df, DATA_DIR)
print(f"\n📁 All files saved to: {base_datasets_dir}")

=== Saving Base Datasets ===
✓ 3N dataset saved to: /Users/arvindsuresh/Documents/Github/Erdos-DL-June25-Math/data/base-datasets-sanitized/base_3N_dataset_sanitized.csv
  Size: 5,319 samples
  Unique indices: 1,773
✓ 4N dataset saved to: /Users/arvindsuresh/Documents/Github/Erdos-DL-June25-Math/data/base-datasets-sanitized/base_4N_dataset_sanitized.csv
  Size: 7,524 samples
  Unique indices: 3,762
✓ Metadata saved to: /Users/arvindsuresh/Documents/Github/Erdos-DL-June25-Math/data/base-datasets-sanitized/base_datasets_sanitized_metadata.json
✓ README saved to: /Users/arvindsuresh/Documents/Github/Erdos-DL-June25-Math/data/base-datasets-sanitized/README.md

📁 All files saved to: /Users/arvindsuresh/Documents/Github/Erdos-DL-June25-Math/data/base-datasets-sanitized
