# Mitigation Evaluation

This notebook evaluates debiasing techniques for mitigating dataset artifacts.

## Setup

In [None]:
import sys
sys.path.append('..')

import torch
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

from datasets import load_dataset
from mitigation import EnsembleDebiaser, DatasetCartographer, AdversarialTrainer
from analysis import ErrorAnalyzer
from analysis.visualization import create_comparison_table

sns.set_style('whitegrid')
%matplotlib inline

## Choose Mitigation Strategy

Options:
1. **Dataset Cartography**: Focus on hard examples
2. **Ensemble Debiasing**: Train biased model and debias main model
3. **Adversarial Training**: Augment with adversarial examples

## Option 1: Dataset Cartography

In [None]:
# Initialize cartographer
cartographer = DatasetCartographer(
    model_path='google/electra-small-discriminator',
    dataset_name='snli'
)

# This would require full training loop integration
# See mitigation/dataset_cartography.py for full implementation
print("Dataset Cartography analysis ready")

## Option 2: Ensemble Debiasing

In [None]:
# Load dataset
dataset = load_dataset('snli')['train']
dataset = dataset.filter(lambda x: x['label'] != -1)
dataset = dataset.select(range(1000))  # Use subset for demo

# Initialize debiaser
debiaser = EnsembleDebiaser(
    main_model_path='google/electra-small-discriminator',
    num_labels=3
)

print(f"Dataset size: {len(dataset)}")
print("Ensemble debiaser initialized")

In [None]:
# Step 1: Train biased model (hypothesis-only)
debiaser.train_biased_model(
    dataset=dataset,
    output_dir='../models/biased_snli',
    num_epochs=2,
    batch_size=32
)

print("Biased model training complete!")

In [None]:
# Step 2: Compute example weights
weights = debiaser.compute_example_weights(dataset, temperature=1.0)

print(f"Computed weights for {len(weights)} examples")
print(f"Mean weight: {weights.mean():.3f}")
print(f"Std weight: {weights.std():.3f}")

# Visualize weight distribution
plt.figure(figsize=(10, 5))
plt.hist(weights, bins=30, edgecolor='black')
plt.xlabel('Example Weight')
plt.ylabel('Frequency')
plt.title('Distribution of Example Weights')
plt.axvline(weights.mean(), color='red', linestyle='--', label='Mean')
plt.legend()
plt.show()

In [None]:
# Step 3: Train debiased model with reweighted examples
debiaser.train_debiased_model(
    dataset=dataset,
    output_dir='../models/debiased_snli',
    example_weights=weights,
    num_epochs=3,
    batch_size=32
)

print("Debiased model training complete!")

## Option 3: Adversarial Training

In [None]:
# Initialize adversarial trainer
adv_trainer = AdversarialTrainer(
    model_path='google/electra-small-discriminator',
    num_labels=3
)

# Load dataset
dataset = load_dataset('snli')['train']
dataset = dataset.filter(lambda x: x['label'] != -1)
dataset = dataset.select(range(1000))

print(f"Dataset size: {len(dataset)}")

In [None]:
# Train with adversarial augmentation
adv_trainer.train_with_adversarial_examples(
    dataset=dataset,
    output_dir='../models/adversarial_snli',
    augmentation_ratio=0.3,  # Add 30% adversarial examples
    num_epochs=3,
    batch_size=32
)

print("Adversarial training complete!")

## Evaluate Mitigated Model

In [None]:
# Evaluate baseline
baseline_analyzer = ErrorAnalyzer('../models/baseline_snli')
baseline_results = baseline_analyzer.analyze_dataset(
    dataset_name='snli',
    split='validation',
    max_samples=1000
)

print(f"Baseline accuracy: {baseline_results['accuracy']:.2%}")

In [None]:
# Evaluate mitigated model
mitigated_analyzer = ErrorAnalyzer('../models/debiased_snli')  # or adversarial_snli
mitigated_results = mitigated_analyzer.analyze_dataset(
    dataset_name='snli',
    split='validation',
    max_samples=1000
)

print(f"Mitigated accuracy: {mitigated_results['accuracy']:.2%}")

## Compare Models

In [None]:
# Create comparison table
comparison = create_comparison_table(baseline_results, mitigated_results)
print("\nModel Comparison:")
print(comparison.to_string(index=False))

In [None]:
# Test hypothesis-only baseline for both models
baseline_hyp = baseline_analyzer.analyze_hypothesis_only(max_samples=1000)
mitigated_hyp = mitigated_analyzer.analyze_hypothesis_only(max_samples=1000)

print("\nHypothesis-only Accuracy:")
print(f"  Baseline: {baseline_hyp['hypothesis_only_accuracy']:.2%}")
print(f"  Mitigated: {mitigated_hyp['hypothesis_only_accuracy']:.2%}")

# Visualize comparison
fig, ax = plt.subplots(figsize=(10, 6))

models = ['Baseline', 'Mitigated']
full_acc = [baseline_results['accuracy'], mitigated_results['accuracy']]
hyp_acc = [baseline_hyp['hypothesis_only_accuracy'], mitigated_hyp['hypothesis_only_accuracy']]

x = np.arange(len(models))
width = 0.35

ax.bar(x - width/2, full_acc, width, label='Full Model', color='skyblue')
ax.bar(x + width/2, hyp_acc, width, label='Hypothesis-Only', color='lightcoral')

ax.set_ylabel('Accuracy')
ax.set_title('Model Comparison: Full vs Hypothesis-Only')
ax.set_xticks(x)
ax.set_xticklabels(models)
ax.legend()
ax.set_ylim([0, 1])

# Add value labels
for i, (f, h) in enumerate(zip(full_acc, hyp_acc)):
    ax.text(i - width/2, f + 0.02, f'{f:.2%}', ha='center', fontweight='bold')
    ax.text(i + width/2, h + 0.02, f'{h:.2%}', ha='center', fontweight='bold')

plt.tight_layout()
plt.show()

## Summary

Key metrics to evaluate:
1. **Overall accuracy**: Did it improve?
2. **Hypothesis-only accuracy**: Did it decrease (less artifact exploitation)?
3. **Error patterns**: Are errors more evenly distributed?
4. **Robustness**: Performance on hard/adversarial examples?

A successful mitigation should:
- Maintain or improve overall accuracy
- Reduce hypothesis-only baseline performance
- Improve performance on hard/out-of-distribution examples