# Error Correction Analysis for Classification Noise

This notebook performs error correction to account for classification noise in alignment and formatting feature labels. It implements the approach of solving a system of linear equations to estimate the "pure" loss per feature (LPF) for each category.

## Approach:
1. Load model classifications (DeepSeek) and human labels
2. Compute confusion matrices to estimate contamination rates
3. Solve system of equations to get corrected LPF values
4. Compare naive vs corrected results

## Data:
- **Models**: Gemma 2 2B and Gemma 2 9B
- **Categories**: Alignment and Formatting (Style)
- **Losses**: Hardcoded from experimental results

In [1]:
import json
import numpy as np
import pandas as pd
from pathlib import Path
import re

# Helper function to load JSON records
def load_json_records(p: Path) -> list[dict]:
    """Load JSON data, handling both list and dict formats."""
    with p.open('r', encoding='utf-8') as f:
        data = json.load(f)
    if isinstance(data, dict):
        for k in ("items", "records", "data"):
            if k in data and isinstance(data[k], list):
                return data[k]
        return [dict(feature_id=k, **(v if isinstance(v, dict) else {"value": v})) for k, v in data.items()]
    assert isinstance(data, list), "Expected list[dict] JSON structure"
    return data

def extract_index_from_row(row: dict) -> int | None:
    """Extract feature index from various possible fields."""
    if 'index' in row and row['index'] is not None:
        try:
            return int(row['index'])
        except Exception:
            pass
    fid = row.get('feature_id')
    if isinstance(fid, str):
        m = re.search(r'-(\d+)$', fid)
        if m:
            try:
                return int(m.group(1))
            except Exception:
                pass
    return None

def compute_confusion_matrix(model_labels: list, human_labels: list) -> dict:
    """
    Compute confusion matrix given model and human labels.
    Returns dict with TP, TN, FP, FN.
    Assumes 'related' = positive class, 'not-related' = negative class
    """
    assert len(model_labels) == len(human_labels), "Label lists must be same length"
    
    tp = sum(1 for m, h in zip(model_labels, human_labels) if m == 'related' and h == 'related')
    tn = sum(1 for m, h in zip(model_labels, human_labels) if m == 'not-related' and h == 'not-related')
    fp = sum(1 for m, h in zip(model_labels, human_labels) if m == 'related' and h == 'not-related')
    fn = sum(1 for m, h in zip(model_labels, human_labels) if m == 'not-related' and h == 'related')
    
    return {'TP': tp, 'TN': tn, 'FP': fp, 'FN': fn}

print("Helper functions loaded successfully.")

Helper functions loaded successfully.


## Configuration: Model and Loss Data

Configure which model to analyze and provide the experimental loss data.

In [2]:
# ============================================================
# CONFIGURATION: Select model and provide loss data
# ============================================================

# Choose model: 'gemma-2-2b' or 'gemma-2-9b'
MODEL = 'gemma-2-2b'

# Loss data (hardcoded from experimental results)
# Format: {model: {'baseline_loss': X, 'alignment_loss': Y, 'formatting_loss': Z}}
LOSS_DATA = {
    'gemma-2-2b': {
        'baseline_loss': 2.60,      # Baseline SimPO loss
        'alignment_loss': 2.88,     # Loss after ablating alignment features
        'formatting_loss': 4.21     # Loss after ablating formatting features
    },
    'gemma-2-9b': {
        'baseline_loss': 2.50,      # Example values - UPDATE THESE
        'alignment_loss': 2.70,     # Example values - UPDATE THESE
        'formatting_loss': 3.80     # Example values - UPDATE THESE
    }
}

# Get loss data for selected model
baseline_loss = LOSS_DATA[MODEL]['baseline_loss']
alignment_loss = LOSS_DATA[MODEL]['alignment_loss']
formatting_loss = LOSS_DATA[MODEL]['formatting_loss']

# Calculate loss increases
loss_increase_A = alignment_loss - baseline_loss
loss_increase_S = formatting_loss - baseline_loss

print(f"Model: {MODEL}")
print(f"Baseline loss: {baseline_loss}")
print(f"Loss increase from ablating Alignment: {loss_increase_A:.4f}")
print(f"Loss increase from ablating Formatting (Style): {loss_increase_S:.4f}")

Model: gemma-2-2b
Baseline loss: 2.6
Loss increase from ablating Alignment: 0.2800
Loss increase from ablating Formatting (Style): 1.6100


## Step 1: Load Model Classifications and Count Features

Load the DeepSeek classifications for alignment and formatting, and count the number of features predicted in each category.

In [3]:
# Define file paths based on model selection
base_path = Path('../../outputs/feature_classification')

# Map model names to their file patterns
MODEL_FILES = {
    'gemma-2-2b': {
        'alignment': base_path / 'gemma-2-2b' / '12-gemmascope-res-65k__l0-21_alignment_classified_deepseek-v3-0324.json',
        'formatting': base_path / 'gemma-2-2b' / '12-gemmascope-res-65k__l0-21_formatting_classified_deepseek-v3-0324.json',
        'human_alignment': base_path / 'human_labels' / '12-gemmascope-res-65k__l0-21_human_labels_alignment.json',
        'human_formatting': base_path / 'human_labels' / '12-gemmascope-res-65k__l0-21_human_labels_formatting.json'
    },
    'gemma-2-9b': {
        'alignment': base_path / 'gemma-2-9b' / '12-gemmascope-res-16k_canonical_alignment_classified_deepseek-deepseek-chat-v3-0324.json',
        'formatting': base_path / 'gemma-2-9b' / '12-gemmascope-res-16k_canonical_formatting_classified_deepseek-deepseek-chat-v3-0324.json',
        'human_alignment': base_path / 'human_labels' / '12-gemmascope-res-16k_canonical_human_labels_alignment.json',
        'human_formatting': base_path / 'human_labels' / '12-gemmascope-res-16k_canonical_human_labels_formatting.json'
    }
}

# Load model classifications
alignment_file = MODEL_FILES[MODEL]['alignment']
formatting_file = MODEL_FILES[MODEL]['formatting']

print(f"Loading model classifications for {MODEL}...")
print(f"Alignment file: {alignment_file.name}")
print(f"Formatting file: {formatting_file.name}")

# Load and process alignment classifications
alignment_records = load_json_records(alignment_file)
alignment_df = pd.DataFrame(alignment_records)
alignment_df['feature_index'] = alignment_df.apply(lambda r: extract_index_from_row(r), axis=1)

# Load and process formatting classifications
formatting_records = load_json_records(formatting_file)
formatting_df = pd.DataFrame(formatting_records)
formatting_df['feature_index'] = formatting_df.apply(lambda r: extract_index_from_row(r), axis=1)

# Count features predicted as positive in each category
N_predicted_A = (alignment_df['label'] == 'related').sum()
N_predicted_S = (formatting_df['label'] == 'related').sum()

print(f"\nFeature counts from model predictions:")
print(f"  Predicted as Alignment-related: {N_predicted_A}")
print(f"  Predicted as Formatting-related (Style): {N_predicted_S}")
print(f"  Total features: {len(alignment_df)}")

Loading model classifications for gemma-2-2b...
Alignment file: 12-gemmascope-res-65k__l0-21_alignment_classified_deepseek-v3-0324.json
Formatting file: 12-gemmascope-res-65k__l0-21_formatting_classified_deepseek-v3-0324.json

Feature counts from model predictions:
  Predicted as Alignment-related: 11143
  Predicted as Formatting-related (Style): 15391
  Total features: 65344

Feature counts from model predictions:
  Predicted as Alignment-related: 11143
  Predicted as Formatting-related (Style): 15391
  Total features: 65344


## Step 2: Load Human Labels and Compute Confusion Matrices

Load human validation labels and compute confusion matrices to estimate classification accuracy and contamination rates.

In [5]:
# Load human labels
human_alignment_file = MODEL_FILES[MODEL]['human_alignment']
human_formatting_file = MODEL_FILES[MODEL]['human_formatting']

print(f"Loading human validation labels...")
print(f"Human alignment file: {human_alignment_file.name}")
print(f"Human formatting file: {human_formatting_file.name}")

# Load human alignment labels
human_alignment_records = load_json_records(human_alignment_file)
human_alignment_df = pd.DataFrame(human_alignment_records)

# Load human formatting labels
human_formatting_records = load_json_records(human_formatting_file)
human_formatting_df = pd.DataFrame(human_formatting_records)

# Clean human data - use the 'feature_index' field directly if it exists
if 'feature_index' in human_alignment_df.columns:
    human_alignment_df_clean = human_alignment_df[['feature_index', 'label']].copy()
    human_alignment_df_clean = human_alignment_df_clean.dropna(subset=['feature_index'])
    human_alignment_df_clean['feature_index'] = human_alignment_df_clean['feature_index'].astype(int)
else:
    raise ValueError("Human alignment data missing 'feature_index' column")

if 'feature_index' in human_formatting_df.columns:
    human_formatting_df_clean = human_formatting_df[['feature_index', 'label']].copy()
    human_formatting_df_clean = human_formatting_df_clean.dropna(subset=['feature_index'])
    human_formatting_df_clean['feature_index'] = human_formatting_df_clean['feature_index'].astype(int)
else:
    raise ValueError("Human formatting data missing 'feature_index' column")

print(f"\nHuman labeled samples:")
print(f"  Alignment: {len(human_alignment_df_clean)} samples")
print(f"  Formatting: {len(human_formatting_df_clean)} samples")

# Debug: Show sample of human indices
if len(human_alignment_df_clean) > 0:
    print(f"  Sample alignment indices: {sorted(human_alignment_df_clean['feature_index'].head(5).tolist())}")
if len(human_formatting_df_clean) > 0:
    print(f"  Sample formatting indices: {sorted(human_formatting_df_clean['feature_index'].head(5).tolist())}")

# Get the feature indices that were human-labeled
human_labeled_indices_A = set(human_alignment_df_clean['feature_index'].tolist())
human_labeled_indices_S = set(human_formatting_df_clean['feature_index'].tolist())

# Extract feature indices from model data and ensure they're integers
alignment_df['feature_index'] = alignment_df.apply(lambda r: extract_index_from_row(r), axis=1)
formatting_df['feature_index'] = formatting_df.apply(lambda r: extract_index_from_row(r), axis=1)

# Debug: Show sample of model indices
alignment_sample_indices = alignment_df['feature_index'].dropna().head(10).astype(int).tolist()
formatting_sample_indices = formatting_df['feature_index'].dropna().head(10).astype(int).tolist()
print(f"\nModel data sample indices:")
print(f"  Alignment: {alignment_sample_indices}")
print(f"  Formatting: {formatting_sample_indices}")

# Filter model predictions to only include human-labeled features
# For alignment
alignment_df_filtered = alignment_df[alignment_df['feature_index'].notna()].copy()
alignment_df_filtered['feature_index'] = alignment_df_filtered['feature_index'].astype(int)
alignment_df_filtered = alignment_df_filtered[alignment_df_filtered['feature_index'].isin(human_labeled_indices_A)]
alignment_df_filtered = alignment_df_filtered[['feature_index', 'label']]

# For formatting
formatting_df_filtered = formatting_df[formatting_df['feature_index'].notna()].copy()
formatting_df_filtered['feature_index'] = formatting_df_filtered['feature_index'].astype(int)
formatting_df_filtered = formatting_df_filtered[formatting_df_filtered['feature_index'].isin(human_labeled_indices_S)]
formatting_df_filtered = formatting_df_filtered[['feature_index', 'label']]

print(f"\nModel predictions for human-labeled features:")
print(f"  Alignment: {len(alignment_df_filtered)} samples")
print(f"  Formatting: {len(formatting_df_filtered)} samples")

# Join model predictions with human labels for alignment
merged_alignment = alignment_df_filtered.merge(
    human_alignment_df_clean,
    on='feature_index',
    how='inner',
    suffixes=('_model', '_human')
)

# Join model predictions with human labels for formatting
merged_formatting = formatting_df_filtered.merge(
    human_formatting_df_clean,
    on='feature_index',
    how='inner',
    suffixes=('_model', '_human')
)

print(f"\nValidation sample sizes (after merge):")
print(f"  Alignment: {len(merged_alignment)} samples")
print(f"  Formatting: {len(merged_formatting)} samples")

# Compute confusion matrices
cm_A = compute_confusion_matrix(
    merged_alignment['label_model'].tolist(),
    merged_alignment['label_human'].tolist()
)

cm_S = compute_confusion_matrix(
    merged_formatting['label_model'].tolist(),
    merged_formatting['label_human'].tolist()
)

# Extract values
TP_A, TN_A, FP_A, FN_A = cm_A['TP'], cm_A['TN'], cm_A['FP'], cm_A['FN']
TP_S, TN_S, FP_S, FN_S = cm_S['TP'], cm_S['TN'], cm_S['FP'], cm_S['FN']

print(f"\nConfusion Matrix for Alignment Classifier:")
print(f"  TP (Correct Positive): {TP_A}")
print(f"  TN (Correct Negative): {TN_A}")
print(f"  FP (False Positive): {FP_A}")
print(f"  FN (False Negative): {FN_A}")
total_A = TP_A + TN_A + FP_A + FN_A
if total_A > 0:
    print(f"  Accuracy: {(TP_A + TN_A) / total_A:.2%}")
else:
    print(f"  Accuracy: N/A (no samples)")

print(f"\nConfusion Matrix for Formatting Classifier:")
print(f"  TP (Correct Positive): {TP_S}")
print(f"  TN (Correct Negative): {TN_S}")
print(f"  FP (False Positive): {FP_S}")
print(f"  FN (False Negative): {FN_S}")
total_S = TP_S + TN_S + FP_S + FN_S
if total_S > 0:
    print(f"  Accuracy: {(TP_S + TN_S) / total_S:.2%}")
else:
    print(f"  Accuracy: N/A (no samples)")

Loading human validation labels...
Human alignment file: 12-gemmascope-res-65k__l0-21_human_labels_alignment.json
Human formatting file: 12-gemmascope-res-65k__l0-21_human_labels_formatting.json

Human labeled samples:
  Alignment: 300 samples
  Formatting: 300 samples
  Sample alignment indices: [5475, 6903, 14974, 47733, 59708]
  Sample formatting indices: [5475, 6903, 14974, 47733, 59708]

Model data sample indices:
  Alignment: [53844, 53847, 53890, 53924, 53896, 53848, 53925, 53859, 53886, 53885]
  Formatting: [53844, 53847, 53890, 53924, 53896, 53848, 53925, 53859, 53886, 53885]

Model predictions for human-labeled features:
  Alignment: 300 samples
  Formatting: 300 samples

Validation sample sizes (after merge):
  Alignment: 300 samples
  Formatting: 300 samples

Confusion Matrix for Alignment Classifier:
  TP (Correct Positive): 27
  TN (Correct Negative): 227
  FP (False Positive): 22
  FN (False Negative): 24
  Accuracy: 84.67%

Confusion Matrix for Formatting Classifier:
  

## Step 3: Calculate Naive Loss Per Feature (LPF)

Calculate the naive (uncorrected) loss per feature by simply dividing the total loss increase by the number of predicted features in each category.

In [10]:
# Calculate naive LPF (no correction for contamination)
lpf_A_naive = loss_increase_A / N_predicted_A
lpf_S_naive = loss_increase_S / N_predicted_S

print("Naive LPF for Alignment: {:.2e}".format(lpf_A_naive))
print("Naive LPF for Formatting: {:.2e}".format(lpf_S_naive))
print("Naive Ratio: {:.2f}x".format(lpf_S_naive / lpf_A_naive))

Naive LPF for Alignment: 2.51e-05
Naive LPF for Formatting: 1.05e-04
Naive Ratio: 4.16x


## Step 4: Estimate Contamination in Each Bucket

Use the False Positive Rate from the confusion matrix to estimate how many features in each predicted category are actually contaminants from the other category.

In [11]:
# Calculate contamination rates using False Positive Rate
# FPR = FP / (TP + FP) - the fraction of positive predictions that are wrong

# For Alignment bucket
fp_rate_A = FP_A / (TP_A + FP_A) if (TP_A + FP_A) > 0 else 0
num_contaminants_in_A = N_predicted_A * fp_rate_A
num_pure_A = N_predicted_A - num_contaminants_in_A

# For Formatting bucket
fp_rate_S = FP_S / (TP_S + FP_S) if (TP_S + FP_S) > 0 else 0
num_contaminants_in_S = N_predicted_S * fp_rate_S
num_pure_S = N_predicted_S - num_contaminants_in_S

print(f"\nAlignment FPR: {fp_rate_A:.2%} | Pure: {num_pure_A:.0f}, Contaminants: {num_contaminants_in_A:.0f}")
print(f"Formatting FPR: {fp_rate_S:.2%} | Pure: {num_pure_S:.0f}, Contaminants: {num_contaminants_in_S:.0f}")


Alignment FPR: 44.90% | Pure: 6140, Contaminants: 5003
Formatting FPR: 6.76% | Pure: 14351, Contaminants: 1040


## Step 5: Solve System of Equations for Corrected LPF

Set up and solve a system of two linear equations to estimate the "pure" LPF for each category:

**Equation 1** (Alignment bucket):  
`loss_increase_A = (num_pure_A × lpf_pure_A) + (num_contaminants_in_A × lpf_pure_S)`

**Equation 2** (Formatting bucket):  
`loss_increase_S = (num_pure_S × lpf_pure_S) + (num_contaminants_in_S × lpf_pure_A)`

Solve for `lpf_pure_A` and `lpf_pure_S`.

In [12]:
# Set up the system of equations: Ax = b
# where x = [lpf_pure_A, lpf_pure_S]

# Coefficient matrix A
A = np.array([
    [num_pure_A, num_contaminants_in_A],      # Equation 1
    [num_contaminants_in_S, num_pure_S]        # Equation 2
])

# Right-hand side vector b
b = np.array([loss_increase_A, loss_increase_S])

# Solve the system
try:
    lpf_corrected = np.linalg.solve(A, b)
    lpf_pure_A = lpf_corrected[0]
    lpf_pure_S = lpf_corrected[1]
    
    print(f"\nCorrected LPF for Alignment: {lpf_pure_A:.2e}")
    print(f"Corrected LPF for Formatting: {lpf_pure_S:.2e}")
    if lpf_pure_A > 0:
        print(f"Corrected Ratio: {lpf_pure_S / lpf_pure_A:.2f}x")
    else:
        print(f"Corrected Ratio: N/A (alignment LPF is negative)")
    
except np.linalg.LinAlgError as e:
    print(f"\nERROR: Could not solve the system. {e}")
    lpf_pure_A = None
    lpf_pure_S = None


Corrected LPF for Alignment: -4.87e-05
Corrected LPF for Formatting: 1.16e-04
Corrected Ratio: N/A (alignment LPF is negative)


## Step 6: Summary Comparison

Compare the naive (uncorrected) results with the corrected results side-by-side.

In [14]:
if lpf_pure_A is not None and lpf_pure_S is not None:
    # Create comparison table
    comparison = pd.DataFrame({
        'Metric': [
            'LPF Alignment',
            'LPF Formatting (Style)',
            'Ratio (Style/Alignment)'
        ],
        'Naive (Uncorrected)': [
            f'{lpf_A_naive:.2e}',
            f'{lpf_S_naive:.2e}',
            f'{lpf_S_naive / lpf_A_naive:.2f}x'
        ],
        'Corrected (Pure)': [
            f'{lpf_pure_A:.2e}',
            f'{lpf_pure_S:.2e}',
            f'{lpf_pure_S / lpf_pure_A:.2f}x' if lpf_pure_A > 0 else 'N/A (negative denom)'
        ]
    })
    
    print("\n" + "=" * 60)
    print("SUMMARY")
    print("=" * 60)
    print(comparison.to_string(index=False))
    
    # Interpretation based on actual results
    print("\n" + "=" * 60)
    print("INTERPRETATION")
    print("=" * 60)
    
    if lpf_pure_A < 0 and lpf_pure_S > 0:
        print("After correcting for classification noise:")
        print(f"• Formatting features: {lpf_pure_S:.2e} LPF (positive impact on loss)")
        print(f"• Alignment features: {lpf_pure_A:.2e} LPF (negative - suggests little/no impact)")
        print("\nThe correction reveals that contamination by high-impact formatting features")
        print("was inflating the apparent impact of alignment features. Pure alignment features")
        print("have minimal or negative impact, while formatting features remain highly impactful.")
    elif lpf_pure_A > 0 and lpf_pure_S > 0:
        ratio_change = (lpf_pure_S / lpf_pure_A) / (lpf_S_naive / lpf_A_naive)
        if ratio_change > 1.1:
            print("After correction, the formatting/alignment ratio increased,")
            print("suggesting the gap is larger than initially estimated.")
        elif ratio_change < 0.9:
            print("After correction, the formatting/alignment ratio decreased,")
            print("suggesting the gap was overestimated due to classification noise.")
        else:
            print("After correction, the formatting/alignment ratio remains similar.")
    else:
        print("Unexpected result pattern - manual interpretation required.")
else:
    print("Could not compute comparison due to solver error.")


SUMMARY
                 Metric Naive (Uncorrected)     Corrected (Pure)
          LPF Alignment            2.51e-05            -4.87e-05
 LPF Formatting (Style)            1.05e-04             1.16e-04
Ratio (Style/Alignment)               4.16x N/A (negative denom)

INTERPRETATION
After correcting for classification noise:
• Formatting features: 1.16e-04 LPF (positive impact on loss)
• Alignment features: -4.87e-05 LPF (negative - suggests little/no impact)

The correction reveals that contamination by high-impact formatting features
was inflating the apparent impact of alignment features. Pure alignment features
have minimal or negative impact, while formatting features remain highly impactful.
