# Model Evaluation

Comprehensive evaluation of trained ice classification models:
- Load trained model
- Evaluate on test set
- Generate confusion matrix
- Calculate per-class metrics
- Error analysis

In [None]:
import sys
sys.path.append('../training')

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import confusion_matrix, classification_report, accuracy_score
from tqdm.notebook import tqdm

# Import dataset and model
from train_ice_classifier import IceDataset, IceClassifier

print("Imports successful!")

## 1. Load Trained Model

In [None]:
# Configuration
MODEL_PATH = '../models/ice_classifier_resnet50.pth'
DATA_DIR = '../data/processed'
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'

# Load model
model = IceClassifier(num_classes=3, pretrained=False)
checkpoint = torch.load(MODEL_PATH, map_location=DEVICE)
model.load_state_dict(checkpoint['model_state_dict'])
model = model.to(DEVICE)
model.eval()

print(f"âœ… Model loaded from: {MODEL_PATH}")
print(f"   Epoch: {checkpoint['epoch']}")
print(f"   Val Loss: {checkpoint['val_loss']:.4f}")
print(f"   Val Acc: {checkpoint['val_acc']:.2f}%")

## 2. Load Test Dataset

In [None]:
# Load test dataset
test_dataset = IceDataset(DATA_DIR, split='test')
test_loader = DataLoader(test_dataset, batch_size=16, shuffle=False, num_workers=2)

print(f"Test samples: {len(test_dataset)}")
print(f"Test batches: {len(test_loader)}")

## 3. Evaluate on Test Set

In [None]:
all_predictions = []
all_labels = []
all_probabilities = []

print("Evaluating model on test set...")

with torch.no_grad():
    for images, labels in tqdm(test_loader, desc='Testing'):
        images = images.to(DEVICE)
        
        outputs = model(images)
        probabilities = torch.softmax(outputs, dim=1)
        predictions = torch.argmax(probabilities, dim=1)
        
        all_predictions.extend(predictions.cpu().numpy())
        all_labels.extend(labels.numpy())
        all_probabilities.extend(probabilities.cpu().numpy())

all_predictions = np.array(all_predictions)
all_labels = np.array(all_labels)
all_probabilities = np.array(all_probabilities)

# Calculate accuracy
accuracy = accuracy_score(all_labels, all_predictions) * 100
print(f"\nâœ… Test Accuracy: {accuracy:.2f}%")

## 4. Confusion Matrix

In [None]:
# Calculate confusion matrix
cm = confusion_matrix(all_labels, all_predictions)
class_names = ['Open Water', 'Thin Ice', 'Thick Ice']

# Plot confusion matrix
plt.figure(figsize=(10, 8))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', 
            xticklabels=class_names, yticklabels=class_names,
            cbar_kws={'label': 'Count'})
plt.title('Confusion Matrix', fontsize=16, fontweight='bold', pad=20)
plt.ylabel('True Label', fontsize=12)
plt.xlabel('Predicted Label', fontsize=12)
plt.tight_layout()
plt.savefig('../models/confusion_matrix.png', dpi=150, bbox_inches='tight')
plt.show()

print("\nConfusion Matrix:")
print(cm)

## 5. Per-Class Metrics

In [None]:
# Classification report
report = classification_report(all_labels, all_predictions, 
                               target_names=class_names, digits=3)
print("\nClassification Report:")
print("="*60)
print(report)

# Per-class accuracy
print("\nPer-Class Accuracy:")
for i, class_name in enumerate(class_names):
    class_mask = all_labels == i
    if class_mask.sum() > 0:
        class_acc = (all_predictions[class_mask] == all_labels[class_mask]).mean() * 100
        print(f"  {class_name}: {class_acc:.2f}%")

## 6. Error Analysis

In [None]:
# Find misclassified samples
errors = all_predictions != all_labels
error_indices = np.where(errors)[0]

print(f"Total errors: {len(error_indices)} / {len(all_labels)} ({len(error_indices)/len(all_labels)*100:.1f}%)")

# Visualize some errors
if len(error_indices) > 0:
    num_errors_to_show = min(6, len(error_indices))
    fig, axes = plt.subplots(2, 3, figsize=(15, 10))
    axes = axes.flatten()
    
    for i, idx in enumerate(error_indices[:num_errors_to_show]):
        # Load image
        image, _ = test_dataset[idx]
        img_display = image.permute(1, 2, 0).numpy()
        
        # Plot
        axes[i].imshow(img_display)
        axes[i].set_title(f"True: {class_names[all_labels[idx]]}\nPred: {class_names[all_predictions[idx]]}\n({all_probabilities[idx, all_predictions[idx]]*100:.1f}% conf)",
                         fontsize=10)
        axes[i].axis('off')
    
    plt.suptitle('Misclassified Samples', fontsize=14, fontweight='bold')
    plt.tight_layout()
    plt.show()
else:
    print("ðŸŽ‰ Perfect accuracy! No errors to show.")

## 7. Confidence Distribution

In [None]:
# Get confidence scores
confidences = np.max(all_probabilities, axis=1)
correct_confidences = confidences[~errors]
incorrect_confidences = confidences[errors]

# Plot
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))

# Overall confidence distribution
ax1.hist(confidences, bins=20, alpha=0.7, color='blue', edgecolor='black')
ax1.set_xlabel('Confidence', fontsize=11)
ax1.set_ylabel('Count', fontsize=11)
ax1.set_title('Prediction Confidence Distribution', fontsize=12, fontweight='bold')
ax1.axvline(np.mean(confidences), color='red', linestyle='--', 
           label=f'Mean: {np.mean(confidences):.3f}')
ax1.legend()
ax1.grid(alpha=0.3)

# Correct vs incorrect
if len(incorrect_confidences) > 0:
    ax2.hist([correct_confidences, incorrect_confidences], bins=15, 
            label=['Correct', 'Incorrect'], color=['green', 'red'],
            alpha=0.6, edgecolor='black')
    ax2.set_xlabel('Confidence', fontsize=11)
    ax2.set_ylabel('Count', fontsize=11)
    ax2.set_title('Confidence: Correct vs Incorrect', fontsize=12, fontweight='bold')
    ax2.legend()
    ax2.grid(alpha=0.3)
else:
    ax2.text(0.5, 0.5, 'No errors!', ha='center', va='center', fontsize=16)
    ax2.set_xlim(0, 1)

plt.tight_layout()
plt.show()

print(f"\nAverage confidence (correct): {np.mean(correct_confidences):.3f}")
if len(incorrect_confidences) > 0:
    print(f"Average confidence (incorrect): {np.mean(incorrect_confidences):.3f}")

## Summary

Model evaluation complete!

**Key Metrics**:
- Test Accuracy: See output above
- Confusion Matrix: Visualized above
- Per-class performance: Check classification report

**Next Steps**:
- If accuracy is low, try:
  - More training data
  - Longer training
  - Data augmentation
- If ready, deploy model to production
- Continue to `04_predictions_visualization.ipynb`