### Basic Setup and Loading Data

In [None]:
# ==================================================================================
# SETUP AND CONFIGURATION
# ==================================================================================

# Standard libraries
import json
import os
import re
import sys
import ast
from collections import Counter

# Data science libraries
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

# Custom modules
sys.path.append('..')
from errors_analysis import *
from datasets import load_dataset

# Configure display options
pd.set_option('display.max_colwidth', None)
pd.set_option('display.max_rows', 50)

# Set visualization style
plt.style.use('seaborn-v0_8-darkgrid')
sns.set_palette("husl")

In [None]:
# ==================================================================================
# DATA LOADING AND PREPROCESSING
# ==================================================================================

def load_predictions(file_path):
    """
    Load prediction results from a JSONL file.
    
    Parameters:
    -----------
    file_path : str
        Path to the JSONL file containing predictions
        
    Returns:
    --------
    list : List of prediction dictionaries
    """
    with open(file_path, "r") as file:
        return [json.loads(line) for line in file]


def extract_failed_predictions(predictions):
    """
    Extract and format failed predictions into a DataFrame.
    
    Parameters:
    -----------
    predictions : list
        List of prediction dictionaries
        
    Returns:
    --------
    pd.DataFrame : DataFrame containing failed predictions with extracted answer column
    """
    # Filter failed examples
    failed_examples = [
        pred for pred in predictions
        if pred['predicted_answer'] != pred['answers']['text'][0]
    ]
    
    # Create DataFrame
    failed_df = pd.DataFrame(failed_examples)
    
    # Extract answer text from nested structure
    failed_df['answer'] = [v['text'][0] for v in failed_df['answers'].values]
    
    return failed_df


def print_dataset_summary(predictions, failed_df):
    """
    Print summary statistics for the dataset.
    
    Parameters:
    -----------
    predictions : list
        Full list of predictions
    failed_df : pd.DataFrame
        DataFrame of failed predictions
    """
    total = len(predictions)
    failed = len(failed_df)
    accuracy = (1 - failed/total) * 100
    
    print("=" * 80)
    print("DATASET SUMMARY")
    print("=" * 80)
    print(f"Total predictions:  {total:,}")
    print(f"Failed predictions: {failed:,}")
    print(f"Success rate:       {total - failed:,}")
    print(f"Accuracy:           {accuracy:.2f}%")
    print(f"\nDataFrame shape:    {failed_df.shape}")
    print(f"Columns:            {list(failed_df.columns)}")
    print("=" * 80)


# ==================================================================================
# MAIN EXECUTION
# ==================================================================================

# Load predictions from trained model
predictions_file = os.path.join("../trained_models/train_rule_cluster_val", "eval_predictions.jsonl")
predictions = load_predictions(predictions_file)

# Extract failed predictions
failed = extract_failed_predictions(predictions)

# Display summary
print_dataset_summary(predictions, failed)

### Rule Based Analysis based on Span and character matches

The following analysis is the updated errors_analysis.py to file to do the analysis

In [None]:
# ==================================================================================
# RULE-BASED ERROR DETECTION
# ==================================================================================

# Apply rule-based error detection to identify dataset quality issues
df_with_flags, summary = detect_dataset_errors(
    df=failed,
    error_threshold=2,           # minimum rules to trigger for dataset error
    length_percentile=95         # percentile for length anomaly detection
)

# Display rule trigger summary
print("=" * 80)
print("RULE TRIGGER SUMMARY")
print("=" * 80)
print(summary)
print()

# Extract and display dataset errors
dataset_errors = df_with_flags[df_with_flags['is_dataset_error']]
print(f"Found {len(dataset_errors)} potential dataset errors")
print(f"Dataset error rate: {len(dataset_errors)/len(failed)*100:.1f}%")
print("=" * 80)

In [None]:
pd.set_option('display.max_colwidth', None)
df_with_flags[df_with_flags['rule13_model_wrong_entity_same_type']][['question', 'answer', 'predicted_answer', 'rule12_model_underextraction']].sample(5)

In [None]:
df_with_flags.columns

### Deeper Analysis based on combining the rules for a better story

We combine the rules to get a better understanding of the errors and generalize the categories.

In [None]:
# ==================================================================================
# ENHANCED ERROR ANALYSIS - COMPOSITE CATEGORIES
# ==================================================================================

# Apply enhanced error analysis to combine rules into meaningful categories
df_enhanced, composite_summary, medical_summary, model_behavior_summary = enhanced_error_analysis(df_with_flags)

# Display all error category summaries
print("=" * 80)
print("COMPOSITE ERROR CATEGORIES")
print("=" * 80)
print(composite_summary)
print()

print("=" * 80)
print("MEDICAL DOMAIN-SPECIFIC ERRORS")
print("=" * 80)
print(medical_summary)
print()

print("=" * 80)
print("MODEL BEHAVIORAL PATTERNS")
print("=" * 80)
print(model_behavior_summary)
print()

# Summary statistics
total_examples = len(df_enhanced)
dataset_errors = df_enhanced['is_dataset_error'].sum()
medical_errors = df_enhanced[['med_semantic_error', 'vague_question_error', 'frequency_error']].any(axis=1).sum()
model_errors = df_enhanced['has_model_behavior_error'].sum()

print("=" * 80)
print("OVERALL ERROR DISTRIBUTION")
print("=" * 80)
print(f"Total examples analyzed:        {total_examples:,}")
print(f"Dataset quality issues:         {dataset_errors:,} ({dataset_errors/total_examples*100:.1f}%)")
print(f"Medical domain errors:          {medical_errors:,} ({medical_errors/total_examples*100:.1f}%)")
print(f"Model behavioral errors:        {model_errors:,} ({model_errors/total_examples*100:.1f}%)")
print("=" * 80)
print()

# Show representative examples from each major category
print("=" * 80)
print("REPRESENTATIVE EXAMPLES")
print("=" * 80)

# 1. Semantic mismatch examples
print("\n1. SEMANTIC MISMATCH (Most Common Dataset Issue)")
print("-" * 80)
semantic_examples = df_enhanced[df_enhanced['semantic_mismatch']].head(2)
for idx, row in semantic_examples.iterrows():
    print(f"\nQuestion:   {row['question'][:70]}...")
    print(f"Gold:       {row['answer'][:70]}...")
    print(f"Predicted:  {row['predicted_answer'][:70]}...")
    print(f"Error Score: {row['dataset_error_score']}")

# 2. Medical semantic errors
print("\n\n2. MEDICAL SEMANTIC ERRORS")
print("-" * 80)
med_examples = df_enhanced[df_enhanced['med_semantic_error']].head(2)
for idx, row in med_examples.iterrows():
    print(f"\nQuestion:   {row['question'][:70]}...")
    print(f"Gold:       {row['answer'][:70]}...")
    print(f"Predicted:  {row['predicted_answer'][:70]}...")

# 3. Model underextraction
print("\n\n3. MODEL UNDEREXTRACTION (Stopped Too Early)")
print("-" * 80)
underextract_examples = df_enhanced[df_enhanced['model_underextraction_pattern']].head(2)
for idx, row in underextract_examples.iterrows():
    print(f"\nQuestion:   {row['question'][:70]}...")
    print(f"Gold:       {row['answer'][:70]}...")
    print(f"Predicted:  {row['predicted_answer'][:70]}...")
    print(f"Length:     {len(row['predicted_answer'])} chars vs {len(row['answer'])} chars (gold)")

# 4. Entity confusion
print("\n\n4. MODEL ENTITY CONFUSION (Wrong Entity, Same Type)")
print("-" * 80)
entity_examples = df_enhanced[df_enhanced['model_entity_confusion_pattern']].head(2)
for idx, row in entity_examples.iterrows():
    print(f"\nQuestion:   {row['question'][:70]}...")
    print(f"Gold:       {row['answer'][:70]}...")
    print(f"Predicted:  {row['predicted_answer'][:70]}...")

# 5. Position bias
print("\n\n5. MODEL POSITION BIAS (Wrong Location in Context)")
print("-" * 80)
position_examples = df_enhanced[df_enhanced['model_position_bias_pattern']].head(2)
for idx, row in position_examples.iterrows():
    print(f"\nQuestion:   {row['question'][:70]}...")
    print(f"Gold:       {row['answer'][:70]}...")
    print(f"Predicted:  {row['predicted_answer'][:70]}...")

# 6. Boundary failure
print("\n\n6. MODEL BOUNDARY FAILURE (Stopped at Punctuation)")
print("-" * 80)
boundary_examples = df_enhanced[df_enhanced['model_boundary_failure_pattern']].head(2)
for idx, row in boundary_examples.iterrows():
    print(f"\nQuestion:   {row['question'][:70]}...")
    print(f"Gold:       {row['answer'][:70]}...")
    print(f"Predicted:  {row['predicted_answer'][:70]}...")

print("\n" + "=" * 80)

### Visualization of Composite Error Categories

In [None]:
# Create comprehensive visualizations for composite error categories
fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(16, 12))

# 1. Composite error categories breakdown
composite_summary.plot(kind='bar', y='count', ax=ax1, color='steelblue', legend=False)
ax1.set_title('Composite Error Categories - Count', fontsize=14, fontweight='bold')
ax1.set_ylabel('Number of Examples')
ax1.set_xlabel('Category')
ax1.tick_params(axis='x', rotation=45)
ax1.grid(axis='y', alpha=0.3)

# 2. Medical domain-specific errors
medical_summary.plot(kind='bar', y='count', ax=ax2, color='coral', legend=False)
ax2.set_title('Medical Domain-Specific Errors - Count', fontsize=14, fontweight='bold')
ax2.set_ylabel('Number of Examples')
ax2.set_xlabel('Error Type')
ax2.tick_params(axis='x', rotation=45)
ax2.grid(axis='y', alpha=0.3)

# 3. Error precision for composite categories
composite_summary.plot(kind='bar', y='error_precision', ax=ax3, color='green', legend=False)
ax3.set_title('Composite Categories - Error Precision', fontsize=14, fontweight='bold')
ax3.set_ylabel('Precision (%)')
ax3.set_xlabel('Category')
ax3.tick_params(axis='x', rotation=45)
ax3.axhline(y=50, color='red', linestyle='--', alpha=0.7, label='50% threshold')
ax3.legend()
ax3.grid(axis='y', alpha=0.3)

# 4. Medical error precision
medical_summary.plot(kind='bar', y='error_precision', ax=ax4, color='purple', legend=False)
ax4.set_title('Medical Errors - Error Precision', fontsize=14, fontweight='bold')
ax4.set_ylabel('Precision (%)')
ax4.set_xlabel('Error Type')
ax4.tick_params(axis='x', rotation=45)
ax4.axhline(y=50, color='red', linestyle='--', alpha=0.7, label='50% threshold')
ax4.legend()
ax4.grid(axis='y', alpha=0.3)

plt.tight_layout()
plt.show()

# Print key statistics
print("=== KEY STATISTICS ===")
print(f"Total examples: {len(df_enhanced)}")
print(f"Flagged as dataset errors: {df_enhanced['is_dataset_error'].sum()} ({df_enhanced['is_dataset_error'].mean()*100:.1f}%)")
print(f"\nSemantic mismatch (primary issue): {df_enhanced['semantic_mismatch'].sum()} ({df_enhanced['semantic_mismatch'].mean()*100:.1f}%)")
print(f"Medical-specific errors: {df_enhanced['med_semantic_error'].sum()} ({df_enhanced['med_semantic_error'].mean()*100:.1f}%)")
print(f"Template generation issues: {df_enhanced['template_artifacts'].sum()} ({df_enhanced['template_artifacts'].mean()*100:.1f}%)")
print(f"\nSemantic mismatch precision for dataset errors: {df_enhanced[df_enhanced['semantic_mismatch']]['is_dataset_error'].mean()*100:.1f}%")


### Medical Domain Pattern Analysis

In [None]:
# Analyze medical domain-specific patterns in detail
fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(16, 12))

# 1. Medication error breakdown
medication_cols = ['med_semantic_error', 'frequency_error']
med_counts = df_enhanced[medication_cols].sum()
ax1.bar(range(len(med_counts)), med_counts.values, color=['#FF6B6B', '#FFA07A'])
ax1.set_title('Medication-Related Errors', fontsize=14, fontweight='bold')
ax1.set_ylabel('Number of Examples')
ax1.set_xticks(range(len(med_counts)))
ax1.set_xticklabels(['All Med Errors', 'Frequency Errors'], rotation=0)
ax1.grid(axis='y', alpha=0.3)

# 2. Vague question analysis
vague_count = df_enhanced['vague_question_error'].sum()
non_vague_count = len(df_enhanced) - vague_count
ax2.pie([non_vague_count, vague_count], 
        labels=['Specific Questions', 'Vague Questions'],
        autopct='%1.1f%%', colors=['lightgreen', 'lightcoral'])
ax2.set_title('Vague vs Specific Questions', fontsize=14, fontweight='bold')

# 3. Question type distribution for errors
question_patterns = {
    'Frequency': df_enhanced['question'].str.contains('how often|frequency', case=False, na=False).sum(),
    'Dosage': df_enhanced['question'].str.contains('dose|dosage', case=False, na=False).sum(),
    'Medication': df_enhanced['question'].str.contains('medication|drug', case=False, na=False).sum(),
    'Temporal': df_enhanced['question'].str.contains('when|previous|current', case=False, na=False).sum(),
}
ax3.barh(list(question_patterns.keys()), list(question_patterns.values()), color='skyblue')
ax3.set_title('Question Type Distribution in Errors', fontsize=14, fontweight='bold')
ax3.set_xlabel('Number of Examples')
ax3.grid(axis='x', alpha=0.3)

# 4. Error severity by medical pattern
medical_pattern_cols = ['med_semantic_error', 'diagnostic_error', 'temporal_error', 'frequency_error', 'vague_question_error']
severity_by_pattern = {}
for col in medical_pattern_cols:
    if df_enhanced[col].sum() > 0:
        avg_severity = df_enhanced[df_enhanced[col]]['dataset_error_score'].mean()
        severity_by_pattern[col.replace('_error', '').replace('_', ' ').title()] = avg_severity

ax4.bar(range(len(severity_by_pattern)), list(severity_by_pattern.values()), color='orange')
ax4.set_title('Average Error Severity by Medical Pattern', fontsize=14, fontweight='bold')
ax4.set_ylabel('Average Error Score')
ax4.set_xticks(range(len(severity_by_pattern)))
ax4.set_xticklabels(list(severity_by_pattern.keys()), rotation=45, ha='right')
ax4.axhline(y=3, color='red', linestyle='--', alpha=0.7, label='Error Threshold')
ax4.legend()
ax4.grid(axis='y', alpha=0.3)

plt.tight_layout()
plt.show()

# Print detailed medical pattern statistics
print("=== MEDICAL DOMAIN PATTERN STATISTICS ===")
for col in medical_pattern_cols:
    count = df_enhanced[col].sum()
    percent = (count / len(df_enhanced)) * 100
    if count > 0:
        is_error_rate = df_enhanced[df_enhanced[col]]['is_dataset_error'].mean() * 100
        print(f"{col}: {count} ({percent:.1f}%) - Dataset Error Rate: {is_error_rate:.1f}%")


### Complete Analysis Using Convenience Function

In [None]:
# Use the complete analysis pipeline for a comprehensive report
results = analyze_medical_qa_errors(failed, error_threshold=2, length_percentile=95, verbose=True)

# Access the enhanced dataframe for further analysis
df_complete = results['enhanced_dataframe']

# Show specific examples using the print function
print("\n" + "="*80)
print_error_examples(df_complete, category='semantic_mismatch', max_examples=3)

print("\n" + "="*80)
print_error_examples(df_complete, category='med_semantic_error', max_examples=3)

print("\n" + "="*80)
print_error_examples(df_complete, category='vague_question_error', max_examples=3)


### Unified Analysis of all the data

## Model Behavioral Pattern Analysis

Analysis of specific model error behaviors (rules 12-15)

In [None]:
# Analyze model behavior patterns
model_behavior_summary = results['model_behavior_summary']
print("Model Behavioral Patterns Summary:")
print(model_behavior_summary)
print()

# Count of examples with any model behavior error
model_errors = df_enhanced['has_model_behavior_error'].sum()
print(f"\nTotal examples with model behavior errors: {model_errors} ({model_errors/len(df_enhanced)*100:.1f}%)")

### Visualize Model Behavior Patterns

In [None]:
# Visualize model behavior patterns
import matplotlib.pyplot as plt
import seaborn as sns

fig, axes = plt.subplots(2, 2, figsize=(14, 10))

model_cols = ['model_underextraction_pattern', 'model_entity_confusion_pattern',
              'model_position_bias_pattern', 'model_boundary_failure_pattern']
labels = ['Underextraction\n(Model stopped early)', 
          'Entity Confusion\n(Wrong entity, same type)',
          'Position Bias\n(Wrong location)', 
          'Boundary Failure\n(Stopped at punctuation)']

for idx, (col, label) in enumerate(zip(model_cols, labels)):
    ax = axes[idx // 2, idx % 2]
    
    # Count distribution
    counts = df_enhanced[col].value_counts()
    ax.pie(counts.values, labels=['No Error', 'Has Error'], autopct='%1.1f%%',
           colors=['#90EE90', '#FF6B6B'])
    ax.set_title(label, fontsize=12, fontweight='bold')

plt.tight_layout()
plt.savefig('model_behavior_patterns.png', dpi=150, bbox_inches='tight')
plt.show()

print(f"Total examples with any model behavior error: {df_enhanced['has_model_behavior_error'].sum()}")

### Specific Examples of Model Errors

In [None]:
# Show specific examples for each model behavior pattern
print("=" * 100)
print("SPECIFIC EXAMPLES OF MODEL BEHAVIORAL ERRORS")
print("=" * 100)

print("\n1. UNDEREXTRACTION (Model stopped too early):")
print("-" * 100)
underextract = df_enhanced[df_enhanced['model_underextraction_pattern']].head(3)
for idx, row in underextract.iterrows():
    print(f"\nQuestion: {row['question']}")
    print(f"Gold:     {row['answer'][:100]}...")
    print(f"Pred:     {row['predicted_answer'][:100]}...")
    print(f"Error:    Model extracted only {len(row['predicted_answer'])} chars vs {len(row['answer'])} chars in gold")

print("\n\n2. ENTITY CONFUSION (Wrong entity, same semantic type):")
print("-" * 100)
entity_conf = df_enhanced[df_enhanced['model_entity_confusion_pattern']].head(3)
for idx, row in entity_conf.iterrows():
    print(f"\nQuestion: {row['question']}")
    print(f"Gold:     {row['answer'][:100]}...")
    print(f"Pred:     {row['predicted_answer'][:100]}...")
    print(f"Error:    Model confused entities of same type (e.g., different medications)")

print("\n\n3. POSITION BIAS (Extracted from wrong location in context):")
print("-" * 100)
pos_bias = df_enhanced[df_enhanced['model_position_bias_pattern']].head(3)
for idx, row in pos_bias.iterrows():
    print(f"\nQuestion: {row['question']}")
    print(f"Gold:     {row['answer'][:100]}...")
    print(f"Pred:     {row['predicted_answer'][:100]}...")
    print(f"Error:    Model selected answer from different position in document")

print("\n\n4. BOUNDARY FAILURE (Stopped at punctuation unnecessarily):")
print("-" * 100)
boundary = df_enhanced[df_enhanced['model_boundary_failure_pattern']].head(3)
for idx, row in boundary.iterrows():
    print(f"\nQuestion: {row['question']}")
    print(f"Gold:     {row['answer'][:100]}...")
    print(f"Pred:     {row['predicted_answer'][:100]}...")
    print(f"Error:    Model stopped at punctuation but answer continues")

### Model vs Dataset Error Comparison

In [None]:
# Compare model behavior errors vs dataset quality errors
import matplotlib.pyplot as plt

fig, ax = plt.subplots(1, 2, figsize=(14, 6))

# Left: Venn-like comparison
categories = ['Model Behavior\nErrors Only', 'Both Model &\nDataset Issues', 'Dataset Quality\nErrors Only']
model_only = df_enhanced['has_model_behavior_error'] & ~df_enhanced['is_dataset_error']
both = df_enhanced['has_model_behavior_error'] & df_enhanced['is_dataset_error']
dataset_only = ~df_enhanced['has_model_behavior_error'] & df_enhanced['is_dataset_error']

counts = [model_only.sum(), both.sum(), dataset_only.sum()]
colors = ['#FF6B6B', '#FFD93D', '#6BCB77']

ax[0].bar(categories, counts, color=colors)
ax[0].set_ylabel('Number of Examples')
ax[0].set_title('Model Behavior vs Dataset Quality Errors', fontweight='bold')
ax[0].tick_params(axis='x', rotation=0)
for i, v in enumerate(counts):
    ax[0].text(i, v + 50, str(v), ha='center', va='bottom', fontweight='bold')

# Right: Breakdown percentages
total = len(df_enhanced)
percentages = [c/total*100 for c in counts]
percentages.append((total - sum(counts))/total*100)  # Neither
labels_pie = categories + ['Neither']
colors_pie = colors + ['#E0E0E0']

ax[1].pie(percentages, labels=labels_pie, autopct='%1.1f%%', colors=colors_pie, startangle=90)
ax[1].set_title('Distribution of Error Types', fontweight='bold')

plt.tight_layout()
plt.savefig('model_vs_dataset_errors.png', dpi=150, bbox_inches='tight')
plt.show()

print(f"Model behavior errors only: {model_only.sum()} ({model_only.sum()/total*100:.1f}%)")
print(f"Both model & dataset issues: {both.sum()} ({both.sum()/total*100:.1f}%)")
print(f"Dataset quality errors only: {dataset_only.sum()} ({dataset_only.sum()/total*100:.1f}%)")
print(f"Neither: {total - sum(counts)} ({(total - sum(counts))/total*100:.1f}%)")

### Advanced Visualizations for Error Analysis

In [None]:
# 5. SUNBURST CHART - Interactive Drill-down of Error Categories
import plotly.express as px

# Build sunburst data
sunburst_data = []

# Add root
sunburst_data.append({
    'labels': 'All Errors',
    'parents': '',
    'values': len(df_enhanced),
    'text': f'{len(df_enhanced)} total'
})

# Level 1: Error vs Non-error
has_error = df_enhanced['has_model_behavior_error'] | df_enhanced['is_dataset_error']
sunburst_data.extend([
    {'labels': 'Has Errors', 'parents': 'All Errors', 
     'values': has_error.sum(), 'text': f"{has_error.sum()/len(df_enhanced)*100:.1f}%"},
    {'labels': 'No Errors', 'parents': 'All Errors', 
     'values': (~has_error).sum(), 'text': f"{(~has_error).sum()/len(df_enhanced)*100:.1f}%"}
])

# Level 2: Model vs Dataset
model_only = df_enhanced['has_model_behavior_error'] & ~df_enhanced['is_dataset_error']
dataset_only = ~df_enhanced['has_model_behavior_error'] & df_enhanced['is_dataset_error']
both = df_enhanced['has_model_behavior_error'] & df_enhanced['is_dataset_error']

sunburst_data.extend([
    {'labels': 'Model Only', 'parents': 'Has Errors', 
     'values': model_only.sum(), 'text': f"{model_only.sum()}"},
    {'labels': 'Dataset Only', 'parents': 'Has Errors', 
     'values': dataset_only.sum(), 'text': f"{dataset_only.sum()}"},
    {'labels': 'Both', 'parents': 'Has Errors', 
     'values': both.sum(), 'text': f"{both.sum()}"}
])

# Level 3: Specific model patterns under "Model Only" and "Both"
model_patterns = {
    'Underextract': 'model_underextraction_pattern',
    'Entity Confusion': 'model_entity_confusion_pattern',
    'Position Bias': 'model_position_bias_pattern',
    'Boundary Fail': 'model_boundary_failure_pattern'
}

for label, col in model_patterns.items():
    # Under Model Only
    count_model = (model_only & df_enhanced[col]).sum()
    if count_model > 0:
        sunburst_data.append({
            'labels': f'{label} (M)', 'parents': 'Model Only',
            'values': count_model, 'text': f"{count_model}"
        })
    
    # Under Both
    count_both = (both & df_enhanced[col]).sum()
    if count_both > 0:
        sunburst_data.append({
            'labels': f'{label} (B)', 'parents': 'Both',
            'values': count_both, 'text': f"{count_both}"
        })

sunburst_df = pd.DataFrame(sunburst_data)

# Create sunburst
fig = px.sunburst(
    sunburst_df,
    names='labels',
    parents='parents',
    values='values',
    title='Interactive Error Category Breakdown (click to drill down)',
    color='values',
    color_continuous_scale='RdYlBu_r',
    hover_data={'text': True}
)

fig.update_traces(textinfo='label+percent parent')
fig.update_layout(height=700, font_size=11)
fig.show()

Analyze the Unified Analysis of everything

In [None]:
unified = pd.read_csv('../experiment_outputs/unified_analysis.csv')
# pre-process to get the answer in a new col
unified['answer'] = unified['answers'].apply(lambda x: ast.literal_eval(x)['text'][0])


The insights from cluster '-1' are noise and they don't have any particular pattern.

In [None]:
pd.set_option('display.max_colwidth', None)
unified[unified['cluster'] == -1][['question', 'answer', 'rule3_low_q_similarity', 'rule5_qtype_mismatch', 'category']]

This gives an insight on all the hard examples that is very likely to be noise as they are labeled wrong.

In [None]:
pd.set_option('display.max_rows', None)
unified[(unified['category'] == 'hard') & (unified['rule3_low_q_similarity'] == True) & (unified['cluster'] == 2)][['question', 'answer', 'rule3_low_q_similarity', 'cluster']]

Cluster Level Analysis for the rules based errors

In [None]:
# Get all rule columns
rule_cols = [col for col in unified.columns if col.startswith('rule')]

# Count how many times each rule is triggered overall
rule_summary = pd.DataFrame({
    'count': unified[rule_cols].sum(),
    'percent': unified[rule_cols].mean() * 100
}).sort_values('count', ascending=False)

print(f"Total examples in unified dataset: {len(unified)}")
print("\nOverall Rule trigger summary:")
print(rule_summary)

# Per-cluster analysis
print("\n" + "="*80)
print("Rule trigger summary per cluster:")
print("="*80)

for cluster_id in sorted(unified['cluster'].unique()):
    cluster_data = unified[unified['cluster'] == cluster_id]
    
    cluster_rule_summary = pd.DataFrame({
        'count': cluster_data[rule_cols].sum(),
        'percent': cluster_data[rule_cols].mean() * 100
    }).sort_values('count', ascending=False)
    
    print(f"\n--- Cluster {cluster_id} (n={len(cluster_data)}) ---")
    print(cluster_rule_summary)

Template Based Analysis of Train Dataset and Errors Dataset

In [None]:
# Apply rule-based analysis to failed examples for visualization
failed_error_flags = failed.apply(compute_dataset_error_flags, 2, axis=1)
failed_with_flags = pd.concat([failed, failed_error_flags], axis=1)

# Compute summary for failed dataset
failed_rule_cols = [c for c in failed_with_flags.columns if c.startswith("rule")]
failed_summary = pd.DataFrame({
    "count": failed_with_flags[failed_rule_cols].sum(),
    "percent": failed_with_flags[failed_rule_cols].mean() * 100
}).sort_values("count", ascending=False)

print(f"Total failed examples: {len(failed)}")
print("\nRule trigger summary for failed predictions:")
print(failed_summary)

Some additional analysis of these rules for ambiguous dataset from Dataset cartography

In [None]:
# Load ambiguous dataset from cartography
ambiguous_path = os.path.join("../cartography_analysis_big", "ambiguous_samples.json")
with open(ambiguous_path, "r") as f:
    ambiguous_data = json.load(f)

ambiguous_df = pd.DataFrame(ambiguous_data)

ambiguous_df['answer'] = [v['text'][0] for v in ambiguous_df["answers"].values]
# Apply the same rule-based analysis to ambiguous examples
ambiguous_error_flags = ambiguous_df.apply(compute_dataset_error_flags, axis=1)
ambiguous_with_flags = pd.concat([ambiguous_df, ambiguous_error_flags], axis=1)

# Compute summary for ambiguous dataset
ambiguous_rule_cols = [c for c in ambiguous_with_flags.columns if c.startswith("rule")]
ambiguous_summary = pd.DataFrame({
    "count": ambiguous_with_flags[ambiguous_rule_cols].sum(),
    "percent": ambiguous_with_flags[ambiguous_rule_cols].mean() * 100
}).sort_values("count", ascending=False)

print(f"Total ambiguous examples: {len(ambiguous_df)}")
print("\nRule trigger summary for ambiguous dataset:")
print(ambiguous_summary)

Visualizations for the Errors Analysis

In [None]:
plt.figure(figsize=(12, 6))
summary["percent"].plot(kind="bar", color="blue")
plt.title("Dataset Error Rule Trigger Percentages")
plt.ylabel("Percentage (%)")
plt.xlabel("Rule")
plt.xticks(rotation=45, ha='right')
plt.grid(axis="y", linestyle="--", alpha=0.4)
plt.show()

In [None]:
r2 = failed_with_flags[failed_with_flags["rule2_multi_clause"]]
r7 = failed_with_flags[failed_with_flags["rule7_boundary_weirdness"]]
counts = pd.Series({
    "Multi-Clause (Rule 2)": r2.shape[0],
    "Boundary Weirdness (Rule 7)": r7.shape[0]
})

plt.figure(figsize=(8,5))
counts.plot(kind="bar", color=["green", "pink"])
plt.title("Major Dataset Artifact Types")
plt.ylabel("Number of Affected Examples")
plt.xticks(rotation=0)
plt.grid(axis="y", alpha=0.4, linestyle="--")
plt.show()

###  Baseline EMRQA prediction error analysis
- High-level error patterns for `eval_baseline_emrqa/eval_predictions.jsonl`
- Label wrong predictions as `truncated_span`, `overlong_span`, `partial_overlap`, or `no_overlap` and show counts/examples.

In [None]:
# Helpers
token_re = re.compile(r"\w+")
relation = ["truncated_span", "overrun_span", "partial_overlap", "no_overlap"]

def norm(text):
    return " ".join(text.lower().split())

def tokens(text: str):
    return token_re.findall(text.lower())

def span_relation(prediction: str, actual: str):
    p = norm(prediction)
    a = norm(actual)
    if p and p in a:
        return relation[0]
    if a and a in p:
        return relation[1]
    if set(tokens(p)) & set(tokens(a)):
        return relation[2]
    return relation[3]

# Add a new column to the failed dataframe for the span relation
failed['relation'] = failed.apply(
    lambda row: span_relation(
        row.get("predicted_answer", ""),
        row.get("answer", "")
    ),
    axis=1
)

# Display counts from the failed dataframe
print(f"Total wrong predictions: {len(failed)}")
print("Wrong predictions by type:")
print(failed["relation"].value_counts())

# Use the following filter to deep dive
# Sample examples from each relation type
samples = []
for rel in relation:
    rel_samples = failed[failed["relation"] == rel][["id", "question", "predicted_answer", "answer", "relation"]].head(3)
    samples.append(rel_samples)

pd.concat(samples, ignore_index=True)

### Deep Dive into Baseline EMRQA Prediction Errors

Cluster error types in failed examples for further analysis. This uses sentence transformers to embed the question, context, and predicted/true answers, then clusters them using KMeans. This seems to work better. However, the clusters are still not very meaningful. Further work is needed here.

### Manual Review of Sampled Errors

- No. of questions in the train dataset that start with "has" is 62872/130,956
- No. of questions in the validation dataset that start with "has" is 15913/32739
- No. of questions in the failed examples that start with "has" is 1423/3248

Insulin

In [None]:
insulin_questions = failed[failed['question'].str.contains('insulin', case=False, na=False)]
len(insulin_questions)

Chest pain

In [None]:
chest_pain_questions = failed[failed['question'].str.contains('chest pain', case=False, na=False)]
len(chest_pain_questions)

Medication

In [None]:
medication_questions = failed[failed['question'].str.contains('medication', case=False, na=False)]
len(chest_pain_questions)

Gemfibrozil

In [None]:
gemfibrozil_examples = failed[failed['question'].str.lower().str.contains('gemfibrozil')]
len(gemfibrozil_examples)