# ECONTRAIL Detection - Model Evaluation

This notebook provides tools for evaluating contrail detection model performance.

## Setup

Make sure you have installed the package:
```bash
pip install -e .
```

In [None]:
# Import required libraries
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path

# Import ECONTRAIL detection package
from econtrail_detection import predict_contrails
from econtrail_detection.utils import (
    load_image,
    preprocess_image,
    calculate_metrics,
    save_prediction
)

print("Imports successful!")

## 1. Load Test Images and Ground Truth

Load test images and their corresponding ground truth masks.

In [None]:
# Define paths
test_images_dir = Path('test_images')
data_dir = Path('data')

# List available test images
test_images = list(test_images_dir.glob('*.png')) + list(test_images_dir.glob('*.jpg'))
print(f"Found {len(test_images)} test images")

# Display first test image if available
if test_images:
    sample_image = load_image(test_images[0])
    plt.figure(figsize=(10, 5))
    plt.imshow(sample_image)
    plt.title(f"Sample Test Image: {test_images[0].name}")
    plt.axis('off')
    plt.show()
else:
    print("No test images found. Add images to the 'test_images' directory.")

## 2. Run Predictions

Run the contrail detection model on test images.

In [None]:
# Configuration
MODEL_PATH = None  # Use default model, or specify path to custom model
THRESHOLD = 0.5
DEVICE = 'cpu'  # or 'cuda' if GPU is available

# Run predictions on test images
predictions = []

if test_images:
    print(f"Running predictions on {len(test_images)} images...")
    
    for img_path in test_images:
        pred = predict_contrails(
            str(img_path),
            model_path=MODEL_PATH,
            threshold=THRESHOLD,
            device=DEVICE
        )
        predictions.append(pred)
        print(f"  Processed: {img_path.name}")
    
    print("Predictions complete!")
else:
    print("No test images available for prediction.")

## 3. Visualize Predictions

Visualize the predictions alongside original images.

In [None]:
# Visualize predictions
if test_images and predictions:
    n_samples = min(3, len(test_images))  # Show up to 3 samples
    
    fig, axes = plt.subplots(n_samples, 2, figsize=(12, 4 * n_samples))
    if n_samples == 1:
        axes = axes.reshape(1, -1)
    
    for i in range(n_samples):
        # Original image
        image = load_image(test_images[i])
        axes[i, 0].imshow(image)
        axes[i, 0].set_title(f"Original: {test_images[i].name}")
        axes[i, 0].axis('off')
        
        # Prediction
        axes[i, 1].imshow(predictions[i], cmap='jet', alpha=0.7)
        axes[i, 1].set_title("Contrail Prediction")
        axes[i, 1].axis('off')
    
    plt.tight_layout()
    plt.show()
else:
    print("No predictions to visualize.")

## 4. Evaluate Performance (if ground truth available)

Calculate metrics if ground truth masks are available.

In [None]:
# Load ground truth masks (if available)
ground_truth_dir = data_dir / 'ground_truth'

if ground_truth_dir.exists():
    gt_masks = list(ground_truth_dir.glob('*.png')) + list(ground_truth_dir.glob('*.jpg'))
    
    if gt_masks and len(gt_masks) == len(predictions):
        print("Calculating metrics...\n")
        
        all_metrics = []
        for i, (pred, gt_path) in enumerate(zip(predictions, gt_masks)):
            gt = load_image(gt_path)
            if len(gt.shape) == 3:
                gt = gt[:, :, 0]  # Use first channel if multi-channel
            
            metrics = calculate_metrics(pred, gt)
            all_metrics.append(metrics)
            
            print(f"Image {i+1}: {test_images[i].name}")
            print(f"  Accuracy:  {metrics['accuracy']:.3f}")
            print(f"  Precision: {metrics['precision']:.3f}")
            print(f"  Recall:    {metrics['recall']:.3f}")
            print(f"  F1-Score:  {metrics['f1_score']:.3f}")
            print(f"  IoU:       {metrics['iou']:.3f}")
            print()
        
        # Calculate average metrics
        avg_metrics = {
            key: np.mean([m[key] for m in all_metrics])
            for key in all_metrics[0].keys()
        }
        
        print("="*50)
        print("AVERAGE METRICS")
        print("="*50)
        for key, value in avg_metrics.items():
            print(f"{key.capitalize():12s}: {value:.3f}")
    else:
        print(f"Ground truth available but count mismatch: {len(gt_masks)} GT vs {len(predictions)} predictions")
else:
    print(f"No ground truth directory found at: {ground_truth_dir}")
    print("Create 'data/ground_truth' directory and add ground truth masks to evaluate performance.")

## 5. Save Predictions

Save prediction masks to disk for further analysis.

In [None]:
# Create output directory
output_dir = Path('output/predictions')
output_dir.mkdir(parents=True, exist_ok=True)

# Save predictions
if predictions:
    print(f"Saving predictions to {output_dir}...")
    
    for i, (pred, img_path) in enumerate(zip(predictions, test_images)):
        output_path = output_dir / f"pred_{img_path.stem}.png"
        save_prediction(pred, output_path, colormap=True)
        print(f"  Saved: {output_path.name}")
    
    print("All predictions saved!")
else:
    print("No predictions to save.")

## Summary

This notebook demonstrated:
1. Loading test images
2. Running contrail detection predictions
3. Visualizing results
4. Evaluating performance (with ground truth)
5. Saving predictions

For more information, see the [research paper](https://doi.org/10.1109/TGRS.2025.3629628).