# Phase 6: Case Studies & Qualitative Analysis

**Question:** Can we trace predictions back to interpretable concepts in the original text?

This notebook:
1. Selects 8 exemplar cases (2 TP, 2 TN, 2 FP, 2 FN) by confidence
2. Visualizes bipartite alignment graphs with attribution overlays
3. Maps important nodes/edges back to AMR concepts
4. Generates explanation cards

In [None]:
import sys
sys.path.insert(0, '../src')

import numpy as np
import matplotlib.pyplot as plt
import torch

from calamr_interp.utils.data_loading import load_and_split
from calamr_interp.utils.model_loading import create_model
from calamr_interp.utils.visualization import setup_style
from calamr_interp.phase4_attribution import GradientSaliency
from calamr_interp.phase6_case_studies import (
    CaseStudySelector,
    AlignmentGraphVisualizer,
    ExplanationGenerator,
)

setup_style()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Device: {device}')

## 1. Load Model & Select Cases

In [None]:
# Load data
train_data, val_data, test_data = load_and_split()
test_list = list(test_data)
print(f'Test: {len(test_list)} graphs')

# Load model (update path as needed)
model = create_model('EdgeAwareGAT')
model = model.to(device)
model.eval()

# Select exemplar cases
selector = CaseStudySelector(model, device)
cases = selector.select(test_list, n_per_category=2)

for cat, items in cases.items():
    print(f'\n{cat}: {len(items)} cases')
    for item in items:
        print(f'  idx={item["index"]}, prob={item["pred_prob"]:.3f}, label={item["label"]}')

## 2. Compute Attributions

In [None]:
# Compute gradient saliency for selected cases
saliency = GradientSaliency(model, device)

all_attributions = {}
for cat, items in cases.items():
    for item in items:
        idx = item['index']
        attrs = saliency.attribute(item['data'])
        all_attributions[idx] = attrs

print(f'Computed attributions for {len(all_attributions)} cases')

## 3. Visualization: Bipartite Alignment Graphs

In [None]:
# Visualize each case
visualizer = AlignmentGraphVisualizer()

for cat in ['TP', 'TN', 'FP', 'FN']:
    if cat not in cases or not cases[cat]:
        print(f'No {cat} cases found')
        continue
    
    for item in cases[cat]:
        idx = item['index']
        data = item['data']
        attrs = all_attributions.get(idx, {})
        
        edge_imp = attrs.get('edge_saliency', None)
        node_imp = attrs.get('node_saliency', None)
        
        title = f"{cat}: idx={idx}, prob={item['pred_prob']:.3f}, true={'Hallu' if item['label']==1 else 'Truth'}"
        fig = visualizer.plot_bipartite(
            data,
            edge_importance=edge_imp,
            node_importance=node_imp,
            title=title,
        )
        plt.show()

## 4. Explanation Cards

In [None]:
# Generate explanation cards
explainer = ExplanationGenerator(model, device)

for cat in ['TP', 'FP']:
    if cat not in cases or not cases[cat]:
        continue
    
    item = cases[cat][0]  # Most confident
    idx = item['index']
    attrs = all_attributions.get(idx, {})
    
    card = explainer.generate_card(
        data=item['data'],
        pred_prob=item['pred_prob'],
        label=item['label'],
        edge_importance=attrs.get('edge_saliency'),
        node_importance=attrs.get('node_saliency'),
    )
    
    print(f"\n{'='*60}")
    print(f"EXPLANATION CARD: {cat} (index={idx})")
    print(f"{'='*60}")
    
    pred = card['prediction']
    print(f"Predicted: {pred['predicted_label']} (prob={pred['probability']:.3f})")
    print(f"Actual:    {pred['true_label']}")
    print(f"Correct:   {pred['correct']}")
    
    stats = card['graph_stats']
    print(f"\nGraph: {stats['n_nodes']} nodes, {stats['n_edges']} edges")
    print(f"  Source: {stats['n_source_nodes']}, Summary: {stats['n_summary_nodes']}")
    print(f"  Alignment edges: {stats['n_alignment_edges']}")
    print(f"  Mean alignment flow: {stats['mean_alignment_flow']:.3f}")
    print(f"  Max alignment flow:  {stats['max_alignment_flow']:.3f}")
    
    if 'top_edges' in card:
        print(f"\nTop 5 important edges:")
        for e in card['top_edges'][:5]:
            edge_type = 'ALIGN' if e['is_alignment'] else 'INTRL'
            print(f"  [{edge_type}] {e['source']}->{e['target']} imp={e['importance']:.4f} flow={e['flow']:.3f}")
    
    plt.show()

In [None]:
print("\n=== Key Interpretation ===")
print("Look for patterns like:")
print("- TP hallucinations: summary nodes with weak/no alignment, low flow values")
print("- FP errors: graphs where strong alignment exists but model is confused")
print("- FN misses: hallucination graphs that look structurally similar to truth")
print("\nThe most interpretable findings will come from combining:")
print("1. Attribution overlays (which edges/nodes matter)")
print("2. Text mapping (what AMR concepts these correspond to)")
print("3. Flow values (how strongly aligned are the concepts)")