# 03 - Model Evaluation: Checkpoint Analysis

**Objective**: Load and analyze the trained Legal-Longformer checkpoint.

This notebook evaluates the model saved during training, examining:
- Overall metrics (F1, Precision, Recall)
- Per-class performance
- Confusion analysis
- Comparison with TF-IDF baseline

In [None]:
import sys
sys.path.insert(0, '..')

import torch
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path

from src.data_loader import DataLoader
from src.model_trainer import DataPreparer
from src.model_evaluator import MultiLabelEvaluator
from src.bert_trainer import (
    DeviceManager, 
    HybridLegalClassifier,
    LegalLongformerTrainer,
)

sns.set_style('whitegrid')
pd.set_option('display.max_colwidth', 50)

## 1. Load Data and Checkpoint

In [None]:
# Load data (same as training)
loader = DataLoader('../data/TRDataChallenge2023.txt')
preparer = DataPreparer(loader, min_label_count=50, random_state=42)
data = preparer.prepare(max_features=10000, ngram_range=(1, 2))

print(data.summary())
print(f"\nLabels: {len(data.label_names)}")

In [None]:
# Check available model files
import json

model_path = Path('../outputs/legal_longformer_best.pt')
history_path = Path('../outputs/legal_longformer_best.history.json')

print("Model files:")
history = None

if model_path.exists():
    size_mb = model_path.stat().st_size / 1e6
    print(f"  {model_path.name}: {size_mb:.1f} MB ✓")
else:
    print(f"  {model_path.name}: NOT FOUND ❌")

if history_path.exists():
    print(f"  {history_path.name} ✓")
    with open(history_path) as f:
        history = json.load(f)
    print(f"\nTraining History:")
    print(f"  Best F1:    {history['best_f1']:.4f}")
    print(f"  Best epoch: {history['best_epoch']}")
else:
    print(f"  {history_path.name}: NOT FOUND")

In [None]:
# Show training history (from saved JSON)
if history:
    print("Training progression:")
    print(f"{'Epoch':<8} {'Train Loss':<12} {'Val Loss':<12} {'Val F1 Micro':<14} {'Val F1 Macro':<14}")
    print("-" * 60)
    for i in range(len(history['train_losses'])):
        best_marker = " ← best" if (i + 1) == history['best_epoch'] else ""
        print(f"{i+1:<8} {history['train_losses'][i]:<12.4f} {history['val_losses'][i]:<12.4f} "
              f"{history['val_f1_micro'][i]:<14.4f} {history['val_f1_macro'][i]:<14.4f}{best_marker}")
else:
    print("No training history available")

## 2. Load Model from Checkpoint

In [None]:
# Initialize hybrid classifier and load best model
hybrid_classifier = HybridLegalClassifier(
    num_labels=len(data.label_names),
    cache_dir='../outputs/summaries',
    device='auto',
)

# Load the best model (saved after training completed)
hybrid_classifier.load('../outputs/legal_longformer_best.pt')

if history:
    best_epoch = history['best_epoch']
    print(f"✓ Loaded best model (epoch {best_epoch})")
    print(f"  Val F1 Micro: {history['val_f1_micro'][best_epoch-1]:.4f}")
    print(f"  Val F1 Macro: {history['val_f1_macro'][best_epoch-1]:.4f}")
else:
    print("✓ Loaded model from legal_longformer_best.pt")

## 3. Preprocess Test Data

In [None]:
# Preprocess BOTH val and test texts (uses cached summaries)
print("Preprocessing validation texts (for threshold optimization)...")
val_texts_processed = hybrid_classifier.preprocess_texts(data.val_texts)

print("\nPreprocessing test texts (for final evaluation)...")
test_texts_processed = hybrid_classifier.preprocess_texts(data.test_texts)

stats = hybrid_classifier.get_processing_stats()
print(f"\nTotal processed: {stats['total_processed']}")
print(f"Direct: {stats['direct_classified']}")
print(f"Summarized: {stats['summarized_first']}")

## 4. Generate Predictions

In [None]:
# Get predictions and probabilities for BOTH sets
print("Generating predictions on validation set...")
y_proba_val = hybrid_classifier.predict_proba(val_texts_processed, preprocess=False, batch_size=16)
y_pred_val = (y_proba_val >= 0.5).astype(int)

print("Generating predictions on test set...")
y_proba_test = hybrid_classifier.predict_proba(test_texts_processed, preprocess=False, batch_size=16)
y_pred_test = (y_proba_test >= 0.5).astype(int)

print(f"\nVal predictions shape: {y_pred_val.shape}")
print(f"Test predictions shape: {y_pred_test.shape}")
print(f"Positive predictions per sample (test): {y_pred_test.sum(axis=1).mean():.2f} avg")

In [None]:
# Save Legal-Longformer predictions for ensemble analysis
np.savez_compressed(
    '../outputs/legal_longformer_test_predictions.npz',
    y_pred=y_pred_test,
    y_proba=y_proba_test
)
print("✓ Legal-Longformer predictions saved to outputs/legal_longformer_test_predictions.npz")

## 5. Overall Evaluation

In [None]:
# Evaluate on TEST set (held-out, final evaluation)
evaluator = MultiLabelEvaluator(data.label_names)
results = evaluator.evaluate(data.y_test, y_pred_test)

print("TEST SET EVALUATION (threshold=0.5):")
print(results.summary())

## 6. Per-Class Performance

In [None]:
# Top performing classes (on TEST set, threshold=0.5)
print("TOP 10 CLASSES (by F1) - TEST SET:")
results.get_top_classes(10, 'f1')[['label', 'precision', 'recall', 'f1', 'support']]

In [None]:
# Bottom performing classes
print("BOTTOM 10 CLASSES (by F1):")
results.get_bottom_classes(10, 'f1')[['label', 'precision', 'recall', 'f1', 'support']]

In [None]:
# Full per-class metrics (TEST set, threshold=0.5)
per_class = results.per_class_metrics.copy()
per_class = per_class.sort_values('f1', ascending=False)
per_class

## 7. F1 Distribution by Class

In [None]:
# F1 distribution histogram
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Histogram
ax1 = axes[0]
ax1.hist(per_class['f1'], bins=20, edgecolor='white', alpha=0.7)
ax1.axvline(x=0.63, color='orange', linestyle='--', label='Human κ low (0.63)')
ax1.axvline(x=0.93, color='green', linestyle='--', label='Human κ high (0.93)')
median_f1 = float(per_class['f1'].median())
ax1.axvline(x=median_f1, color='red', linestyle='-', label=f'Median: {median_f1:.2f}')
ax1.set_xlabel('F1 Score')
ax1.set_ylabel('Number of Classes')
ax1.set_title('F1 Score Distribution Across Classes')
ax1.legend()

# Bar chart by class
ax2 = axes[1]
colors = ['green' if f1 >= 0.63 else 'orange' if f1 >= 0.5 else 'red' for f1 in per_class['f1']]
ax2.bar(range(len(per_class)), per_class['f1'], color=colors, alpha=0.7)
ax2.axhline(y=0.63, color='orange', linestyle='--', alpha=0.7)
ax2.axhline(y=0.93, color='green', linestyle='--', alpha=0.7)
ax2.set_xlabel('Class Index (sorted by F1)')
ax2.set_ylabel('F1 Score')
ax2.set_title('F1 by Class (green=automatable, orange=review, red=human)')

plt.tight_layout()
plt.show()

## 8. Threshold Analysis

In [None]:
# Try different global thresholds on VALIDATION set
from sklearn.metrics import f1_score, precision_score, recall_score

thresholds = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8]
thresh_results = []

for thresh in thresholds:
    y_pred_t = (y_proba_val >= thresh).astype(int)
    thresh_results.append({
        'threshold': thresh,
        'f1_micro': f1_score(data.y_val, y_pred_t, average='micro', zero_division=0),
        'f1_macro': f1_score(data.y_val, y_pred_t, average='macro', zero_division=0),
        'precision': precision_score(data.y_val, y_pred_t, average='micro', zero_division=0),
        'recall': recall_score(data.y_val, y_pred_t, average='micro', zero_division=0),
        'avg_preds': y_pred_t.sum(axis=1).mean(),
    })

thresh_df = pd.DataFrame(thresh_results)
print("Global threshold analysis (on VALIDATION set):")
thresh_df

In [None]:
# Plot threshold analysis
fig, ax = plt.subplots(figsize=(10, 5))
ax.plot(thresh_df['threshold'], thresh_df['f1_micro'], 'o-', label='F1 Micro', linewidth=2)
ax.plot(thresh_df['threshold'], thresh_df['f1_macro'], 's-', label='F1 Macro', linewidth=2)
ax.plot(thresh_df['threshold'], thresh_df['precision'], '^--', label='Precision', alpha=0.7)
ax.plot(thresh_df['threshold'], thresh_df['recall'], 'v--', label='Recall', alpha=0.7)

best_thresh = thresh_df.loc[thresh_df['f1_micro'].idxmax(), 'threshold']
ax.axvline(x=float(best_thresh), color='red', linestyle=':', label=f'Best threshold: {best_thresh}') # type: ignore

ax.set_xlabel('Threshold')
ax.set_ylabel('Score')
ax.set_title('Threshold Analysis')
ax.legend()
ax.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

print(f"\nBest threshold: {best_thresh} → F1 Micro: {thresh_df.loc[thresh_df['f1_micro'].idxmax(), 'f1_micro']:.4f}")

In [None]:
# Evaluate with optimal threshold (0.6)
OPTIMAL_THRESHOLD = 0.6

y_pred_optimal = (y_proba_test >= OPTIMAL_THRESHOLD).astype(int)
results_optimal = evaluator.evaluate(data.y_test, y_pred_optimal)

print(f"=== Results with Optimal Threshold ({OPTIMAL_THRESHOLD}) ===")
print(results_optimal.summary())

print(f"\nImprovement over default (0.5):")
print(f"  F1 Micro: {results.f1_micro:.4f} → {results_optimal.f1_micro:.4f} ({(results_optimal.f1_micro - results.f1_micro)*100:+.1f}%)")
print(f"  F1 Macro: {results.f1_macro:.4f} → {results_optimal.f1_macro:.4f} ({(results_optimal.f1_macro - results.f1_macro)*100:+.1f}%)")

In [None]:
# Per-class performance with optimal threshold
per_class_optimal = results_optimal.per_class_metrics

print(f"TOP 10 CLASSES (threshold={OPTIMAL_THRESHOLD}):")
display(per_class_optimal.nlargest(10, 'f1')[['label', 'precision', 'recall', 'f1', 'support']])

print(f"\nBOTTOM 10 CLASSES (threshold={OPTIMAL_THRESHOLD}):")
display(per_class_optimal.nsmallest(10, 'f1')[['label', 'precision', 'recall', 'f1', 'support']])

## 9. Prediction Distribution Analysis

In [None]:
# Analyze probability distribution (TEST set)
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Overall probability distribution
ax1 = axes[0]
ax1.hist(y_proba_test.flatten(), bins=50, edgecolor='white', alpha=0.7)
ax1.axvline(x=0.5, color='red', linestyle='--', label='Threshold (0.5)')
ax1.set_xlabel('Predicted Probability')
ax1.set_ylabel('Frequency')
ax1.set_title('Overall Probability Distribution (Test Set)')
ax1.legend()
ax1.set_yscale('log')

# Predictions per sample
ax2 = axes[1]
preds_per_sample = y_pred_test.sum(axis=1)
max_preds = int(preds_per_sample.max())
mean_preds = float(preds_per_sample.mean())
ax2.hist(preds_per_sample, bins=range(0, max_preds + 2), edgecolor='white', alpha=0.7)
ax2.axvline(x=mean_preds, color='red', linestyle='--', label=f'Mean: {mean_preds:.2f}')
ax2.set_xlabel('Number of Predictions per Sample')
ax2.set_ylabel('Frequency')
ax2.set_title('Predictions per Sample Distribution')
ax2.legend()

plt.tight_layout()
plt.show()

print(f"Samples with 0 predictions: {(preds_per_sample == 0).sum()} ({(preds_per_sample == 0).mean()*100:.1f}%)")

## 10. Feasibility Analysis

In [None]:
# Feasibility vs human annotator agreement
feasibility = results.get_feasibility_analysis(human_kappa_low=0.63, human_kappa_high=0.93)

print("=== AUTOMATION FEASIBILITY ===")
print(f"Fully automatable (F1 ≥ 0.63):  {feasibility['automation_feasible'].sum()} / {len(feasibility)}")
print(f"High confidence (F1 ≥ 0.93):    {feasibility['high_confidence'].sum()} / {len(feasibility)}")
print(f"Needs review (0.50 ≤ F1 < 0.63): {feasibility['needs_review'].sum()} / {len(feasibility)}")
print(f"Not feasible (F1 < 0.50):       {(~feasibility['automation_feasible'] & ~feasibility['needs_review']).sum()} / {len(feasibility)}")

In [None]:
# Show automatable classes
print("\nCLASSES FEASIBLE FOR AUTOMATION:")
feasibility[feasibility['automation_feasible']][['label', 'f1', 'support']].sort_values('f1', ascending=False)

## 10. Per-Class Threshold Optimization

Find the optimal threshold for each class independently to maximize F1.

In [None]:
# Find optimal threshold for each class using VALIDATION set
# Thresholds will be applied to TEST set for unbiased final evaluation
from sklearn.metrics import f1_score

thresholds_to_try = np.arange(0.1, 0.95, 0.05)
optimal_thresholds = []

for class_idx in range(len(data.label_names)):
    best_f1 = 0
    best_thresh = 0.5
    
    # Use VALIDATION set for optimization
    y_true_class = data.y_val[:, class_idx]
    y_proba_class = y_proba_val[:, class_idx]
    
    # Calculate F1 at default threshold (0.5)
    y_pred_at_05 = (y_proba_class >= 0.5).astype(int)
    if y_pred_at_05.sum() > 0 and y_pred_at_05.sum() < len(y_pred_at_05):
        f1_at_05 = f1_score(y_true_class, y_pred_at_05, zero_division=0)
    else:
        f1_at_05 = 0
    
    for t in thresholds_to_try:
        y_pred_class = (y_proba_class >= t).astype(int)
        # Handle case where all predictions are 0 or all are 1
        if y_pred_class.sum() == 0 or y_pred_class.sum() == len(y_pred_class):
            continue
        f1 = f1_score(y_true_class, y_pred_class, zero_division=0)
        if f1 > best_f1:
            best_f1 = f1
            best_thresh = t
    
    optimal_thresholds.append({
        'class_idx': class_idx,
        'label': data.label_names[class_idx],
        'optimal_threshold': best_thresh,
        'f1_at_0.5': f1_at_05,
        'f1_at_optimal': best_f1,
        'improvement': best_f1 - f1_at_05,
        'support': int(y_true_class.sum())
    })

threshold_df = pd.DataFrame(optimal_thresholds)
print(f"Per-class threshold optimization (on VALIDATION set):")
print(f"  Threshold range: {threshold_df['optimal_threshold'].min():.2f} - {threshold_df['optimal_threshold'].max():.2f}")
print(f"  Mean threshold:  {threshold_df['optimal_threshold'].mean():.2f}")

In [None]:
# Show classes with biggest improvements
print("TOP 10 CLASSES BY IMPROVEMENT:")
display(threshold_df.nlargest(10, 'improvement')[
    ['label', 'optimal_threshold', 'f1_at_0.5', 'f1_at_optimal', 'improvement', 'support']
])

# Threshold distribution
print(f"\nThreshold distribution:")
print(threshold_df['optimal_threshold'].describe())

In [None]:
# Apply per-class thresholds (from VAL) to TEST set for final unbiased evaluation
per_class_thresholds = threshold_df['optimal_threshold'].values

# Apply to TEST predictions
y_pred_test_perclass = np.zeros_like(y_proba_test, dtype=int)
for class_idx in range(len(data.label_names)):
    y_pred_test_perclass[:, class_idx] = (y_proba_test[:, class_idx] >= per_class_thresholds[class_idx]).astype(int)

# Also apply global optimal threshold (0.6) to TEST
OPTIMAL_THRESHOLD = 0.6
y_pred_test_optimal = (y_proba_test >= OPTIMAL_THRESHOLD).astype(int)

# Evaluate all configurations on TEST set
results_test_05 = evaluator.evaluate(data.y_test, y_pred_test)
results_test_06 = evaluator.evaluate(data.y_test, y_pred_test_optimal)
results_test_perclass = evaluator.evaluate(data.y_test, y_pred_test_perclass)

# Load TF-IDF predictions and compute metrics dynamically (instead of hardcoding)
tfidf_pred_path = Path('../outputs/tfidf_test_predictions.npz')
if tfidf_pred_path.exists():
    tfidf_data = np.load(tfidf_pred_path)
    y_pred_tfidf = tfidf_data['y_pred']
    results_tfidf = evaluator.evaluate(data.y_test, y_pred_tfidf)
    TFIDF_F1_MICRO = results_tfidf.f1_micro
    TFIDF_F1_MACRO = results_tfidf.f1_macro
    TFIDF_PRECISION = results_tfidf.precision_micro
    TFIDF_RECALL = results_tfidf.recall_micro
else:
    raise FileNotFoundError(f"Run NB 02 first to generate {tfidf_pred_path}")

print("=" * 80)
print("FINAL EVALUATION ON TEST SET (thresholds optimized on validation)")
print("=" * 80)
print(f"{'Metric':<15} {'TF-IDF':<12} {'LF (0.5)':<12} {'LF (0.6)':<12} {'LF Per-Class':<12}")
print("-" * 80)
print(f"{'F1 Micro':<15} {TFIDF_F1_MICRO:<12.4f} {results_test_05.f1_micro:<12.4f} {results_test_06.f1_micro:<12.4f} {results_test_perclass.f1_micro:<12.4f}")
print(f"{'F1 Macro':<15} {TFIDF_F1_MACRO:<12.4f} {results_test_05.f1_macro:<12.4f} {results_test_06.f1_macro:<12.4f} {results_test_perclass.f1_macro:<12.4f}")
print(f"{'Precision':<15} {TFIDF_PRECISION:<12.4f} {results_test_05.precision_micro:<12.4f} {results_test_06.precision_micro:<12.4f} {results_test_perclass.precision_micro:<12.4f}")
print(f"{'Recall':<15} {TFIDF_RECALL:<12.4f} {results_test_05.recall_micro:<12.4f} {results_test_06.recall_micro:<12.4f} {results_test_perclass.recall_micro:<12.4f}")
print("=" * 80)
print(f"\nGap to TF-IDF (F1 Micro): {(results_test_perclass.f1_micro - TFIDF_F1_MICRO)*100:+.1f}%")
print(f"Gap to TF-IDF (F1 Macro): {(results_test_perclass.f1_macro - TFIDF_F1_MACRO)*100:+.1f}%")

## 11. Summary

In [None]:
print("=" * 60)
print("MODEL EVALUATION SUMMARY")
print("=" * 60)
if history:
    print(f"\nModel: Legal-Longformer (epoch {history['best_epoch']})")
    print(f"Validation F1 Micro: {history['best_f1']:.4f}")
    print(f"Validation F1 Macro: {history['val_f1_macro'][history['best_epoch']-1]:.4f}")

print(f"\nTest Set Performance (threshold=0.5):")
print(f"  F1 Micro:      {results.f1_micro:.4f}")
print(f"  F1 Macro:      {results.f1_macro:.4f}")
print(f"  F1 Weighted:   {results.f1_weighted:.4f}")
print(f"  Precision:     {results.precision_micro:.4f}")
print(f"  Recall:        {results.recall_micro:.4f}")

print(f"\nOptimal Threshold Analysis:")
best_idx = thresh_df['f1_micro'].idxmax()
print(f"  Best threshold:  {thresh_df.loc[best_idx, 'threshold']}")
print(f"  F1 Micro:        {thresh_df.loc[best_idx, 'f1_micro']:.4f}")

print(f"\nAutomation Feasibility (vs human κ=0.63-0.93):")
print(f"  High confidence (F1≥0.93): {feasibility['high_confidence'].sum()} / {len(feasibility)} classes")
print(f"  Automatable (F1≥0.63):     {feasibility['automation_feasible'].sum()} / {len(feasibility)} classes")
print("=" * 60)

In [None]:
# Save Legal-Longformer predictions for ensemble analysis
from src.model_evaluator import save_predictions

pred_path = Path('../outputs/longformer_predictions.npz')
save_predictions(y_pred_test, y_proba_test, str(pred_path))
print(f"Saved Legal-Longformer predictions to {pred_path}")