# 04 - Evaluate Model

This notebook provides comprehensive model evaluation:
- Detailed performance metrics
- Confusion matrices
- Per-class metrics
- Error analysis
- Sample predictions


In [1]:
# Import libraries
import os
import sys
import json
import torch
from pathlib import Path
from torch.utils.data import DataLoader
from sklearn.metrics import (
    accuracy_score, precision_recall_fscore_support,
    confusion_matrix, classification_report
)
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

# Add src to path
sys.path.insert(0, str(Path().absolute().parent / 'src'))

from data_processor import DFGDataset, load_config, load_dfg_mapping
from model import DFGClassifier

# Set style
sns.set_style("whitegrid")
plt.rcParams['figure.figsize'] = (14, 8)

print("‚úì Libraries imported")


  from .autonotebook import tqdm as notebook_tqdm


‚úì Libraries imported


In [None]:
# Configuration
DATA_PATH = '../dfg-classifier/data/processed'
MODEL_PATH = '../dfg-classifier/models/experiments/trained_model.pt'  # Change to your model path
CONFIG_PATH = '../config.yaml'
DFG_MAPPING_PATH = '../data/dfg_mapping.json'

# Device
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {DEVICE}")

# Load configuration
config = load_config(CONFIG_PATH)
dfg_mapping = load_dfg_mapping(DFG_MAPPING_PATH)
print("‚úì Configuration loaded")


In [None]:
# Load test dataset
def load_processed_dataset(data_path: str, split: str = 'test'):
    """Load processed dataset"""
    file_path = os.path.join(data_path, f'{split}.json')
    
    if not os.path.exists(file_path):
        raise FileNotFoundError(f"Dataset file not found: {file_path}")
    
    with open(file_path, 'r', encoding='utf-8') as f:
        data = json.load(f)
    
    dataset = []
    for item in data:
        dataset.append({
            'input_ids': torch.tensor(item['input_ids']),
            'attention_mask': torch.tensor(item['attention_mask']),
            'labels': torch.tensor(item['labels']),
            'filename': item['filename'],
            'title': item['title'],
            'abstract': item['abstract'],
            'label': item['label']
        })
    
    return dataset

print("üìä Loading test dataset...")
test_dataset = load_processed_dataset(DATA_PATH, 'test')
print(f"‚úì Test set: {len(test_dataset)} samples")

# Create data loader
def create_data_loader(dataset, batch_size=4):
    from data_processor import DFGDataset
    torch_dataset = DFGDataset(dataset)
    return DataLoader(
        torch_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=0
    )

test_loader = create_data_loader(test_dataset, batch_size=4)
print(f"‚úì Test batches: {len(test_loader)}")


In [None]:
# Load model
print("üß† Loading model...")

num_classes = len(dfg_mapping['level_2']['classes'])

# Initialize model architecture
model = DFGClassifier(
    model_name='allenai/scibert_scivocab_uncased',
    num_classes=num_classes,
    dropout_rate=0.3,
    freeze_bert=False,
    hierarchical=False,
    dfg_mapping=dfg_mapping
)

# Load trained weights
if os.path.exists(MODEL_PATH):
    model.load_state_dict(torch.load(MODEL_PATH, map_location=DEVICE))
    print(f"‚úì Model loaded from: {MODEL_PATH}")
else:
    print(f"‚ö†Ô∏è  Model file not found: {MODEL_PATH}")
    print("   Initializing with random weights for demonstration")

model.to(DEVICE)
model.eval()
print(f"‚úì Model ready for evaluation")


## Model Evaluation

Evaluate the model on the test set and collect predictions.


In [None]:
# Evaluate model
print("üîç Evaluating model...")

all_predictions = []
all_labels = []
all_probs = []
sample_data = []  # Store sample info for analysis

model.eval()
with torch.no_grad():
    for batch in test_loader:
        input_ids = batch['input_ids'].to(DEVICE)
        attention_mask = batch['attention_mask'].to(DEVICE)
        labels = batch['labels'].to(DEVICE)
        
        outputs = model(input_ids, attention_mask, labels)
        predictions = outputs['predictions']
        logits = outputs['logits']
        probs = torch.softmax(logits, dim=-1)
        
        all_predictions.extend(predictions.cpu().numpy())
        all_labels.extend(labels.cpu().numpy())
        all_probs.extend(probs.cpu().numpy())
        
        # Store sample data
        for i in range(len(labels)):
            sample_data.append({
                'filename': batch['filename'][i],
                'title': batch['title'][i],
                'abstract': batch['abstract'][i],
                'true_label': batch['label'][i],
                'predicted_label_id': predictions[i].item(),
                'true_label_id': labels[i].item(),
                'confidence': probs[i][predictions[i]].item()
            })

all_predictions = np.array(all_predictions)
all_labels = np.array(all_labels)
all_probs = np.array(all_probs)

print(f"‚úì Evaluation complete: {len(all_predictions)} samples")


In [None]:
# Overall Metrics
accuracy = accuracy_score(all_labels, all_predictions)
precision, recall, f1, support = precision_recall_fscore_support(
    all_labels, all_predictions, average=None, zero_division=0
)

# Macro and weighted averages
macro_precision = precision.mean()
macro_recall = recall.mean()
macro_f1 = f1.mean()

weighted_precision = np.average(precision, weights=support)
weighted_recall = np.average(recall, weights=support)
weighted_f1 = np.average(f1, weights=support)

print("üìä Overall Performance Metrics:")
print("=" * 80)
print(f"Accuracy: {accuracy:.4f}")
print(f"\nMacro Averages:")
print(f"  Precision: {macro_precision:.4f}")
print(f"  Recall: {macro_recall:.4f}")
print(f"  F1-Score: {macro_f1:.4f}")
print(f"\nWeighted Averages:")
print(f"  Precision: {weighted_precision:.4f}")
print(f"  Recall: {weighted_recall:.4f}")
print(f"  F1-Score: {weighted_f1:.4f}")
print("=" * 80)


In [None]:
# Create confusion matrix
cm = confusion_matrix(all_labels, all_predictions)

# Get class names
id_to_label = {v: k for k, v in dfg_mapping['level_2']['classes'].items()}
class_names = [dfg_mapping['level_2']['classes'].get(id_to_label.get(i, ''), f'Class {i}') 
               for i in range(len(cm))]

# Plot confusion matrix
fig, ax = plt.subplots(figsize=(16, 14))

# Normalize confusion matrix
cm_normalized = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
cm_normalized = np.nan_to_num(cm_normalized)

# Plot
sns.heatmap(
    cm_normalized,
    annot=True,
    fmt='.2f',
    cmap='Blues',
    xticklabels=class_names[:len(cm)],
    yticklabels=class_names[:len(cm)],
    ax=ax,
    cbar_kws={'label': 'Normalized Frequency'}
)

ax.set_xlabel('Predicted Label', fontsize=12)
ax.set_ylabel('True Label', fontsize=12)
ax.set_title('Confusion Matrix (Normalized)', fontsize=14, fontweight='bold')
plt.xticks(rotation=45, ha='right', fontsize=8)
plt.yticks(rotation=0, fontsize=8)
plt.tight_layout()
plt.show()

print("‚ÑπÔ∏è  Values represent the normalized frequency (0-1) of predictions")


## Per-Class Metrics

Analyze performance for each individual class.


In [None]:
# Per-class metrics
per_class_metrics = []
for i, class_name in enumerate(class_names[:len(precision)]):
    per_class_metrics.append({
        'Class': class_name,
        'Precision': precision[i],
        'Recall': recall[i],
        'F1-Score': f1[i],
        'Support': int(support[i])
    })

# Create DataFrame
metrics_df = pd.DataFrame(per_class_metrics)
metrics_df = metrics_df.sort_values('F1-Score', ascending=False)

# Display table
print("üìà Per-Class Performance Metrics:")
print("=" * 100)
print(metrics_df.to_string(index=False))
print("=" * 100)

# Visualize per-class F1 scores
fig, ax = plt.subplots(figsize=(14, 8))
sorted_idx = metrics_df['F1-Score'].sort_values().index
sorted_metrics = metrics_df.loc[sorted_idx]

colors = ['coral' if f < 0.5 else 'steelblue' for f in sorted_metrics['F1-Score']]
ax.barh(range(len(sorted_metrics)), sorted_metrics['F1-Score'], color=colors, alpha=0.7)
ax.set_yticks(range(len(sorted_metrics)))
ax.set_yticklabels(sorted_metrics['Class'], fontsize=9)
ax.set_xlabel('F1-Score', fontsize=12)
ax.set_title('Per-Class F1 Scores', fontsize=14, fontweight='bold')
ax.axvline(x=0.5, color='red', linestyle='--', alpha=0.5, label='F1=0.5 threshold')
ax.legend()
ax.grid(axis='x', alpha=0.3)
plt.tight_layout()
plt.show()

# Summary statistics
print(f"\nüìä Summary Statistics:")
print(f"  Classes with F1 > 0.8: {(metrics_df['F1-Score'] > 0.8).sum()}")
print(f"  Classes with F1 > 0.5: {(metrics_df['F1-Score'] > 0.5).sum()}")
print(f"  Classes with F1 <= 0.5: {(metrics_df['F1-Score'] <= 0.5).sum()}")
print(f"  Mean F1-Score: {metrics_df['F1-Score'].mean():.4f}")
print(f"  Std F1-Score: {metrics_df['F1-Score'].std():.4f}")


## Error Analysis

Identify common errors and analyze misclassified samples.


In [None]:
# Find misclassified samples
errors = []
for i, sample in enumerate(sample_data):
    if sample['true_label_id'] != sample['predicted_label_id']:
        errors.append({
            **sample,
            'predicted_label': dfg_mapping['level_2']['classes'].get(
                list(dfg_mapping['level_2']['classes'].keys())[sample['predicted_label_id']],
                f"Class {sample['predicted_label_id']}"
            ) if sample['predicted_label_id'] < len(dfg_mapping['level_2']['classes']) else f"Class {sample['predicted_label_id']}"
        })

print(f"üîç Error Analysis:")
print(f"  Total errors: {len(errors)} out of {len(sample_data)} samples")
print(f"  Error rate: {len(errors)/len(sample_data)*100:.2f}%")

# Most common error types
if errors:
    error_pairs = [(e['true_label'], e['predicted_label']) for e in errors]
    error_counts = pd.Series(error_pairs).value_counts().head(10)
    
    print(f"\nüî¥ Top 10 Most Common Errors:")
    for (true_label, pred_label), count in error_counts.items():
        print(f"  {true_label} ‚Üí {pred_label}: {count} times")
    
    # Visualize error distribution
    fig, ax = plt.subplots(figsize=(12, 6))
    error_counts.plot(kind='barh', ax=ax, color='coral')
    ax.set_xlabel('Frequency', fontsize=12)
    ax.set_ylabel('Error Type (True ‚Üí Predicted)', fontsize=12)
    ax.set_title('Most Common Classification Errors', fontsize=14, fontweight='bold')
    ax.invert_yaxis()
    ax.grid(axis='x', alpha=0.3)
    plt.tight_layout()
    plt.show()


In [None]:
# Show some example errors
if errors:
    print("\nüìù Example Misclassified Samples:")
    print("=" * 100)
    
    # Show top 5 errors by confidence (high confidence errors are interesting)
    sorted_errors = sorted(errors, key=lambda x: x['confidence'], reverse=True)[:5]
    
    for i, error in enumerate(sorted_errors, 1):
        print(f"\n[Error Example {i}]")
        print(f"  True Label: {error['true_label']}")
        print(f"  Predicted Label: {error['predicted_label']}")
        print(f"  Confidence: {error['confidence']:.4f}")
        print(f"  Title: {error['title'][:100]}..." if len(error['title']) > 100 else f"  Title: {error['title']}")
        print(f"  Abstract: {error['abstract'][:200]}..." if len(error['abstract']) > 200 else f"  Abstract: {error['abstract']}")
        print("-" * 100)


## Sample Predictions

Show some correctly classified examples with their confidence scores.


In [None]:
# Show correct predictions with high confidence
correct = [s for s in sample_data if s['true_label_id'] == s['predicted_label_id']]
high_confidence_correct = sorted(correct, key=lambda x: x['confidence'], reverse=True)[:5]

if high_confidence_correct:
    print("‚úÖ Example Correct Predictions (High Confidence):")
    print("=" * 100)
    
    for i, sample in enumerate(high_confidence_correct, 1):
        print(f"\n[Correct Example {i}]")
        print(f"  Label: {sample['true_label']}")
        print(f"  Confidence: {sample['confidence']:.4f}")
        print(f"  Title: {sample['title'][:100]}..." if len(sample['title']) > 100 else f"  Title: {sample['title']}")
        print(f"  Abstract: {sample['abstract'][:200]}..." if len(sample['abstract']) > 200 else f"  Abstract: {sample['abstract']}")
        print("-" * 100)


In [None]:
# Confidence distribution
confidences = [s['confidence'] for s in sample_data]
correct_confidences = [s['confidence'] for s in correct]
error_confidences = [s['confidence'] for s in errors] if errors else []

fig, axes = plt.subplots(1, 2, figsize=(16, 5))

# Overall confidence distribution
axes[0].hist(confidences, bins=30, color='steelblue', alpha=0.7, edgecolor='black')
axes[0].set_xlabel('Prediction Confidence', fontsize=12)
axes[0].set_ylabel('Frequency', fontsize=12)
axes[0].set_title('Overall Confidence Distribution', fontsize=14, fontweight='bold')
axes[0].grid(alpha=0.3)
axes[0].axvline(x=np.mean(confidences), color='red', linestyle='--', 
                label=f'Mean: {np.mean(confidences):.3f}')
axes[0].legend()

# Confidence by correctness
if errors:
    axes[1].hist(correct_confidences, bins=30, alpha=0.6, label='Correct', 
                 color='green', edgecolor='black')
    axes[1].hist(error_confidences, bins=30, alpha=0.6, label='Errors', 
                 color='red', edgecolor='black')
    axes[1].set_xlabel('Prediction Confidence', fontsize=12)
    axes[1].set_ylabel('Frequency', fontsize=12)
    axes[1].set_title('Confidence Distribution: Correct vs Errors', fontsize=14, fontweight='bold')
    axes[1].legend()
    axes[1].grid(alpha=0.3)
    
    print(f"\nüìä Confidence Statistics:")
    print(f"  Mean confidence (all): {np.mean(confidences):.4f}")
    print(f"  Mean confidence (correct): {np.mean(correct_confidences):.4f}")
    print(f"  Mean confidence (errors): {np.mean(error_confidences):.4f}")
else:
    axes[1].text(0.5, 0.5, 'No errors found!', 
                ha='center', va='center', transform=axes[1].transAxes, fontsize=14)

plt.tight_layout()
plt.show()


## Detailed Classification Report

Generate a detailed classification report with all metrics.


In [None]:
# Detailed classification report
print("üìã Detailed Classification Report:")
print("=" * 100)
print(classification_report(
    all_labels,
    all_predictions,
    target_names=class_names[:len(np.unique(all_labels))],
    digits=4,
    zero_division=0
))
print("=" * 100)


In [None]:
# Save evaluation results
results = {
    'overall_metrics': {
        'accuracy': float(accuracy),
        'macro_precision': float(macro_precision),
        'macro_recall': float(macro_recall),
        'macro_f1': float(macro_f1),
        'weighted_precision': float(weighted_precision),
        'weighted_recall': float(weighted_recall),
        'weighted_f1': float(weighted_f1)
    },
    'per_class_metrics': metrics_df.to_dict('records'),
    'error_count': len(errors),
    'error_rate': len(errors)/len(sample_data) if sample_data else 0,
    'total_samples': len(sample_data)
}

results_path = '../dfg-classifier/models/experiments/evaluation_results.json'
os.makedirs(os.path.dirname(results_path), exist_ok=True)

with open(results_path, 'w') as f:
    json.dump(results, f, indent=2)

print(f"‚úì Evaluation results saved to: {results_path}")
print("\n‚úÖ Evaluation complete!")
