In [None]:
import json
import sys

import matplotlib.pyplot as plt
import pandas as pd
from datasets import load_dataset

sys.path.append('..')

from dataset_cartography import (
    analyze_cartography_by_question_type,
    categorize_examples,
    get_examples_by_category,
    load_cartography_metrics,
)

## 1. Load Cartography Metrics

First, load the cartography metrics generated during training.
Make sure you've run training with `--enable_cartography` first!

In [None]:
# Load cartography metrics
cartography_df = load_cartography_metrics("../cartography_output_full")

print(f"Loaded metrics for {len(cartography_df)} examples")
print("\nFirst few rows:")
cartography_df.head()

## 2. Categorize Examples

In [None]:
# Categorize examples based on confidence and variability
cartography_df = categorize_examples(cartography_df)

# Show distribution
print("Category Distribution:")
print(cartography_df['category'].value_counts())
print("\nPercentages:")
print(cartography_df['category'].value_counts(normalize=True) * 100)

## 3. Visualize the Data Map

In [None]:
# Create scatter plot colored by category
fig, ax = plt.subplots(figsize=(12, 8))

colors = {
    'easy': 'green',
    'hard': 'red',
    'ambiguous': 'orange',
    'easy_variable': 'lightgreen'
}

for category in cartography_df['category'].unique():
    data = cartography_df[cartography_df['category'] == category]
    ax.scatter(
        data['variability'],
        data['confidence'],
        c=colors.get(category, 'gray'),
        label=category,
        alpha=0.6,
        s=30
    )

# Add median lines
ax.axhline(cartography_df['confidence'].median(), color='black', 
           linestyle='--', alpha=0.5, label='Median confidence')
ax.axvline(cartography_df['variability'].median(), color='black', 
           linestyle='--', alpha=0.5, label='Median variability')

ax.set_xlabel('Variability', fontsize=14)
ax.set_ylabel('Confidence', fontsize=14)
ax.set_title('Dataset Cartography Map', fontsize=16, fontweight='bold')
ax.legend(fontsize=12)
ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

## 4. Get Example Samples from Each Category

In [None]:
# Load the original dataset
dataset = load_dataset("Eladio/emrqa-msquad")
train_df = pd.DataFrame(dataset['train'])

## taken from helpers.py
if "id" not in train_df.columns:
    print("Generating IDs for dataset examples...")
    from helpers import generate_hash_ids
    
    # Generate IDs and add as a new column
    train_df['id'] = train_df.apply(lambda row: generate_hash_ids(row)['id'], axis=1)

In [None]:
# Get 5 hardest examples
hard_ids = get_examples_by_category(cartography_df, 'hard', n=5)

print("=" * 80)
print("HARDEST EXAMPLES (Low confidence, low variability)")
print("=" * 80)

for i, ex_id in enumerate(hard_ids, 1):
    example = train_df[train_df['id'] == ex_id].iloc[0]
    metrics = cartography_df.loc[ex_id]
    
    print(f"\n{i}. Example ID: {ex_id}")
    print(f"   Confidence: {metrics['confidence']:.3f} | Variability: {metrics['variability']:.3f} | Correctness: {metrics['correctness']:.3f}")
    print(f"   Question: {example['question']}")
    print(f"   Answer: {example['answers']['text'][0]}")
    print(f"   Context snippet: {example['context'][:200]}...")
    print("-" * 80)

In [None]:
# Get 5 most ambiguous examples
ambiguous_ids = get_examples_by_category(cartography_df, 'ambiguous', n=5)

print("=" * 80)
print("MOST AMBIGUOUS EXAMPLES (Low confidence, high variability)")
print("=" * 80)

for i, ex_id in enumerate(ambiguous_ids, 1):
    example = train_df[train_df['id'] == ex_id].iloc[0]
    metrics = cartography_df.loc[ex_id]
    
    print(f"\n{i}. Example ID: {ex_id}")
    print(f"   Confidence: {metrics['confidence']:.3f} | Variability: {metrics['variability']:.3f} | Correctness: {metrics['correctness']:.3f}")
    print(f"   Question: {example['question']}")
    print(f"   Answer: {example['answers']['text'][0]}")
    print(f"   Context snippet: {example['context'][:200]}...")
    print("-" * 80)
     

In [None]:
# Get 5 easiest examples
easy_ids = get_examples_by_category(cartography_df, 'easy', n=5)

print("=" * 80)
print("EASIEST EXAMPLES (High confidence, low variability)")
print("=" * 80)

for i, ex_id in enumerate(easy_ids, 1):
    example = train_df[train_df['id'] == ex_id].iloc[0]
    metrics = cartography_df.loc[ex_id]
    
    print(f"\n{i}. Example ID: {ex_id}")
    print(f"   Confidence: {metrics['confidence']:.3f} | Variability: {metrics['variability']:.3f} | Correctness: {metrics['correctness']:.3f}")
    print(f"   Question: {example['question']}")
    print(f"   Answer: {example['answers']['text'][0]}")
    print(f"   Context snippet: {example['context'][:200]}...")
    print("-" * 80)

## 5. Analyze by Question Type

In [None]:
# Analyze cartography metrics by question type
qtype_analysis = analyze_cartography_by_question_type(cartography_df, train_df)

print("Cartography Metrics by Question Type:")
print("="*80)
qtype_analysis

In [None]:
# Visualize by question type
# Merge to get question types
merged = cartography_df.join(train_df.set_index('id')[['question']], how='inner')

def classify_question(q):
    q_low = q.lower()
    if any(x in q_low for x in ["when", "date", "time", "year"]):
        return "temporal"
    elif any(x in q_low for x in ["how many", "how much", "dose", "dosage"]):
        return "numerical"
    elif any(x in q_low for x in ["what", "which"]):
        return "what/which"
    elif any(x in q_low for x in ["has", "does", "is", "was", "did"]):
        return "yes/no"
    else:
        return "other"

merged['question_type'] = merged['question'].apply(classify_question)

# Box plot by question type
fig, axes = plt.subplots(1, 3, figsize=(18, 5))

for ax, metric in zip(axes, ['confidence', 'variability', 'correctness']):
    data_by_type = [merged[merged['question_type'] == qt][metric].values 
                    for qt in merged['question_type'].unique()]
    
    ax.boxplot(data_by_type, tick_labels=merged['question_type'].unique())
    ax.set_ylabel(metric.capitalize(), fontsize=12)
    ax.set_title(f'{metric.capitalize()} by Question Type', fontsize=13, fontweight='bold')
    ax.grid(True, alpha=0.3, axis='y')
    ax.tick_params(axis='x', rotation=45)

plt.tight_layout()
plt.show()

## 6. Combine with Error Analysis

In [None]:
# Load evaluation predictions (if you have them)
try:
    with open("../eval_baseline_emrqa/eval_predictions.jsonl", "r") as f:
        predictions = [json.loads(line) for line in f]
    
    pred_df = pd.DataFrame(predictions)
    
    # Add error indicator
    pred_df['is_wrong'] = pred_df.apply(
        lambda row: row['predicted_answer'] != row['answers']['text'][0],
        axis=1
    )
    
    # Merge with cartography
    combined = pred_df.merge(
        cartography_df,
        left_on='id',
        right_index=True,
        how='inner'
    )
    
    # Error rate by category
    error_by_category = combined.groupby('category')['is_wrong'].agg(['mean', 'sum', 'count'])
    error_by_category.columns = ['error_rate', 'num_errors', 'total']
    
    print("Error Analysis by Cartography Category:")
    print("="*80)
    print(error_by_category)
    
    # Visualize
    fig, ax = plt.subplots(figsize=(10, 6))
    error_by_category['error_rate'].plot(kind='bar', ax=ax, color='crimson', alpha=0.7)
    ax.set_ylabel('Error Rate', fontsize=12)
    ax.set_xlabel('Category', fontsize=12)
    ax.set_title('Error Rate by Cartography Category', fontsize=14, fontweight='bold')
    ax.grid(True, alpha=0.3, axis='y')
    plt.xticks(rotation=45)
    plt.tight_layout()
    plt.show()
    
except FileNotFoundError:
    print("No evaluation predictions found. Run evaluation first to see error analysis.")

## 7. Filter Dataset Based on Cartography

You can use cartography to create filtered training sets:

In [None]:
# Strategy 1: Remove ambiguous examples (potential label noise)
clean_ids = cartography_df[cartography_df['category'] != 'ambiguous'].index.tolist()
print(f"Clean dataset (no ambiguous): {len(clean_ids)} examples")

# Strategy 2: Focus on hard examples for curriculum learning
hard_ids = cartography_df[cartography_df['category'] == 'hard'].index.tolist()
print(f"Hard examples only: {len(hard_ids)} examples")

# Strategy 3: Balanced sampling
n_per_category = 1000
balanced_ids = []
for cat in ['easy', 'hard', 'ambiguous']:
    cat_ids = cartography_df[cartography_df['category'] == cat].index[:n_per_category].tolist()
    balanced_ids.extend(cat_ids)

print(f"Balanced dataset: {len(balanced_ids)} examples")

# Save filtered IDs for future use
with open('filtered_example_ids.json', 'w') as f:
    json.dump({
        'clean': clean_ids,
        'hard': hard_ids,
        'balanced': balanced_ids
    }, f, indent=2)

print("\nSaved filtered IDs to filtered_example_ids.json")

## 8. Summary Statistics

In [None]:
# Overall summary
print("="*80)
print("DATASET CARTOGRAPHY SUMMARY")
print("="*80)

print(f"\nTotal examples analyzed: {len(cartography_df)}")
print(f"\nMetric ranges:")
print(f"  Confidence:   {cartography_df['confidence'].min():.3f} - {cartography_df['confidence'].max():.3f} (mean: {cartography_df['confidence'].mean():.3f})")
print(f"  Variability:  {cartography_df['variability'].min():.3f} - {cartography_df['variability'].max():.3f} (mean: {cartography_df['variability'].mean():.3f})")
print(f"  Correctness:  {cartography_df['correctness'].min():.3f} - {cartography_df['correctness'].max():.3f} (mean: {cartography_df['correctness'].mean():.3f})")

print("\nCategory distribution:")
for cat, count in cartography_df['category'].value_counts().items():
    pct = 100 * count / len(cartography_df)
    print(f"  {cat:15s}: {count:6d} ({pct:5.1f}%)")

print("\n" + "="*80)