In [1]:
import json
import random
import re
from pathlib import Path
from typing import List, Dict, Set, Tuple

import pandas as pd
import numpy as np
from tqdm.auto import tqdm

# --- Path and Directory Definitions ---
def find_project_root(marker: str = ".git") -> Path:
    """Traverse upwards to find the project root, marked by the git repository."""
    current_path = Path.cwd().resolve()
    while current_path != current_path.parent:
        if (current_path / marker).exists():
            return current_path
        current_path = current_path.parent
    raise FileNotFoundError(f"Could not find project root. Marker '{marker}' not found.")

# --- Global Constants and Paths ---
PROJECT_ROOT = find_project_root()
DATA_DIR = PROJECT_ROOT / 'data'
OUTPUT_DIR = DATA_DIR / "line-classification"

# --- Ensure output directory exists ---
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)

# --- Configuration ---
RANDOM_SEED = 42
LINE_SEP_TOKEN = "<|LINE_SEP|>"  # Special token for line separation

random.seed(RANDOM_SEED)
np.random.seed(RANDOM_SEED)

print(f"Project root: {PROJECT_ROOT}")
print(f"Data directory: {DATA_DIR}")
print(f"Output directory: {OUTPUT_DIR}")
print(f"Random seed set to: {RANDOM_SEED}")
print(f"Line separator token: {LINE_SEP_TOKEN}")

Project root: /Users/arvindsuresh/Documents/Github/Erdos-DL-June25-Math
Data directory: /Users/arvindsuresh/Documents/Github/Erdos-DL-June25-Math/data
Output directory: /Users/arvindsuresh/Documents/Github/Erdos-DL-June25-Math/data/line-classification
Random seed set to: 42
Line separator token: <|LINE_SEP|>


In [2]:
# Load the master catalog
master_catalog_path = DATA_DIR / "master_catalog_sanitized.csv"
if not master_catalog_path.exists():
    raise FileNotFoundError(f"Master catalog not found: {master_catalog_path}")

master_df = pd.read_csv(master_catalog_path)
print(f"Loaded master catalog with {len(master_df):,} records")
print(f"Columns: {list(master_df.columns)}")

print("\n=== Master Catalog Overview ===")
print(f"Total samples: {len(master_df):,}")
print(f"Unique indices: {master_df['index'].nunique():,}")

print("\n--- Error Type Distribution ---")
error_type_counts = master_df['error_type'].value_counts()
print(error_type_counts)

print("\n--- Source Distribution ---") 
source_counts = master_df['source'].value_counts()
print(source_counts)

Loaded master catalog with 24,652 records
Columns: ['index', 'tier', 'question', 'correct_answer', 'wrong_answer', 'error_type', 'erroneous_line_number', 'explanation', 'error_subtype', 'source', 'solution_length', 'relative_line_position']

=== Master Catalog Overview ===
Total samples: 24,652
Unique indices: 6,777

--- Error Type Distribution ---
error_type
computational_error    22542
conceptual_error        2110
Name: count, dtype: int64

--- Source Distribution ---
source
programmatic    22912
manual           1740
Name: count, dtype: int64


In [3]:
def select_balanced_computational_indices_simple(master_df, conceptual_indices, target_count):
    """
    Simplified computational error sampling:
    1. Form 3 groups based on relative line position (early, middle, late)
    2. From each group, take as many distinct tier 4 samples as possible
    3. For remaining, randomly sample distinct indices from other tiers
    """
    
    print("=== Simple Computational Error Sampling ===")
    
    # Get available computational errors (excluding conceptual indices)
    computational_df = master_df[
        (master_df['error_type'] == 'computational_error') & 
        (~master_df['index'].isin(conceptual_indices))
    ].copy()
    
    if len(computational_df) == 0:
        print("Warning: No computational errors available after excluding conceptual indices")
        return []
    
    print(f"Available computational error samples: {len(computational_df)}")
    print(f"Available unique indices: {computational_df['index'].nunique()}")
    
    # Step 1: Form 3 groups based on relative line position
    # Simple tertile split
    positions = computational_df['relative_line_position'].dropna()
    tertile_33 = positions.quantile(0.33)
    tertile_67 = positions.quantile(0.67)
    
    print(f"Position tertiles: 33%={tertile_33:.3f}, 67%={tertile_67:.3f}")
    
    # Create groups
    early_group = computational_df[computational_df['relative_line_position'] <= tertile_33]
    middle_group = computational_df[
        (computational_df['relative_line_position'] > tertile_33) & 
        (computational_df['relative_line_position'] <= tertile_67)
    ]
    late_group = computational_df[computational_df['relative_line_position'] > tertile_67]
    
    print(f"Early group: {len(early_group)} samples, {early_group['index'].nunique()} unique indices")
    print(f"Middle group: {len(middle_group)} samples, {middle_group['index'].nunique()} unique indices")
    print(f"Late group: {len(late_group)} samples, {late_group['index'].nunique()} unique indices")
    
    selected_indices = []
    
    # Step 2: From each group, take as many distinct tier 4 samples as possible
    for group_name, group_df in [("Early", early_group), ("Middle", middle_group), ("Late", late_group)]:
        tier4_indices = group_df[group_df['tier'] == 'tier4']['index'].unique()
        print(f"{group_name} group tier 4 unique indices: {len(tier4_indices)}")
        selected_indices.extend(tier4_indices.tolist())
    
    print(f"Total tier 4 indices selected: {len(selected_indices)}")
    
    # Step 3: For remaining, randomly sample distinct indices from other tiers
    remaining_needed = target_count - len(selected_indices)
    
    if remaining_needed > 0:
        print(f"Need {remaining_needed} more indices from other tiers")
        
        # Get all non-tier4 indices that aren't already selected
        non_tier4_df = computational_df[computational_df['tier'] != 'tier4']
        available_indices = non_tier4_df['index'].unique()
        available_indices = [idx for idx in available_indices if idx not in selected_indices]
        
        print(f"Available non-tier4 indices: {len(available_indices)}")
        
        if len(available_indices) >= remaining_needed:
            additional_indices = random.sample(available_indices, remaining_needed)
            selected_indices.extend(additional_indices)
        else:
            print(f"Warning: Only {len(available_indices)} available, adding all")
            selected_indices.extend(available_indices)
    
    print(f"Final selection: {len(selected_indices)} unique indices")
    
    # Verify uniqueness
    if len(selected_indices) != len(set(selected_indices)):
        print(f"ERROR: Duplicate indices detected!")
        selected_indices = list(set(selected_indices))
        print(f"After deduplication: {len(selected_indices)} unique indices")
    
    # Show final distribution
    selected_df = computational_df[computational_df['index'].isin(selected_indices)]
    selected_unique = selected_df.drop_duplicates('index')
    
    print(f"\n--- Final Distribution ---")
    tier_counts = selected_unique['tier'].value_counts().sort_index()
    print("Tier distribution:")
    for tier in ['tier1', 'tier2', 'tier3', 'tier4']:
        count = tier_counts.get(tier, 0)
        pct = (count / len(selected_unique)) * 100
        print(f"  {tier}: {count} indices ({pct:.1f}%)")
    
    # Position distribution of selected indices
    print("\nPosition distribution:")
    for group_name, (min_pos, max_pos) in [
        ("Early", (0, tertile_33)),
        ("Middle", (tertile_33, tertile_67)), 
        ("Late", (tertile_67, 1.0))
    ]:
        group_selected = selected_unique[
            (selected_unique['relative_line_position'] >= min_pos) & 
            (selected_unique['relative_line_position'] <= max_pos)
        ]
        count = len(group_selected)
        pct = (count / len(selected_unique)) * 100
        print(f"  {group_name}: {count} indices ({pct:.1f}%)")
    
    return selected_indices

def get_disjoint_sets_simple(master_df):
    """
    Create three disjoint sets with simplified computational sampling.
    """
    
    print("=== Creating Disjoint Sets with Simple Sampling ===")
    
    # Set A: All indices with conceptual errors
    conceptual_indices = master_df[
        master_df['error_type'] == 'conceptual_error'
    ]['index'].unique()
    conceptual_indices = sorted(conceptual_indices)
    N = len(conceptual_indices)
    
    print(f"Set A (Conceptual errors): {N} unique indices")
    
    # Set B: Simple computational error sampling
    computational_indices = select_balanced_computational_indices_simple(
        master_df, conceptual_indices, N
    )
    
    print(f"Set B (Computational errors): {len(computational_indices)} unique indices")
    
    # Set C: Correct samples from remaining disjoint indices
    used_indices = set(conceptual_indices) | set(computational_indices)
    all_indices = set(master_df['index'].unique())
    remaining_indices = list(all_indices - used_indices)
    
    print(f"Available indices for correct samples: {len(remaining_indices)}")
    
    if len(remaining_indices) < len(computational_indices):
        print(f"WARNING: Only {len(remaining_indices)} remaining indices available, need {len(computational_indices)}")
        correct_indices = remaining_indices
    else:
        correct_indices = random.sample(remaining_indices, len(computational_indices))
    
    print(f"Set C (Correct samples): {len(correct_indices)} unique indices")
    
    # Final verification that all sets are truly disjoint
    set_A = set(conceptual_indices)
    set_B = set(computational_indices)
    set_C = set(correct_indices)
    
    overlap_AB = len(set_A & set_B)
    overlap_AC = len(set_A & set_C)
    overlap_BC = len(set_B & set_C)
    
    print(f"\n--- Disjoint Verification ---")
    print(f"Conceptual ∩ Computational: {overlap_AB} indices (should be 0)")
    print(f"Conceptual ∩ Correct: {overlap_AC} indices (should be 0)")
    print(f"Computational ∩ Correct: {overlap_BC} indices (should be 0)")
    
    total_unique_indices = len(set_A | set_B | set_C)
    expected_total = len(conceptual_indices) + len(computational_indices) + len(correct_indices)
    print(f"Total unique indices: {total_unique_indices} (should equal {expected_total})")
    
    if total_unique_indices != expected_total:
        print("ERROR: Sets are not properly disjoint!")
    else:
        print("✓ All sets are properly disjoint")
    
    return conceptual_indices, computational_indices, correct_indices

conceptual_indices, computational_indices, correct_indices = get_disjoint_sets_simple(master_df)

=== Creating Disjoint Sets with Simple Sampling ===
Set A (Conceptual errors): 1881 unique indices
=== Simple Computational Error Sampling ===
Available computational error samples: 16287
Available unique indices: 4896
Position tertiles: 33%=0.250, 67%=0.600
Early group: 5599 samples, 3676 unique indices
Middle group: 5680 samples, 3654 unique indices
Late group: 5008 samples, 2866 unique indices
Early group tier 4 unique indices: 181
Middle group tier 4 unique indices: 183
Late group tier 4 unique indices: 180
Total tier 4 indices selected: 544
Need 1337 more indices from other tiers
Available non-tier4 indices: 4627
Final selection: 1881 unique indices
ERROR: Duplicate indices detected!
After deduplication: 1606 unique indices

--- Final Distribution ---
Tier distribution:
  tier1: 581 indices (36.2%)
  tier2: 164 indices (10.2%)
  tier3: 590 indices (36.7%)
  tier4: 269 indices (16.7%)

Position distribution:
  Early: 1156 indices (72.0%)
  Middle: 387 indices (24.1%)
  Late: 157 in

In [4]:
def add_line_separator_tokens(solution_text: str, line_sep_token: str = LINE_SEP_TOKEN) -> str:
    """
    Since preprocessing is already done in master catalog, just add line separator tokens.
    
    Args:
        solution_text: Already preprocessed solution text from sanitized master catalog
        line_sep_token: Special token to use for line separation
        
    Returns:
        Solution text with line separator tokens for tokenization
    """
    if not isinstance(solution_text, str):
        return ""
    
    # Split by newlines, filter empty lines, and join with special token
    lines = solution_text.split('\n')
    non_empty_lines = [line.strip() for line in lines if line.strip()]  # Filter empty lines
    
    processed = line_sep_token.join(non_empty_lines) + line_sep_token
    return processed

def create_user_prompt_simple(question: str, solution: str) -> str:
    """
    Creates user prompt from already sanitized question and solution.
    
    Args:
        question: Already sanitized GSM8K problem statement
        solution: Already preprocessed solution text with line separators
        
    Returns:
        Formatted user prompt for line-level error detection
    """
    system_instruction = "You are an expert in mathematical problem solving. Analyze the following mathematical problem and solution to identify whether there are any errors. At the end of every line, ask yourself- 'Is the solution correct so far? In particular, have there been any computational errors, or misinterpretations of the problem, or any flaws in the logic?' You must carefully examine it line-by-line because it is important to detect the VERY FIRST LINE WITH AN ERROR."
    
    user_prompt = f"""{system_instruction}

### Problem:
{question}

### Solution:
{solution}"""
    
    return user_prompt

In [5]:
def create_progressive_line_labels(solution_with_tokens, relative_line_position, solution_length, sample_index=None, is_correct=False):
    """
    Create PROGRESSIVE line labels for both error and correct samples.
    
    This answers the question: "Is the solution correct up to this point?"
    - 0 = Solution is correct up to and including this line
    - 1 = Solution becomes incorrect starting from this line (and remains incorrect)
    
    Args:
        solution_with_tokens (str): Solution text with line separator tokens added
        relative_line_position (float or None): Relative position of error (0.0-1.0), 
                                              or None for correct solutions
        solution_length (int): Pre-computed number of lines (from sanitized catalog)
        sample_index (int, optional): Sample index for debugging mismatch reports
        is_correct (bool): Whether this is a correct solution (all labels should be 0)
    
    Returns:
        list: Progressive labels [0, 0, 1, 1, 1, ...] where 1s start from first error
              For correct solutions, all labels are 0.
    """
    
    if pd.isna(solution_with_tokens) or solution_with_tokens is None or solution_with_tokens == "":
        return []
    
    # Use the pre-computed solution_length from sanitized catalog
    num_lines = int(solution_length)
    
    # Verify actual line count matches stored length (for debugging)
    if LINE_SEP_TOKEN in solution_with_tokens:
        actual_line_count = solution_with_tokens.count(LINE_SEP_TOKEN)
        if actual_line_count != num_lines:
            index_info = f" (Index: {sample_index})" if sample_index is not None else ""
            print(f"Warning: Line count mismatch{index_info}! Stored: {num_lines}, Actual: {actual_line_count}")
            # Use actual count to prevent index errors
            num_lines = actual_line_count
    
    # Handle edge case of empty solutions
    if num_lines <= 0:
        return []
    
    # Initialize all labels as 0 (correct)
    line_labels = [0] * num_lines
    
    # For correct samples, all labels stay 0
    if is_correct:
        return line_labels
    
    # For error samples, set error line and ALL SUBSEQUENT LINES to 1
    if relative_line_position is not None and not pd.isna(relative_line_position):
        # Convert relative position (0.0-1.0) to absolute line number (0-based)
        if num_lines == 1:
            error_line_number = 0  # Only one line, must be line 0
        else:
            error_line_number = int(relative_line_position * (num_lines - 1))
        
        # Ensure error line is within valid bounds
        error_line_number = max(0, min(error_line_number, num_lines - 1))
        
        # Set error line and ALL SUBSEQUENT LINES to 1
        for i in range(error_line_number, num_lines):
            line_labels[i] = 1
    
    return line_labels


def select_best_sample(idx_samples):
    """
    Select the best sample from a group of samples for the same index.
    Priority: manual > programmatic
    """
    # Prioritize manual over programmatic
    manual_samples = idx_samples[idx_samples['source'] == 'manual']
    if len(manual_samples) > 0:
        return manual_samples.iloc[0]  # Take first manual sample
    else:
        return idx_samples.iloc[0]  # Take first programmatic sample

In [6]:
def create_3n_progressive_line_classification_dataset(master_df, conceptual_indices, computational_indices, correct_indices):
    """
    Create 3N dataset with PROGRESSIVE labeling strategy (no tokenization).
    Labels answer: "Is the solution correct up to this point?"
    
    Args:
        master_df: Master catalog dataframe
        conceptual_indices: List of indices for conceptual errors
        computational_indices: List of indices for computational errors  
        correct_indices: List of indices for correct solutions
    """
    
    print("=== Creating 3N Progressive Line Classification Dataset ===")
    print("Strategy: 0s for correct lines, 1s from first error onwards")
    print(f"Dataset composition: {len(conceptual_indices)} conceptual + {len(computational_indices)} computational + {len(correct_indices)} correct")
    
    dataset = []
    
    # Process conceptual error samples
    print(f"Processing {len(conceptual_indices)} conceptual error samples...")
    for idx in tqdm(conceptual_indices, desc="Conceptual errors"):
        idx_samples = master_df[
            (master_df['index'] == idx) & 
            (master_df['error_type'] == 'conceptual_error')
        ]
        
        if len(idx_samples) == 0:
            continue
            
        selected_sample = select_best_sample(idx_samples)
        
        # Use preprocessed text from master catalog
        preprocessed_solution = selected_sample['wrong_answer']
        solution_with_tokens = add_line_separator_tokens(preprocessed_solution)
        preprocessed_question = selected_sample['question']
        user_prompt = create_user_prompt_simple(preprocessed_question, solution_with_tokens)
        
        # Create PROGRESSIVE labels for error sample
        line_labels = create_progressive_line_labels(
            solution_with_tokens, 
            selected_sample['relative_line_position'],
            selected_sample['solution_length'],
            sample_index=idx,
            is_correct=False
        )
        
        # Verify we have a valid progressive sequence
        if len(line_labels) > 0 and any(label == 1 for label in line_labels):
            first_error_idx = line_labels.index(1)
            labels_after_error = line_labels[first_error_idx:]
            
            if all(label == 1 for label in labels_after_error):
                sample = {
                    'text': user_prompt,
                    'correct_answer': selected_sample['correct_answer'],
                    'line_labels': line_labels,  # Progressive labels [0,0,1,1,1...]
                    'error_type': 'conceptual_error',
                    'index': idx,
                    'tier': selected_sample['tier'],
                    'source': selected_sample['source'],
                    'relative_line_position': selected_sample['relative_line_position'],
                    'solution_length': selected_sample['solution_length'],
                    'first_error_line': first_error_idx
                }
                dataset.append(sample)
    
    # Process computational error samples
    print(f"Processing {len(computational_indices)} computational error samples...")
    for idx in tqdm(computational_indices, desc="Computational errors"):
        idx_samples = master_df[
            (master_df['index'] == idx) & 
            (master_df['error_type'] == 'computational_error')
        ]
        
        if len(idx_samples) == 0:
            continue
            
        selected_sample = select_best_sample(idx_samples)
        
        preprocessed_solution = selected_sample['wrong_answer']
        solution_with_tokens = add_line_separator_tokens(preprocessed_solution)
        preprocessed_question = selected_sample['question']
        user_prompt = create_user_prompt_simple(preprocessed_question, solution_with_tokens)
        
        line_labels = create_progressive_line_labels(
            solution_with_tokens, 
            selected_sample['relative_line_position'],
            selected_sample['solution_length'],
            sample_index=idx,
            is_correct=False
        )
        
        if len(line_labels) > 0 and any(label == 1 for label in line_labels):
            first_error_idx = line_labels.index(1)
            labels_after_error = line_labels[first_error_idx:]
            
            if all(label == 1 for label in labels_after_error):
                sample = {
                    'text': user_prompt,
                    'correct_answer': selected_sample['correct_answer'],
                    'line_labels': line_labels,  # Progressive labels
                    'error_type': 'computational_error',
                    'index': idx,
                    'tier': selected_sample['tier'],
                    'source': selected_sample['source'],
                    'relative_line_position': selected_sample['relative_line_position'],
                    'solution_length': selected_sample['solution_length'],
                    'first_error_line': first_error_idx
                }
                dataset.append(sample)
    
    # Process correct samples
    print(f"Processing {len(correct_indices)} correct solution samples...")
    for idx in tqdm(correct_indices, desc="Correct solutions"):
        idx_samples = master_df[master_df['index'] == idx]
        
        if len(idx_samples) == 0:
            print(f"Warning: No data found for correct index {idx}")
            continue
        
        reference_sample = idx_samples.iloc[0]
        preprocessed_question = reference_sample['question']
        correct_solution = reference_sample['correct_answer']
        
        solution_with_tokens = add_line_separator_tokens(correct_solution)
        user_prompt = create_user_prompt_simple(preprocessed_question, solution_with_tokens)
        
        # Calculate solution length for correct answer
        lines = correct_solution.split('\n')
        non_empty_lines = [line.strip() for line in lines if line.strip()]
        solution_length = len(non_empty_lines)
        
        # Create all-zero labels for correct solution
        line_labels = create_progressive_line_labels(
            solution_with_tokens,
            relative_line_position=None,
            solution_length=solution_length,
            sample_index=idx,
            is_correct=True
        )
        
        if len(line_labels) > 0:
            sample = {
                'text': user_prompt,
                'correct_answer': correct_solution,
                'line_labels': line_labels,  # All zeros [0,0,0,0...]
                'error_type': 'correct',
                'index': idx,
                'tier': reference_sample['tier'],
                'source': 'original_gsm8k',
                'relative_line_position': None,
                'solution_length': solution_length,
                'first_error_line': -1  # No error
            }
            dataset.append(sample)
    
    dataset_df = pd.DataFrame(dataset)
    
    print(f"\nTotal samples created: {len(dataset_df)}")
    
    # Validation
    print("\n--- Progressive Label Validation ---")
    valid_progressive = 0
    for _, row in dataset_df.iterrows():
        line_labels = row['line_labels']
        if row['error_type'] == 'correct':
            # Should be all zeros
            if all(label == 0 for label in line_labels):
                valid_progressive += 1
        else:
            # Should have progressive pattern [0,0,1,1,1...]
            if any(label == 1 for label in line_labels):
                first_error = line_labels.index(1)
                if all(label == 1 for label in line_labels[first_error:]):
                    valid_progressive += 1
    
    print(f"Valid progressive patterns: {valid_progressive}/{len(dataset_df)} ({valid_progressive/len(dataset_df)*100:.1f}%)")
    
    return dataset_df

In [7]:
def save_3n_progressive_dataset(dataset_df, output_dir):
    """
    Save the 3N progressive line classification dataset.
    """
    
    print("=== Saving 3N Progressive Line Classification Dataset ===")
    
    output_dir.mkdir(parents=True, exist_ok=True)
    
    # Save main dataset
    dataset_path = output_dir / "3n_progressive_line_classification_dataset_2.csv"
    dataset_df.to_csv(dataset_path, index=False)
    print(f"✓ Dataset saved to: {dataset_path}")
    print(f"  Size: {len(dataset_df):,} samples")
    print(f"  Unique indices: {dataset_df['index'].nunique():,}")
    
    # Calculate class balance for metadata
    total_line_labels = sum(len(labels) for labels in dataset_df['line_labels'])
    total_error_lines = sum(sum(labels) for labels in dataset_df['line_labels'])
    total_correct_lines = total_line_labels - total_error_lines
    imbalance_ratio = total_correct_lines / total_error_lines if total_error_lines > 0 else float('inf')
    
    # Create metadata
    metadata = {
        "creation_info": {
            "creation_date": pd.Timestamp.now().isoformat(),
            "random_seed": RANDOM_SEED,
            "source_catalog": "master_catalog_sanitized.csv",
            "creator": "make-line-classification-dataset.ipynb",
            "dataset_strategy": "3n_progressive_labeling"
        },
        "labeling_strategy": {
            "approach": "progressive",
            "description": "Labels answer: 'Is the solution correct up to this point?'",
            "correct_samples": "All labels are 0 (solution is correct throughout)",
            "error_samples": "Progressive pattern [0,0,1,1,1...] starting from first error",
            "advantages": ["Better class balance", "More training signal", "Conceptually aligned"]
        },
        "dataset_info": {
            "description": "3N progressive line-level error detection dataset_2",
            "strategy": "3n_progressive_balanced", 
            "total_samples": len(dataset_df),
            "unique_problems": dataset_df['index'].nunique(),
            "composition": dict(dataset_df['error_type'].value_counts()),
            "source_distribution": dict(dataset_df['source'].value_counts()),
            "tier_distribution": dict(dataset_df['tier'].value_counts()),
            "class_balance": {
                "total_line_positions": total_line_labels,
                "error_line_positions": total_error_lines,
                "correct_line_positions": total_correct_lines,
                "imbalance_ratio": f"{imbalance_ratio:.1f}:1",
                "error_percentage": f"{total_error_lines/total_line_labels*100:.1f}%"
            }
        },
        "column_descriptions": {
            "text": "Complete user prompt: system instruction + problem + solution",
            "correct_answer": "Original GSM8K correct answer for reference",
            "line_labels": "Progressive labels [0,0,1,1,1...] indicating correctness up to each line",
            "error_type": "'conceptual_error', 'computational_error', or 'correct'",
            "index": "Original GSM8K problem index",
            "tier": "Problem difficulty tier (tier1-tier4)",
            "source": "'manual', 'programmatic', or 'original_gsm8k'"
        }
    }
    
    metadata_path = output_dir / "3n_progressive_dataset_metadata_2.json"
    with open(metadata_path, 'w') as f:
        json.dump(metadata, f, indent=2, default=str)
    print(f"✓ Metadata saved to: {metadata_path}")
    
    return output_dir

In [8]:
# Execute the 3N progressive strategy
print("🚀 Creating 3N progressive line classification dataset...")
progressive_3n_dataset_df = create_3n_progressive_line_classification_dataset(
    master_df, conceptual_indices, computational_indices, correct_indices
)

# Save the 3N progressive dataset
saved_3n_dir = save_3n_progressive_dataset(progressive_3n_dataset_df, OUTPUT_DIR)
print(f"\n📁 3N progressive dataset saved to: {saved_3n_dir}")

# Quick validation
print(f"\n🔍 QUICK VALIDATION:")
print(f"Dataset length: {len(progressive_3n_dataset_df)}")
print(f"Error type distribution:")
print(progressive_3n_dataset_df['error_type'].value_counts())

# Show a few examples
print(f"\n📋 SAMPLE EXAMPLES:")
for error_type in ['correct', 'conceptual_error', 'computational_error']:
    sample = progressive_3n_dataset_df[progressive_3n_dataset_df['error_type'] == error_type].iloc[0]
    line_labels = sample['line_labels']
    first_one = line_labels.index(1) if 1 in line_labels else -1
    print(f"  {error_type}: {len(line_labels)} lines, first error at {first_one}, labels: {line_labels}")

🚀 Creating 3N progressive line classification dataset...
=== Creating 3N Progressive Line Classification Dataset ===
Strategy: 0s for correct lines, 1s from first error onwards
Dataset composition: 1881 conceptual + 1606 computational + 1606 correct
Processing 1881 conceptual error samples...


Conceptual errors:   0%|          | 0/1881 [00:00<?, ?it/s]

Processing 1606 computational error samples...


Computational errors:   0%|          | 0/1606 [00:00<?, ?it/s]

Processing 1606 correct solution samples...


Correct solutions:   0%|          | 0/1606 [00:00<?, ?it/s]


Total samples created: 5093

--- Progressive Label Validation ---
Valid progressive patterns: 5093/5093 (100.0%)
=== Saving 3N Progressive Line Classification Dataset ===
✓ Dataset saved to: /Users/arvindsuresh/Documents/Github/Erdos-DL-June25-Math/data/line-classification/3n_progressive_line_classification_dataset_2.csv
  Size: 5,093 samples
  Unique indices: 5,093
✓ Metadata saved to: /Users/arvindsuresh/Documents/Github/Erdos-DL-June25-Math/data/line-classification/3n_progressive_dataset_metadata_2.json

📁 3N progressive dataset saved to: /Users/arvindsuresh/Documents/Github/Erdos-DL-June25-Math/data/line-classification

🔍 QUICK VALIDATION:
Dataset length: 5093
Error type distribution:
error_type
conceptual_error       1881
computational_error    1606
correct                1606
Name: count, dtype: int64

📋 SAMPLE EXAMPLES:
  correct: 5 lines, first error at -1, labels: [0, 0, 0, 0, 0]
  conceptual_error: 3 lines, first error at 1, labels: [0, 1, 1]
  computational_error: 4 lines, 

In [9]:
# ==============================================================================
# DEBUGGING: Line Count Mismatch Investigation
# ==============================================================================
def debug_line_count_mismatch(master_df, sample_indices=None, max_samples=20):
    """
    Debug line count mismatches between stored solution_length and actual preprocessed lines.
    """
    print("🔍 DEBUGGING LINE COUNT MISMATCH")
    print("=" * 70)
    
    if sample_indices is None:
        # Use the problematic indices we found earlier
        sample_indices = [9, 18, 25, 29, 30]  # From the debugging output
    
    problematic_cases = []
    
    for idx in sample_indices[:max_samples]:
        print(f"\n{'='*60}")
        print(f"DEBUGGING INDEX {idx}")
        print(f"{'='*60}")
        
        # Get the sample from master catalog
        sample_data = master_df[master_df['index'] == idx]
        if len(sample_data) == 0:
            print(f"No data found for index {idx}")
            continue
        
        # Take the first available sample (conceptual or computational)
        sample = sample_data.iloc[0]
        
        print(f"📍 Error type: {sample['error_type']}")
        print(f"📍 Source: {sample['source']}")
        print(f"📍 Tier: {sample['tier']}")
        print(f"📍 Stored solution_length: {sample['solution_length']}")
        print(f"📍 Stored relative_line_position: {sample['relative_line_position']}")
        
        # Get the raw solution
        raw_solution = sample['wrong_answer']
        print(f"\n--- RAW SOLUTION ---")
        print(repr(raw_solution))
        
        # Process the solution step by step
        print(f"\n--- PREPROCESSING STEPS ---")
        
        # Step 1: Convert literal \n to actual newlines
        step1 = raw_solution.replace('\\n', '\n')
        print(f"Step 1 - After \\n conversion:")
        print(repr(step1))
        
        # Step 2: Remove calculator annotations
        import re
        step2 = re.sub(r'<<.*?>>', '', step1)
        print(f"Step 2 - After calculator annotation removal:")
        print(repr(step2))
        
        # Step 3: Sanitize Unicode and commas (simplified)
        step3 = step2  # Skip for now to focus on line counting
        
        # Step 4: Split by newlines and count
        lines_before_filter = step3.split('\n')
        print(f"Step 4 - Lines before filtering:")
        for i, line in enumerate(lines_before_filter):
            print(f"  Line {i}: {repr(line)}")
        
        # Step 5: Filter empty lines
        non_empty_lines = [line.strip() for line in lines_before_filter if line.strip()]
        print(f"Step 5 - Non-empty lines after filtering:")
        for i, line in enumerate(non_empty_lines):
            print(f"  Line {i}: {repr(line)}")
        
        # Calculate actual vs expected
        actual_line_count = len(non_empty_lines)
        stored_line_count = int(sample['solution_length'])
        
        print(f"\n--- LINE COUNT COMPARISON ---")
        print(f"📊 Stored solution_length: {stored_line_count}")
        print(f"📊 Actual preprocessed lines: {actual_line_count}")
        print(f"📊 Difference: {actual_line_count - stored_line_count}")
        
        # Show what the original line counting logic would have done
        print(f"\n--- ORIGINAL LINE COUNTING SIMULATION ---")
        # The original logic probably counted lines in the raw text
        original_lines = raw_solution.replace('\\n', '\n').split('\n')
        original_non_empty = [line.strip() for line in original_lines if line.strip()]
        print(f"Original raw line count: {len(original_non_empty)}")
        
        # Check if calculator annotations affected line count
        with_calc_lines = step1.split('\n')
        without_calc_lines = step2.split('\n')
        calc_lines_before = len([line.strip() for line in with_calc_lines if line.strip()])
        calc_lines_after = len([line.strip() for line in without_calc_lines if line.strip()])
        
        print(f"Lines before calculator removal: {calc_lines_before}")
        print(f"Lines after calculator removal: {calc_lines_after}")
        print(f"Calculator annotations removed {calc_lines_before - calc_lines_after} lines")
        
        # Calculate what the error position should be
        if sample['relative_line_position'] is not None:
            # Original calculation (wrong)
            original_error_line = int(sample['relative_line_position'] * (stored_line_count - 1))
            
            # Corrected calculation
            corrected_error_line = int(sample['relative_line_position'] * (actual_line_count - 1))
            
            print(f"\n--- ERROR POSITION COMPARISON ---")
            print(f"Original error line (using stored length): {original_error_line}")
            print(f"Corrected error line (using actual length): {corrected_error_line}")
            print(f"Valid range for actual lines: 0 to {actual_line_count - 1}")
            
            # Check if original position is out of bounds
            if original_error_line >= actual_line_count:
                print(f"❌ Original position {original_error_line} is OUT OF BOUNDS for {actual_line_count} lines!")
            else:
                print(f"✅ Original position {original_error_line} is within bounds")
        
        problematic_cases.append({
            'index': idx,
            'stored_length': stored_line_count,
            'actual_length': actual_line_count,
            'difference': actual_line_count - stored_line_count,
            'error_type': sample['error_type'],
            'relative_position': sample['relative_line_position']
        })
    
    # Summary analysis
    print(f"\n{'='*70}")
    print("SUMMARY ANALYSIS")
    print(f"{'='*70}")
    
    differences = [case['difference'] for case in problematic_cases]
    
    print(f"📊 Analyzed {len(problematic_cases)} problematic cases")
    print(f"📊 Line count differences: {differences}")
    print(f"📊 Average difference: {np.mean(differences):.2f}")
    print(f"📊 Most common difference: {max(set(differences), key=differences.count)}")
    
    # Show patterns by error type
    conceptual_diffs = [case['difference'] for case in problematic_cases if case['error_type'] == 'conceptual_error']
    computational_diffs = [case['difference'] for case in problematic_cases if case['error_type'] == 'computational_error']
    
    print(f"\n--- PATTERNS BY ERROR TYPE ---")
    if conceptual_diffs:
        print(f"Conceptual errors: avg difference = {np.mean(conceptual_diffs):.2f}")
    if computational_diffs:
        print(f"Computational errors: avg difference = {np.mean(computational_diffs):.2f}")
    
    return problematic_cases

# Run the debugging
debug_results = debug_line_count_mismatch(master_df)

🔍 DEBUGGING LINE COUNT MISMATCH

DEBUGGING INDEX 9
📍 Error type: computational_error
📍 Source: manual
📍 Tier: tier2
📍 Stored solution_length: 9
📍 Stored relative_line_position: 0.75

--- RAW SOLUTION ---
'She works 8 hours a day for $18 per hour so she makes 8*18 = $144.00 per 8-hour shift\nShe works 10 hours a day and anything over 8 hours is eligible for overtime, so she gets 10-8 = 2 hours of overtime\nOvertime is calculated as time and a half so and she makes $18/hour so her overtime pay is 18*.5 = $9.00\nHer overtime pay is 18+9 = $27.00\nHer base pay is $144.00 per 8-hour shift and she works 5 days and makes 5 * $144 = $720.00\nHer overtime pay is $27.00 per hour and she works 2 hours of overtime per day and makes 27*2 = $54.00 in overtime pay\n2 hours of overtime pay for 5 days means she makes 54*5 = $250.00\nIn 5 days her base pay is $720.00 and she makes $250.00 in overtime pay so she makes $720 + $250 = $970.00\n#### 970'

--- PREPROCESSING STEPS ---
Step 1 - After \n convers

In [10]:
import pandas as pd
import ast

# Load both datasets
flawed_only_df = pd.read_csv('/Users/arvindsuresh/Documents/Github/Erdos-DL-June25-Math/data/line-classification/flawed-only/flawed_only_line_classification_dataset.csv')
progressive_3n_df = pd.read_csv('/Users/arvindsuresh/Documents/Github/Erdos-DL-June25-Math/data/line-classification/3n_progressive_line_classification_dataset_2.csv')

print("=== Dataset Comparison ===")
print(f"Flawed-only dataset: {len(flawed_only_df)} samples")
print(f"3N progressive dataset: {len(progressive_3n_df)} samples")

# Filter out correct samples from 3N dataset for comparison
progressive_errors_df = progressive_3n_df[progressive_3n_df['error_type'] != 'correct'].copy()
print(f"3N progressive (errors only): {len(progressive_errors_df)} samples")

# Parse line_labels from string representation
def parse_labels(label_str):
    return ast.literal_eval(label_str) if isinstance(label_str, str) else label_str

flawed_only_df['parsed_labels'] = flawed_only_df['line_labels'].apply(parse_labels)
progressive_errors_df['parsed_labels'] = progressive_errors_df['line_labels'].apply(parse_labels)

# Match rows by index and error_type
matches = []
mismatches = []

for _, flawed_row in flawed_only_df.iterrows():
    # Find matching row in progressive dataset
    matching_rows = progressive_errors_df[
        (progressive_errors_df['index'] == flawed_row['index']) & 
        (progressive_errors_df['error_type'] == flawed_row['error_type'])
    ]
    
    if len(matching_rows) == 0:
        mismatches.append({
            'type': 'no_match',
            'index': flawed_row['index'],
            'error_type': flawed_row['error_type'],
            'issue': 'No matching row found in progressive dataset'
        })
        continue
    
    if len(matching_rows) > 1:
        mismatches.append({
            'type': 'multiple_matches',
            'index': flawed_row['index'],
            'error_type': flawed_row['error_type'],
            'issue': f'Multiple matches found: {len(matching_rows)}'
        })
        continue
    
    progressive_row = matching_rows.iloc[0]
    
    # Check text column exact match
    text_match = flawed_row['text'] == progressive_row['text']
    
    # Check label lengths match
    flawed_labels = flawed_row['parsed_labels']
    progressive_labels = progressive_row['parsed_labels']
    length_match = len(flawed_labels) == len(progressive_labels)
    
    match_info = {
        'index': flawed_row['index'],
        'error_type': flawed_row['error_type'],
        'text_match': text_match,
        'length_match': length_match,
        'flawed_labels': flawed_labels,
        'progressive_labels': progressive_labels,
        'label_length': len(flawed_labels)
    }
    
    if text_match and length_match:
        matches.append(match_info)
    else:
        match_info['issues'] = []
        if not text_match:
            match_info['issues'].append('text_mismatch')
        if not length_match:
            match_info['issues'].append(f'length_mismatch: {len(flawed_labels)} vs {len(progressive_labels)}')
        mismatches.append(match_info)

print(f"\n=== Comparison Results ===")
print(f"Perfect matches: {len(matches)}")
print(f"Mismatches: {len(mismatches)}")

if mismatches:
    print(f"\n=== Mismatch Details ===")
    for i, mismatch in enumerate(mismatches[:5]):  # Show first 5 mismatches
        print(f"\nMismatch {i+1}:")
        print(f"  Index: {mismatch['index']}, Error type: {mismatch['error_type']}")
        if 'issues' in mismatch:
            print(f"  Issues: {mismatch['issues']}")
        else:
            print(f"  Issue: {mismatch['issue']}")

# Show labeling pattern comparison for a few matches
if matches:
    print(f"\n=== Label Pattern Comparison (First 3 Matches) ===")
    for i, match in enumerate(matches[:3]):
        print(f"\nSample {i+1} (Index {match['index']}):")
        print(f"  Flawed-only labels:    {match['flawed_labels']}")
        print(f"  Progressive labels:    {match['progressive_labels']}")
        
        # Show the transformation
        flawed = match['flawed_labels']
        progressive = match['progressive_labels']
        if 1 in flawed:
            error_pos = flawed.index(1)
            print(f"  Error at position {error_pos}: {flawed} → {progressive}")

print(f"\n=== Summary ===")
print(f"Total comparisons: {len(flawed_only_df)}")
print(f"Perfect matches: {len(matches)} ({len(matches)/len(flawed_only_df)*100:.1f}%)")
print(f"Issues found: {len(mismatches)} ({len(mismatches)/len(flawed_only_df)*100:.1f}%)")

=== Dataset Comparison ===
Flawed-only dataset: 3487 samples
3N progressive dataset: 5093 samples
3N progressive (errors only): 3487 samples

=== Comparison Results ===
Perfect matches: 0
Mismatches: 3487

=== Mismatch Details ===

Mismatch 1:
  Index: 1, Error type: conceptual_error
  Issues: ['text_mismatch']

Mismatch 2:
  Index: 2, Error type: conceptual_error
  Issues: ['text_mismatch']

Mismatch 3:
  Index: 4, Error type: conceptual_error
  Issues: ['text_mismatch']

Mismatch 4:
  Index: 6, Error type: conceptual_error
  Issues: ['text_mismatch']

Mismatch 5:
  Index: 10, Error type: conceptual_error
  Issues: ['text_mismatch']

=== Summary ===
Total comparisons: 3487
Perfect matches: 0 (0.0%)
Issues found: 3487 (100.0%)
