# Model Evaluation
## Endangered Species Classifier Performance Analysis

This notebook evaluates the trained model and generates comprehensive performance metrics.

In [None]:
# Import required libraries
import sys
sys.path.append('..')

import torch
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import classification_report, confusion_matrix, accuracy_score
import json

from src.multi_task_model import load_multi_task_model
from src.data_loader import create_dataloaders, load_species_data
from src.evaluate import *
from config.model_config import CONSERVATION_CLASSES, MODEL_PATHS

## Load Trained Model

In [None]:
# Load the trained model
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

model_path = MODEL_PATHS['multi_task_model']
print(f"\nLoading model from: {model_path}")

try:
    model = load_multi_task_model(model_path)
    model = model.to(device)
    print("✓ Model loaded successfully!")
except Exception as e:
    print(f"Error loading model: {e}")
    print("Please ensure the model has been trained and saved.")

## Load Test Data

In [None]:
# Load test dataset
print("Loading test data...")

try:
    dataset = load_species_data()
    _, _, test_loader = create_dataloaders(dataset)
    print(f"Test batches: {len(test_loader)}")
except Exception as e:
    print(f"Error loading data: {e}")

## Evaluate Model

In [None]:
# Run evaluation
print("Evaluating model on test set...")

predictions = evaluate_model(model, test_loader, device)
print("✓ Evaluation complete!")

## Calculate Metrics

In [None]:
# Calculate comprehensive metrics
metrics = calculate_metrics(predictions)

print("=" * 50)
print("MODEL PERFORMANCE METRICS")
print("=" * 50)
print(f"\nOverall Accuracy: {metrics['accuracy']:.4f} ({metrics['accuracy']*100:.2f}%)")
print(f"Precision: {metrics['precision']:.4f} ({metrics['precision']*100:.2f}%)")
print(f"Recall: {metrics['recall']:.4f} ({metrics['recall']*100:.2f}%)")
print(f"F1-Score: {metrics['f1']:.4f} ({metrics['f1']*100:.2f}%)")

## Per-Class Performance

In [None]:
# Display per-class metrics
print("\nPer-Class Performance:")
print("-" * 70)
print(f"{'Class':<5} {'Precision':<12} {'Recall':<12} {'F1-Score':<12} {'Support':<10}")
print("-" * 70)

for class_name in CONSERVATION_CLASSES:
    if class_name in metrics['per_class_metrics']:
        class_metrics = metrics['per_class_metrics'][class_name]
        print(f"{class_name:<5} "
              f"{class_metrics['precision']:<12.4f} "
              f"{class_metrics['recall']:<12.4f} "
              f"{class_metrics['f1-score']:<12.4f} "
              f"{class_metrics['support']:<10}")

## Confusion Matrix

In [None]:
# Generate and plot confusion matrix
plot_confusion_matrix(
    predictions['conservation_labels'],
    predictions['conservation_preds'],
    CONSERVATION_CLASSES,
    save_path='../visualizations/confusion_matrix.png'
)

print("Confusion matrix generated and saved!")

## Class Distribution

In [None]:
# Plot class distribution in test set
plot_class_distribution(
    predictions['conservation_labels'],
    save_path='../visualizations/class_distribution.png'
)

print("Class distribution plot saved!")

## Geographic Region Performance

In [None]:
# Analyze geographic region predictions
from sklearn.metrics import multilabel_confusion_matrix
from config.model_config import GEOGRAPHIC_REGIONS

geographic_preds_binary = (predictions['geographic_preds'] > 0.5).astype(int)
geographic_labels_binary = predictions['geographic_labels'].astype(int)

print("Geographic Region Prediction Performance:")
print("-" * 50)

for idx, region in enumerate(GEOGRAPHIC_REGIONS):
    true_labels = geographic_labels_binary[:, idx]
    pred_labels = geographic_preds_binary[:, idx]
    
    accuracy = accuracy_score(true_labels, pred_labels)
    print(f"{region:<20} Accuracy: {accuracy:.4f}")

## Save Results

In [None]:
# Save all results
save_results(metrics, predictions, output_dir='../results/')

print("\nAll evaluation results saved!")
print("- Model metrics (JSON)")
print("- Classification report (TXT)")
print("- Predictions (CSV)")
print("- Visualizations (PNG)")

## Summary

In [None]:
# Print final summary
print("\n" + "=" * 50)
print("EVALUATION SUMMARY")
print("=" * 50)
print(f"\n✓ Model evaluated on {len(predictions['conservation_labels'])} test samples")
print(f"✓ Overall accuracy: {metrics['accuracy']*100:.2f}%")
print(f"✓ Weighted F1-score: {metrics['f1']*100:.2f}%")
print(f"✓ Results saved to: ../results/")
print(f"✓ Visualizations saved to: ../visualizations/")
print("\nEvaluation complete!")