# Artifact Analysis

This notebook analyzes the baseline model to identify dataset artifacts and spurious correlations.

## Setup

In [None]:
import sys
sys.path.append('..')

import torch
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

from analysis import ErrorAnalyzer, ContrastSetEvaluator
from analysis.visualization import (
    plot_error_types,
    plot_confidence_distribution,
    plot_contrast_set_results
)

sns.set_style('whitegrid')
%matplotlib inline

## Load Baseline Model

In [None]:
model_path = '../models/baseline_snli'

analyzer = ErrorAnalyzer(model_path)
print(f"Model loaded from: {model_path}")

## Error Analysis

Analyze model errors on the validation set.

In [None]:
# Analyze predictions
results = analyzer.analyze_dataset(
    dataset_name='snli',
    split='validation',
    max_samples=1000
)

print(f"\nOverall Accuracy: {results['accuracy']:.2%}")
print(f"Total Errors: {results['errors']}")
print(f"Error Rate: {results['error_rate']:.2%}")

## Visualize Error Patterns

In [None]:
# Plot error distribution
plot_error_types(results)
plt.show()

## Confidence Analysis

In [None]:
# Analyze confidence distribution
plot_confidence_distribution(analyzer.predictions)
plt.show()

## Hypothesis-Only Baseline

Test if the model exploits hypothesis-only biases.

In [None]:
# Run hypothesis-only analysis
hyp_results = analyzer.analyze_hypothesis_only(max_samples=1000)

print(f"Hypothesis-only accuracy: {hyp_results['hypothesis_only_accuracy']:.2%}")
print(f"Full model accuracy: {results['accuracy']:.2%}")
print(f"\nGap: {(results['accuracy'] - hyp_results['hypothesis_only_accuracy']):.2%}")

if hyp_results['hypothesis_only_accuracy'] > 0.4:
    print("\n⚠️ Warning: High hypothesis-only accuracy suggests dataset artifacts!")
    print("Random baseline would be 33.3% for 3-class NLI.")

## Examine High-Confidence Errors

In [None]:
# Look at high-confidence errors (model is very confident but wrong)
error_patterns = results['error_patterns']
high_conf_errors = error_patterns['high_confidence_errors']

print(f"High-confidence errors: {len(high_conf_errors)}")
print("\nExamples of high-confidence errors:\n")

for i, error in enumerate(high_conf_errors[:5]):
    print(f"Example {i+1}:")
    print(f"  Premise: {error['premise']}")
    print(f"  Hypothesis: {error['hypothesis']}")
    print(f"  True label: {error['true_label']}, Predicted: {error['pred_label']}")
    print(f"  Confidence: {error['confidence']:.2%}")
    print()

## Lexical Overlap Analysis

In [None]:
from analysis.error_analysis import find_lexical_overlaps

# Analyze lexical overlap for errors
overlap_data = []

for error in analyzer.errors[:100]:
    overlap = find_lexical_overlaps(error['premise'], error['hypothesis'])
    overlap_data.append({
        'overlap_ratio': overlap['overlap_ratio'],
        'true_label': error['true_label'],
        'pred_label': error['pred_label'],
    })

overlap_df = pd.DataFrame(overlap_data)

# Plot overlap distribution
plt.figure(figsize=(10, 5))
plt.hist(overlap_df['overlap_ratio'], bins=20, edgecolor='black')
plt.xlabel('Lexical Overlap Ratio')
plt.ylabel('Frequency')
plt.title('Lexical Overlap Distribution in Errors')
plt.show()

print(f"Mean overlap ratio: {overlap_df['overlap_ratio'].mean():.2%}")

## Export Error Analysis

In [None]:
# Export errors for further analysis
analyzer.export_errors('../results/analysis/errors.csv')
print("Errors exported to ../results/analysis/errors.csv")

## Summary

Key findings:
1. Overall accuracy and error rate
2. Hypothesis-only baseline performance (indicates artifacts)
3. Error patterns by label type
4. High-confidence errors (potential systematic biases)
5. Lexical overlap patterns in errors