# Downstream Evaluation: Land Cover Classification from Satellite Data

Land cover classification from satellite imagery represents a critical downstream task for evaluating image restoration models. This task serves as an excellent testbed for restoration quality assessment because:

1. **Real-world Impact**: Satellite-based land cover classification directly supports environmental monitoring, urban planning, and climate research applications
2. **Multi-spectral Complexity**: Sentinel-2 satellite data contains 13 spectral bands, making it sensitive to restoration artifacts across different wavelengths
3. **Fine-grained Classification**: With 19 distinct land cover classes, the task requires preservation of subtle spectral signatures that could be lost during restoration
4. **Scale Sensitivity**: Classification performance depends on both local texture details and broader spatial patterns, testing restoration at multiple scales

This evaluation framework provides comprehensive tools for assessing how well restored satellite imagery maintains the critical spectral and spatial information needed for accurate land cover classification.

## Dataset and Model Information

This evaluation uses Sentinel-2 satellite images covering 19 land cover classes based on the CORINE Land Cover database:

**Class Labels:**
1. Urban fabric
2. Industrial or commercial units  
3. Arable land
4. Permanent crops
5. Pastures
6. Complex cultivation patterns
7. Land principally occupied by agriculture
8. Agro-forestry areas
9. Broad-leaved forest
10. Coniferous forest
11. Mixed forest
12. Natural grassland
13. Moors and heathland
14. Sclerophyllous vegetation
15. Transitional woodland/shrub
16. Beaches, dunes, sands
17. Inland waters
18. Coastal lagoons
19. Estuaries

### Required Files and Dependencies

**Model Weights:**
- Pre-trained models available on [Hugging Face Hub](https://huggingface.co/models?search=bigearthnet)
- Official BigEarthNet models: [BigEarthNet website](http://bigearth.net/)
- Supported formats: `.safetensors`, `.pth`, TensorFlow checkpoints

**Dataset Access:**
- BigEarthNet dataset: Available through official channels or processed versions on Hugging Face
- Requires registration for full dataset access

**Key Citations:**
- Sumbul, G., et al. (2019). "BigEarthNet: A Large-Scale Benchmark Archive for Remote Sensing Image Understanding." IGARSS 2019.
- Sumbul, G., et al. (2021). "BigEarthNet-MM: A Large-Scale, Multimodal, Multilabel Benchmark Archive for Remote Sensing Image Classification and Retrieval." IEEE Geoscience and Remote Sensing Magazine.

## Installation and Setup

This section handles the installation of all required dependencies for the satellite image evaluation pipeline. The setup includes deep learning frameworks (PyTorch), image processing libraries (PIL, OpenCV), scientific computing tools (NumPy, Pandas), visualization packages (Matplotlib, Seaborn), and evaluation metrics utilities (scikit-learn). 

The installation process automatically detects missing packages and installs them, ensuring a complete environment for running land cover classification evaluations on satellite imagery.

## Batch Evaluation Framework

This section provides the core functionality for running evaluation over batches of satellite images, saving outputs, and computing comprehensive evaluation metrics. The framework supports both single-image and batch processing modes with detailed performance tracking.

In [None]:
# Install required packages
import subprocess
import sys

packages_to_install = [
    'torch',
    'torchvision', 
    'safetensors',
    'pillow',
    'numpy',
    'pandas',
    'matplotlib',
    'seaborn',
    'scikit-learn',
    'tqdm'
]

for package in packages_to_install:
    try:
        __import__(package.replace('-', '_'))
        print(f"{package} already installed")
    except ImportError:
        print(f"Installing {package}...")
        subprocess.check_call([sys.executable, '-m', 'pip', 'install', package])

print("\n All packages installed successfully!")

In [None]:
# Import all required libraries
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import models, transforms
from safetensors.torch import load_file
from PIL import Image
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import (
    accuracy_score, precision_recall_fscore_support, 
    confusion_matrix, classification_report
)
from sklearn.preprocessing import LabelEncoder
import os
import json
import time
from datetime import datetime
from tqdm.auto import tqdm
from collections import defaultdict, Counter
import warnings
warnings.filterwarnings('ignore')

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)
if torch.cuda.is_available():
    torch.cuda.manual_seed(42)

print("All libraries imported successfully!")

## Class Definitions and Configuration

This section defines the 19 land cover classes from the BigEarthNet dataset based on the CORINE Land Cover classification system. These classes represent the complete range of land cover types found in European satellite imagery, from urban environments to natural ecosystems.

The configuration parameters are optimized for multi-spectral satellite image processing, including input dimensions that accommodate both RGB and 10-channel Sentinel-2 imagery. The batch size and processing parameters are set to balance computational efficiency with memory constraints during evaluation.

In [None]:
# class definitions
CLASSES = [
    "Urban fabric",
    "Industrial or commercial units", 
    "Arable land",
    "Permanent crops",
    "Pastures",
    "Complex cultivation patterns",
    "Land principally occupied by agriculture, with significant areas of natural vegetation",
    "Agro-forestry areas",
    "Broad-leaved forest",
    "Coniferous forest", 
    "Mixed forest",
    "Natural grassland",
    "Moors and heathland",
    "Sclerophyllous vegetation",
    "Transitional woodland/shrub",
    "Beaches, dunes, sands", 
    "Inland waters",
    "Coastal lagoons",
    "Estuaries"
]

# Class ID mapping
CLASS_ID_TO_NAME = {i: name for i, name in enumerate(CLASSES)}
CLASS_NAME_TO_ID = {name: i for i, name in enumerate(CLASSES)}

print(f"Number of classes: {len(CLASSES)}")
print("Class definitions loaded")

# Configuration for different evaluation scenarios
EVAL_CONFIG = {
    'input_size': (224, 224),
    'input_channels': 10,  # Multi-spectral channels
    'batch_size': 32,
    'num_workers': 4,
    'confidence_threshold': 0.5,
    'supported_formats': ['.jpg', '.jpeg', '.png', '.tif', '.tiff'],
    'model_architectures': ['resnet50', 'resnet101', 'wide_resnet50_2'],
    'checkpoint_formats': ['.safetensors', '.pth', '.pt']
}

print("Evaluation configuration set")

## Model Loading and Preprocessing Pipeline

This section implements a comprehensive preprocessing pipeline specifically designed for satellite imagery evaluation. The pipeline handles the complexities of multi-spectral satellite data while maintaining compatibility with standard RGB inputs.

**Key Features:**
- **Multi-spectral Support**: Converts RGB imagery to simulate 10-channel Sentinel-2 data by intelligently duplicating and modifying spectral bands
- **Flexible Architecture Loading**: Supports multiple CNN architectures (ResNet-50, ResNet-101, Wide ResNet-50) commonly used in remote sensing
- **Checkpoint Compatibility**: Handles various model checkpoint formats including SafeTensors and PyTorch formats
- **Adaptive Input Layers**: Automatically modifies model input layers to accommodate multi-spectral channels

The preprocessing strategy approximates additional spectral bands (NIR, SWIR, vegetation red edge) by applying mathematical transformations to RGB channels, enabling evaluation of models trained on full multi-spectral data using standard RGB imagery.

In [None]:
class SatelliteImageProcessor:
    """
    Advanced preprocessing pipeline for satellite imagery
    Supports multi-spectral (10-channel) and RGB (3-channel) inputs
    """
    
    def __init__(self, input_channels=10, target_size=(224, 224)):
        self.input_channels = input_channels
        self.target_size = target_size
        
        # Multi-spectral preprocessing (expands RGB to 10 channels)
        self.multispectral_transform = transforms.Compose([
            transforms.Resize(target_size),
            transforms.ToTensor(),
            transforms.Normalize(
                mean=[0.485, 0.456, 0.406],  # ImageNet normalization
                std=[0.229, 0.224, 0.225]
            )
        ])
        
        # Standard RGB preprocessing
        self.rgb_transform = transforms.Compose([
            transforms.Resize(target_size),
            transforms.ToTensor(),
            transforms.Normalize(
                mean=[0.485, 0.456, 0.406],
                std=[0.229, 0.224, 0.225]
            )
        ])
    
    def expand_rgb_to_multispectral(self, rgb_tensor):
        """Expand 3-channel RGB to 10-channel multi-spectral"""
        if rgb_tensor.shape[0] != 3:
            raise ValueError(f"Expected 3-channel input, got {rgb_tensor.shape[0]}")
        
        # Strategy: Duplicate RGB channels with slight variations
        # This simulates multi-spectral data for evaluation purposes
        channels = []
        
        # Original RGB channels
        channels.extend([rgb_tensor[0], rgb_tensor[1], rgb_tensor[2]])
        
        # Near-infrared approximations (modified red channel)
        channels.append(rgb_tensor[0] * 1.1)  # NIR1
        channels.append(rgb_tensor[0] * 0.9)  # NIR2
        
        # SWIR approximations (modified combinations)
        channels.append((rgb_tensor[0] + rgb_tensor[2]) * 0.5)  # SWIR1
        channels.append((rgb_tensor[1] + rgb_tensor[2]) * 0.6)  # SWIR2
        
        # Vegetation red edge (modified red/green)
        channels.append((rgb_tensor[0] + rgb_tensor[1]) * 0.7)  # VRE1
        channels.append((rgb_tensor[0] + rgb_tensor[1]) * 0.8)  # VRE2
        
        # Coastal aerosol (modified blue)
        channels.append(rgb_tensor[2] * 0.95)  # Coastal
        
        return torch.stack(channels)
    
    def process_image(self, image_path):
        """Process a single image"""
        try:
            # Load image
            image = Image.open(image_path).convert('RGB')
            
            # Apply basic preprocessing
            tensor = self.rgb_transform(image)
            
            # Expand to multi-spectral if needed
            if self.input_channels == 10:
                tensor = self.expand_rgb_to_multispectral(tensor)
            
            return tensor
            
        except Exception as e:
            print(f"Error processing {image_path}: {e}")
            return None

def load_satellite_model(checkpoint_path, architecture='resnet50', num_classes=19, input_channels=10):
    """
    Load a satellite image classification model from various checkpoint formats
    """
    print(f"Loading {architecture} model from {checkpoint_path}")
    
    # Initialize model architecture
    if architecture == 'resnet50':
        model = models.resnet50(pretrained=False)
    elif architecture == 'resnet101':
        model = models.resnet101(pretrained=False)
    elif architecture == 'wide_resnet50_2':
        model = models.wide_resnet50_2(pretrained=False)
    else:
        raise ValueError(f"Unsupported architecture: {architecture}")
    
    # Modify first layer for multi-spectral input
    if input_channels != 3:
        original_conv1 = model.conv1
        model.conv1 = nn.Conv2d(
            input_channels, 64, kernel_size=7, stride=2, padding=3, bias=False
        )
        
        # Initialize new conv1 weights
        with torch.no_grad():
            if input_channels == 10:
                # For 10-channel input, initialize with RGB weights + noise
                model.conv1.weight[:, :3, :, :] = original_conv1.weight
                # Initialize additional channels with small random values
                model.conv1.weight[:, 3:, :, :] = torch.randn(64, 7, 7, 7) * 0.01
    
    # Modify final layer for correct number of classes
    if hasattr(model, 'fc'):
        model.fc = nn.Linear(model.fc.in_features, num_classes)
    elif hasattr(model, 'classifier'):
        model.classifier = nn.Linear(model.classifier.in_features, num_classes)
    
    # Load weights
    try:
        if checkpoint_path.endswith('.safetensors'):
            weights = load_file(checkpoint_path)
            model.load_state_dict(weights, strict=False)
            print("Loaded weights from .safetensors file")
        elif checkpoint_path.endswith(('.pth', '.pt')):
            weights = torch.load(checkpoint_path, map_location='cpu')
            if 'state_dict' in weights:
                weights = weights['state_dict']
            model.load_state_dict(weights, strict=False)
            print("Loaded weights from PyTorch checkpoint")
        else:
            raise ValueError(f"Unsupported checkpoint format: {checkpoint_path}")
            
    except Exception as e:
        print(f"Warning: Could not load all weights - {e}")
        print("Proceeding with randomly initialized weights for missing parameters")
    
    model.eval()
    return model.to(device)

# Initialize processor
processor = SatelliteImageProcessor(
    input_channels=EVAL_CONFIG['input_channels'],
    target_size=EVAL_CONFIG['input_size']
)

print("Model loading utilities ready")

## Batch Inference Pipeline

This section implements an efficient batch processing system for running land cover classification on large datasets of satellite images. The pipeline is designed to handle both single images and batch processing scenarios while maintaining consistent preprocessing and evaluation standards.

**Pipeline Components:**
- **Batch Loading**: Efficient data loading with configurable batch sizes and parallel processing
- **Memory Management**: Optimized memory usage for processing large satellite image datasets
- **Progress Tracking**: Real-time progress monitoring with detailed timing information
- **Error Handling**: Robust error handling for corrupted or incompatible image formats
- **Output Management**: Structured output format for downstream analysis and metric computation

The batch inference system supports various image formats commonly used in satellite imagery (TIFF, JPEG, PNG) and automatically handles preprocessing, inference, and result aggregation across the entire dataset.

In [None]:
def save_results_to_csv(results, output_file):
    """Save inference results to CSV format"""
    
    rows = []
    for img_name, prediction in results['predictions'].items():
        if 'error' not in prediction:
            rows.append({
                'image_name': img_name,
                'predicted_class': prediction['predicted_class'],
                'predicted_class_id': prediction['predicted_class_id'],
                'confidence_score': prediction['confidence_score'],
                'inference_time': prediction.get('inference_time', 0)
            })
    
    df = pd.DataFrame(rows)
    df.to_csv(output_file, index=False)

def run_batch_inference(model, input_folder, output_file=None, max_images=None):
    """
    Run batch inference on a folder of satellite images
    
    Args:
        model: Loaded PyTorch model
        input_folder: Path to folder containing images
        output_file: Optional path to save results CSV
        max_images: Optional limit on number of images to process
    
    Returns:
        dict: Comprehensive results dictionary
    """
    
    print(f"Starting batch inference on: {input_folder}")
    
    # Find all image files
    supported_extensions = EVAL_CONFIG['supported_formats']
    image_files = []
    
    for ext in supported_extensions:
        image_files.extend([
            f for f in os.listdir(input_folder) 
            if f.lower().endswith(ext.lower())
        ])
    
    if max_images:
        image_files = image_files[:max_images]
    
    print(f"Found {len(image_files)} images to process")
    
    if len(image_files) == 0:
        print("No supported image files found!")
        return None
    
    # Initialize results tracking
    results = {
        'predictions': {},
        'metadata': {
            'model_name': model.__class__.__name__,
            'num_classes': len(CLASSES),
            'class_names': CLASSES,
            'input_resolution': EVAL_CONFIG['input_size'],
            'input_channels': EVAL_CONFIG['input_channels'],
            'timestamp': datetime.now().isoformat(),
            'input_folder': input_folder
        },
        'performance': {
            'total_images': len(image_files),
            'successful_predictions': 0,
            'failed_predictions': 0,
            'inference_times': [],
            'total_processing_time': 0
        }
    }
    
    start_time = time.time()
    
    # Process images with progress bar
    for img_name in tqdm(image_files, desc="Processing images"):
        img_path = os.path.join(input_folder, img_name)
        
        try:
            # Start timing
            inference_start = time.time()
            
            # Process image
            input_tensor = processor.process_image(img_path)
            if input_tensor is None:
                results['performance']['failed_predictions'] += 1
                continue
                
            # Add batch dimension and move to device
            input_batch = input_tensor.unsqueeze(0).to(device)
            
            # Run inference
            with torch.no_grad():
                raw_output = model(input_batch)
                probabilities = F.softmax(raw_output, dim=1)
                confidence, predicted = torch.max(probabilities, 1)
                
                # Extract results
                pred_idx = predicted.item()
                confidence_score = confidence.item()
                raw_logits = raw_output.squeeze().cpu().numpy().tolist()
                probs = probabilities.squeeze().cpu().numpy().tolist()
            
            # Record timing
            inference_time = time.time() - inference_start
            results['performance']['inference_times'].append(inference_time)
            
            # Store prediction results
            results['predictions'][img_name] = {
                'predicted_class': CLASSES[pred_idx],
                'predicted_class_id': pred_idx,
                'confidence_score': confidence_score,
                'raw_logits': raw_logits,
                'probabilities': probs,
                'inference_time': inference_time
            }
            
            results['performance']['successful_predictions'] += 1
            
        except Exception as e:
            print(f"Error processing {img_name}: {e}")
            results['predictions'][img_name] = {
                'error': str(e),
                'predicted_class': 'ERROR',
                'predicted_class_id': -1,
                'confidence_score': 0.0
            }
            results['performance']['failed_predictions'] += 1
    
    # Calculate final performance metrics
    total_time = time.time() - start_time
    results['performance']['total_processing_time'] = total_time
    results['performance']['average_inference_time'] = (
        np.mean(results['performance']['inference_times']) 
        if results['performance']['inference_times'] else 0
    )
    results['performance']['images_per_second'] = (
        results['performance']['successful_predictions'] / total_time
        if total_time > 0 else 0
    )
    
    # Print summary
    print(f"\n{'='*50}")
    print("BATCH INFERENCE COMPLETE")
    print(f"{'='*50}")
    print(f"Total images: {results['performance']['total_images']}")
    print(f"Successful: {results['performance']['successful_predictions']}")
    print(f"Failed: {results['performance']['failed_predictions']}")
    print(f"Success rate: {results['performance']['successful_predictions']/results['performance']['total_images']*100:.1f}%")
    print(f"Average inference time: {results['performance']['average_inference_time']:.3f}s")
    print(f"Processing speed: {results['performance']['images_per_second']:.1f} images/sec")
    print(f"Total time: {total_time:.1f}s")
    
    # Save results to CSV if requested
    if output_file:
        save_results_to_csv(results, output_file)
        print(f"Results saved to: {output_file}")
    
    return results
    
print("Batch inference pipeline ready")

## Evaluation Metrics and Classification Assessment

This section implements a comprehensive evaluation framework for assessing land cover classification performance. The evaluation system computes multiple complementary metrics to provide a thorough understanding of model performance across different aspects of the classification task.

### Primary Evaluation Metrics

**Accuracy Metrics:**
- **Overall Accuracy**: The fraction of correctly classified pixels/images across all classes
- **Top-k Accuracy**: Measures if the correct class appears in the top-k predictions (k=3,5), useful for understanding near-miss predictions
- **Per-class Accuracy**: Individual accuracy scores for each of the 19 land cover classes

**Precision, Recall, and F1-Score:**
- **Precision**: For each class, the fraction of predicted instances that are actually correct (TP/(TP+FP))
- **Recall (Sensitivity)**: For each class, the fraction of actual instances that are correctly identified (TP/(TP+FN))
- **F1-Score**: Harmonic mean of precision and recall, providing a balanced measure of performance
- **Macro-averaged**: Unweighted mean across all classes, treating each class equally
- **Weighted-averaged**: Mean weighted by class support, accounting for class imbalance

**Confusion Matrix Analysis:**
- **Raw Confusion Matrix**: Shows the distribution of predictions vs ground truth for detailed error analysis
- **Normalized Confusion Matrix**: Percentages normalized by true class, highlighting misclassification patterns
- **Class Confusion Patterns**: Identifies which classes are commonly confused with each other

**Confidence Analysis:**
- **Prediction Confidence**: Distribution of model confidence scores across all predictions
- **Confidence by Correctness**: Comparison of confidence scores between correct and incorrect predictions
- **Confidence Calibration**: Assessment of whether high confidence correlates with correct predictions

### Advanced Evaluation Features

**Class Distribution Analysis:**
- Compares the distribution of true labels vs predicted labels to identify prediction biases
- Identifies under-predicted and over-predicted classes

**Error Pattern Analysis:**
- Systematic analysis of common misclassification patterns
- Identifies challenging class pairs and potential systematic errors

The evaluation framework is designed to handle class imbalance common in satellite imagery datasets and provides insights for model improvement and deployment decisions.

In [None]:
class ClassificationEvaluator:
    """
    Comprehensive evaluation metrics for land cover classification
    """
    
    def __init__(self, class_names=None):
        self.class_names = class_names or CLASSES
        self.num_classes = len(self.class_names)
    
    def load_ground_truth(self, ground_truth_file):
        """
        Load ground truth labels from various formats
        
        Supported formats:
        - CSV with columns: image_name, true_class_id or true_class
        - JSON with image_name -> class mapping
        - TXT with format: image_name: class_id (one per line)
        """
        
        if ground_truth_file.endswith('.csv'):
            df = pd.read_csv(ground_truth_file)
            ground_truth = {}
            
            if 'true_class_id' in df.columns:
                for _, row in df.iterrows():
                    ground_truth[row['image_name']] = {
                        'true_class_id': int(row['true_class_id']),
                        'true_class': self.class_names[int(row['true_class_id'])]
                    }
            elif 'true_class' in df.columns:
                for _, row in df.iterrows():
                    class_id = CLASS_NAME_TO_ID.get(row['true_class'], -1)
                    ground_truth[row['image_name']] = {
                        'true_class_id': class_id,
                        'true_class': row['true_class']
                    }
            else:
                raise ValueError("CSV must contain 'true_class_id' or 'true_class' column")
                
        elif ground_truth_file.endswith('.json'):
            with open(ground_truth_file, 'r') as f:
                data = json.load(f)
            ground_truth = {}
            
            for img_name, label_info in data.items():
                if isinstance(label_info, dict):
                    ground_truth[img_name] = label_info
                else:
                    # Assume it's just the class name or ID
                    if isinstance(label_info, str):
                        class_id = CLASS_NAME_TO_ID.get(label_info, -1)
                        ground_truth[img_name] = {
                            'true_class_id': class_id,
                            'true_class': label_info
                        }
                    else:
                        ground_truth[img_name] = {
                            'true_class_id': int(label_info),
                            'true_class': self.class_names[int(label_info)]
                        }
                        
        elif ground_truth_file.endswith('.txt'):
            ground_truth = {}
            
            with open(ground_truth_file, 'r') as f:
                for line_num, line in enumerate(f, 1):
                    line = line.strip()
                    if not line:  # Skip empty lines
                        continue
                    
                    try:
                        # Parse format: "image_name: class_id"
                        if ':' not in line:
                            print(f"Warning: Line {line_num} missing colon separator: {line}")
                            continue
                            
                        img_name, class_id_str = line.split(':', 1)
                        img_name = img_name.strip()
                        class_id_str = class_id_str.strip()
                        
                        # Convert class_id to integer
                        class_id = int(class_id_str)
                        
                        # Validate class_id is within valid range
                        if 0 <= class_id < len(self.class_names):
                            ground_truth[img_name] = {
                                'true_class_id': class_id,
                                'true_class': self.class_names[class_id]
                            }
                        else:
                            print(f"Warning: Line {line_num} has invalid class_id {class_id} (must be 0-{len(self.class_names)-1}): {line}")
                            
                    except ValueError as e:
                        print(f"Warning: Line {line_num} has invalid format: {line} - {e}")
                        continue
                    except Exception as e:
                        print(f"Warning: Error parsing line {line_num}: {line} - {e}")
                        continue
            
            print(f"Loaded {len(ground_truth)} ground truth entries from TXT file")
                        
        else:
            raise ValueError("Unsupported ground truth format. Use .csv, .json, or .txt")
        
        return ground_truth
    
    def compute_metrics(self, predictions, ground_truth):
        """
        Compute comprehensive classification metrics
        
        Args:
            predictions: Dictionary from run_batch_inference
            ground_truth: Dictionary with true labels
            
        Returns:
            Dictionary with all computed metrics
        """
        
        # Align predictions and ground truth
        y_true = []
        y_pred = []
        y_confidence = []
        matched_images = []
        
        for img_name in predictions['predictions']:
            if img_name in ground_truth and 'error' not in predictions['predictions'][img_name]:
                true_class_id = ground_truth[img_name]['true_class_id']
                pred_class_id = predictions['predictions'][img_name]['predicted_class_id']
                confidence = predictions['predictions'][img_name]['confidence_score']
                
                if true_class_id >= 0 and pred_class_id >= 0:  # Valid labels
                    y_true.append(true_class_id)
                    y_pred.append(pred_class_id)
                    y_confidence.append(confidence)
                    matched_images.append(img_name)
        
        if len(y_true) == 0:
            raise ValueError("No matching images found between predictions and ground truth")
        
        print(f"Evaluating {len(y_true)} images with ground truth labels")
        
        # Convert to numpy arrays
        y_true = np.array(y_true)
        y_pred = np.array(y_pred)
        y_confidence = np.array(y_confidence)
        
        # Compute basic metrics
        accuracy = accuracy_score(y_true, y_pred)
        
        # Per-class metrics
        precision, recall, f1, support = precision_recall_fscore_support(
            y_true, y_pred, average=None, zero_division=0
        )
        
        # Macro averages
        macro_precision = precision_recall_fscore_support(
            y_true, y_pred, average='macro', zero_division=0
        )
        
        # Weighted averages
        weighted_precision = precision_recall_fscore_support(
            y_true, y_pred, average='weighted', zero_division=0
        )
        
        # Confusion matrix
        cm = confusion_matrix(y_true, y_pred, labels=range(self.num_classes))
        
        # Get unique classes present in the data
        unique_classes = sorted(list(set(y_true) | set(y_pred)))
        present_class_names = [self.class_names[i] for i in unique_classes if i < len(self.class_names)]
        
        # Classification report (only for classes present in data)
        class_report = classification_report(
            y_true, y_pred, 
            labels=unique_classes,
            target_names=present_class_names,
            zero_division=0,
            output_dict=True
        )
        
        # Top-k accuracy (top-3, top-5)
        top3_acc = self.compute_topk_accuracy(predictions, ground_truth, k=3)
        top5_acc = self.compute_topk_accuracy(predictions, ground_truth, k=5)
        
        # Class distribution analysis
        class_distribution = {
            'true_distribution': np.bincount(y_true, minlength=self.num_classes),
            'pred_distribution': np.bincount(y_pred, minlength=self.num_classes)
        }
        
        # Confidence analysis
        confidence_stats = {
            'mean_confidence': float(np.mean(y_confidence)),
            'std_confidence': float(np.std(y_confidence)),
            'min_confidence': float(np.min(y_confidence)),
            'max_confidence': float(np.max(y_confidence)),
            'median_confidence': float(np.median(y_confidence))
        }
        
        # Confidence by correctness
        correct_mask = (y_true == y_pred)
        confidence_by_correctness = {
            'correct_predictions': {
                'mean_confidence': float(np.mean(y_confidence[correct_mask])) if np.any(correct_mask) else 0,
                'count': int(np.sum(correct_mask))
            },
            'incorrect_predictions': {
                'mean_confidence': float(np.mean(y_confidence[~correct_mask])) if np.any(~correct_mask) else 0,
                'count': int(np.sum(~correct_mask))
            }
        }
        
        # Compile all metrics
        metrics = {
            'overall_metrics': {
                'accuracy': float(accuracy),
                'top3_accuracy': float(top3_acc),
                'top5_accuracy': float(top5_acc),
                'macro_precision': float(macro_precision[0]),
                'macro_recall': float(macro_precision[1]),
                'macro_f1': float(macro_precision[2]),
                'weighted_precision': float(weighted_precision[0]),
                'weighted_recall': float(weighted_precision[1]),
                'weighted_f1': float(weighted_precision[2]),
                'num_samples': len(y_true)
            },
            'per_class_metrics': {
                self.class_names[i]: {
                    'precision': float(precision[i]) if i < len(precision) else 0.0,
                    'recall': float(recall[i]) if i < len(recall) else 0.0,
                    'f1_score': float(f1[i]) if i < len(f1) else 0.0,
                    'support': int(support[i]) if i < len(support) else 0
                }
                for i in range(self.num_classes)
            },
            'confusion_matrix': cm.tolist(),
            'classification_report': class_report,
            'class_distribution': {
                'true_distribution': class_distribution['true_distribution'].tolist(),
                'pred_distribution': class_distribution['pred_distribution'].tolist()
            },
            'confidence_analysis': confidence_stats,
            'confidence_by_correctness': confidence_by_correctness,
            'matched_images': matched_images
        }
        
        return metrics
    
    def compute_topk_accuracy(self, predictions, ground_truth, k=3):
        """Compute top-k accuracy"""
        
        correct_topk = 0
        total = 0
        
        for img_name in predictions['predictions']:
            if img_name in ground_truth and 'error' not in predictions['predictions'][img_name]:
                true_class_id = ground_truth[img_name]['true_class_id']
                probs = predictions['predictions'][img_name]['probabilities']
                
                if true_class_id >= 0 and len(probs) >= k:
                    # Get top-k predictions
                    top_k_indices = np.argsort(probs)[-k:]
                    if true_class_id in top_k_indices:
                        correct_topk += 1
                    total += 1
        
        return correct_topk / total if total > 0 else 0
    
    def print_evaluation_summary(self, metrics):
        """Print a comprehensive evaluation summary"""
        
        print(f"\n{'='*60}")
        print("CLASSIFICATION EVALUATION SUMMARY")
        print(f"{'='*60}")
        
        # Overall metrics
        overall = metrics['overall_metrics']
        print(f"\nOVERALL PERFORMANCE:")
        print(f"  • Accuracy: {overall['accuracy']:.4f} ({overall['accuracy']*100:.2f}%)")
        print(f"  • Top-3 Accuracy: {overall['top3_accuracy']:.4f} ({overall['top3_accuracy']*100:.2f}%)")
        print(f"  • Top-5 Accuracy: {overall['top5_accuracy']:.4f} ({overall['top5_accuracy']*100:.2f}%)")
        print(f"  • Macro F1-Score: {overall['macro_f1']:.4f}")
        print(f"  • Weighted F1-Score: {overall['weighted_f1']:.4f}")
        print(f"  • Number of samples: {overall['num_samples']}")
        
        # Confidence analysis
        conf = metrics['confidence_analysis']
        print(f"\nCONFIDENCE ANALYSIS:")
        print(f"  • Mean confidence: {conf['mean_confidence']:.4f}")
        print(f"  • Confidence range: [{conf['min_confidence']:.4f}, {conf['max_confidence']:.4f}]")
        
        conf_correct = metrics['confidence_by_correctness']
        print(f"  • Correct predictions: {conf_correct['correct_predictions']['mean_confidence']:.4f} (n={conf_correct['correct_predictions']['count']})")
        print(f"  • Incorrect predictions: {conf_correct['incorrect_predictions']['mean_confidence']:.4f} (n={conf_correct['incorrect_predictions']['count']})")
        
        # Top and bottom performing classes
        per_class = metrics['per_class_metrics']
        f1_scores = [(name, data['f1_score']) for name, data in per_class.items() if data['support'] > 0]
        f1_scores.sort(key=lambda x: x[1], reverse=True)
        
        print(f"\nTOP 5 PERFORMING CLASSES (by F1-score):")
        for i, (class_name, f1) in enumerate(f1_scores[:5]):
            support = per_class[class_name]['support']
            print(f"  {i+1}. {class_name}: {f1:.4f} (n={support})")
        
        print(f"\nBOTTOM 5 PERFORMING CLASSES (by F1-score):")
        for i, (class_name, f1) in enumerate(f1_scores[-5:]):
            support = per_class[class_name]['support']
            print(f"  {len(f1_scores)-i}. {class_name}: {f1:.4f} (n={support})")
        
        print(f"\n{'='*60}")

# Initialize evaluator
evaluator = ClassificationEvaluator(CLASSES)
print("Classification evaluator ready")

## Visualization and Analysis Tools

This section provides comprehensive visualization capabilities for analyzing classification performance and understanding model behavior. The visualization suite includes confusion matrices, performance distributions, confidence analysis plots, and detailed error analysis visualizations.

**Visualization Components:**
- **Confusion Matrix Heatmaps**: Normalized confusion matrices with class-wise performance visualization
- **Per-Class Performance Charts**: Bar charts showing precision, recall, and F1-scores for each land cover class
- **Confidence Distribution Plots**: Histograms and box plots showing confidence score distributions
- **Class Distribution Comparisons**: Side-by-side comparison of true vs predicted class distributions
- **Error Analysis Visualizations**: Detailed plots highlighting misclassification patterns and systematic errors

These visualizations are essential for understanding model strengths and weaknesses, identifying problematic classes, and making informed decisions about model deployment and improvement strategies.

In [None]:
def plot_confusion_matrix(metrics, figsize=(12, 10), save_path=None):
    """Plot confusion matrix with proper labels"""
    
    cm = np.array(metrics['confusion_matrix'])
    class_names = CLASSES
    
    plt.figure(figsize=figsize)
    
    # Normalize confusion matrix
    cm_normalized = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
    cm_normalized = np.nan_to_num(cm_normalized)  # Handle divide by zero
    
    # Create heatmap
    sns.heatmap(cm_normalized, 
                annot=True, 
                fmt='.2f', 
                cmap='Blues',
                xticklabels=[name[:20] + '...' if len(name) > 20 else name for name in class_names],
                yticklabels=[name[:20] + '...' if len(name) > 20 else name for name in class_names],
                cbar_kws={'label': 'Normalized Frequency'})
    
    plt.title('Confusion Matrix - Land Cover Classification\n(Normalized by True Class)', 
              fontsize=14, fontweight='bold')
    plt.xlabel('Predicted Class', fontsize=12)
    plt.ylabel('True Class', fontsize=12)
    plt.xticks(rotation=45, ha='right')
    plt.yticks(rotation=0)
    plt.tight_layout()
    
    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
        print(f"Confusion matrix saved to: {save_path}")
    
    plt.show()

def plot_class_performance(metrics, figsize=(15, 8), save_path=None):
    """Plot per-class performance metrics"""
    
    per_class = metrics['per_class_metrics']
    
    # Extract data
    classes = list(per_class.keys())
    f1_scores = [per_class[cls]['f1_score'] for cls in classes]
    precisions = [per_class[cls]['precision'] for cls in classes]
    recalls = [per_class[cls]['recall'] for cls in classes]
    supports = [per_class[cls]['support'] for cls in classes]
    
    # Create subplot
    fig, (ax1, ax2) = plt.subplots(2, 1, figsize=figsize)
    
    # Plot 1: F1-scores with support size
    x_pos = np.arange(len(classes))
    bars = ax1.bar(x_pos, f1_scores, alpha=0.7, color='skyblue', edgecolor='navy')
    
    # Color bars by support size
    max_support = max(supports) if supports else 1
    for i, (bar, support) in enumerate(zip(bars, supports)):
        color_intensity = support / max_support
        bar.set_color(plt.cm.viridis(color_intensity))
    
    ax1.set_xlabel('Land Cover Classes')
    ax1.set_ylabel('F1-Score')
    ax1.set_title('Per-Class F1-Scores (Color intensity = Support size)', fontweight='bold')
    ax1.set_xticks(x_pos)
    ax1.set_xticklabels([cls[:15] + '...' if len(cls) > 15 else cls for cls in classes], 
                        rotation=45, ha='right')
    ax1.grid(axis='y', alpha=0.3)
    ax1.set_ylim(0, 1)
    
    # Plot 2: Precision vs Recall scatter
    scatter = ax2.scatter(recalls, precisions, s=[s*2 for s in supports], 
                         alpha=0.6, c=supports, cmap='viridis')
    
    # Add class labels for points with good performance or high support
    for i, cls in enumerate(classes):
        if f1_scores[i] > 0.5 or supports[i] > np.percentile(supports, 75):
            ax2.annotate(cls[:10], (recalls[i], precisions[i]), 
                        xytext=(5, 5), textcoords='offset points', 
                        fontsize=8, alpha=0.7)
    
    ax2.set_xlabel('Recall')
    ax2.set_ylabel('Precision')
    ax2.set_title('Precision vs Recall (Bubble size = Support)', fontweight='bold')
    ax2.grid(True, alpha=0.3)
    ax2.set_xlim(0, 1)
    ax2.set_ylim(0, 1)
    
    # Add diagonal line
    ax2.plot([0, 1], [0, 1], 'k--', alpha=0.5, linewidth=1)
    
    # Add colorbar
    cbar = plt.colorbar(scatter, ax=ax2)
    cbar.set_label('Support (Number of samples)')
    
    plt.tight_layout()
    
    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
        print(f"Class performance plot saved to: {save_path}")
    
    plt.show()

def plot_confidence_analysis(metrics, figsize=(12, 8), save_path=None):
    """Plot confidence score analysis"""
    
    # This would need actual confidence data from predictions
    # For now, create a placeholder visualization framework
    
    fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=figsize)
    
    # Confidence statistics from metrics
    conf_stats = metrics['confidence_analysis']
    conf_by_correct = metrics['confidence_by_correctness']
    
    # Plot 1: Basic confidence statistics
    stats_names = ['Mean', 'Median', 'Min', 'Max']
    stats_values = [
        conf_stats['mean_confidence'],
        conf_stats['median_confidence'], 
        conf_stats['min_confidence'],
        conf_stats['max_confidence']
    ]
    
    ax1.bar(stats_names, stats_values, color=['lightblue', 'lightgreen', 'lightcoral', 'lightyellow'])
    ax1.set_title('Confidence Score Statistics', fontweight='bold')
    ax1.set_ylabel('Confidence Score')
    ax1.set_ylim(0, 1)
    ax1.grid(axis='y', alpha=0.3)
    
    # Plot 2: Confidence by correctness
    correct_conf = conf_by_correct['correct_predictions']['mean_confidence']
    incorrect_conf = conf_by_correct['incorrect_predictions']['mean_confidence']
    correct_count = conf_by_correct['correct_predictions']['count']
    incorrect_count = conf_by_correct['incorrect_predictions']['count']
    
    bars = ax2.bar(['Correct\nPredictions', 'Incorrect\nPredictions'], 
                   [correct_conf, incorrect_conf], 
                   color=['green', 'red'], alpha=0.7)
    
    # Add count labels on bars
    ax2.text(0, correct_conf + 0.02, f'n={correct_count}', ha='center', fontweight='bold')
    ax2.text(1, incorrect_conf + 0.02, f'n={incorrect_count}', ha='center', fontweight='bold')
    
    ax2.set_title('Mean Confidence by Prediction Correctness', fontweight='bold')
    ax2.set_ylabel('Mean Confidence Score')
    ax2.set_ylim(0, 1)
    ax2.grid(axis='y', alpha=0.3)
    
    # Plot 3: Class distribution comparison
    true_dist = metrics['class_distribution']['true_distribution']
    pred_dist = metrics['class_distribution']['pred_distribution']
    
    x_pos = np.arange(len(true_dist))
    width = 0.35
    
    ax3.bar(x_pos - width/2, true_dist, width, label='True Distribution', alpha=0.7)
    ax3.bar(x_pos + width/2, pred_dist, width, label='Predicted Distribution', alpha=0.7)
    
    ax3.set_title('Class Distribution: True vs Predicted', fontweight='bold')
    ax3.set_xlabel('Class Index')
    ax3.set_ylabel('Count')
    ax3.legend()
    ax3.grid(axis='y', alpha=0.3)
    
    # Plot 4: Overall accuracy summary
    overall = metrics['overall_metrics']
    accuracy_metrics = {
        'Top-1': overall['accuracy'],
        'Top-3': overall['top3_accuracy'],
        'Top-5': overall['top5_accuracy']
    }
    
    bars = ax4.bar(accuracy_metrics.keys(), accuracy_metrics.values(), 
                   color=['darkblue', 'blue', 'lightblue'])
    
    # Add percentage labels on bars
    for bar, acc in zip(bars, accuracy_metrics.values()):
        height = bar.get_height()
        ax4.text(bar.get_x() + bar.get_width()/2., height + 0.01,
                f'{acc*100:.1f}%', ha='center', va='bottom', fontweight='bold')
    
    ax4.set_title('Top-K Accuracy Comparison', fontweight='bold')
    ax4.set_ylabel('Accuracy')
    ax4.set_ylim(0, 1)
    ax4.grid(axis='y', alpha=0.3)
    
    plt.tight_layout()
    
    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
        print(f"Confidence analysis plot saved to: {save_path}")
    
    plt.show()

def generate_evaluation_report(metrics, output_dir='evaluation_results'):
    """Generate a comprehensive evaluation report with visualizations"""
    
    # Create output directory
    os.makedirs(output_dir, exist_ok=True)
    
    print(f"Generating comprehensive evaluation report in: {output_dir}")
    
    # Generate visualizations
    plot_confusion_matrix(metrics, save_path=os.path.join(output_dir, 'confusion_matrix.png'))
    plot_class_performance(metrics, save_path=os.path.join(output_dir, 'class_performance.png'))
    plot_confidence_analysis(metrics, save_path=os.path.join(output_dir, 'confidence_analysis.png'))
    
    # Save detailed metrics as JSON
    with open(os.path.join(output_dir, 'detailed_metrics.json'), 'w') as f:
        json.dump(metrics, f, indent=2)
    
    # Generate summary report
    with open(os.path.join(output_dir, 'evaluation_summary.txt'), 'w') as f:
        f.write("LAND COVER CLASSIFICATION EVALUATION REPORT\n")
        f.write("=" * 50 + "\n\n")
        f.write(f"Generated on: {datetime.now().isoformat()}\n\n")
        
        # Overall metrics
        overall = metrics['overall_metrics']
        f.write("OVERALL PERFORMANCE:\n")
        f.write(f"  Accuracy: {overall['accuracy']:.4f} ({overall['accuracy']*100:.2f}%)\n")
        f.write(f"  Top-3 Accuracy: {overall['top3_accuracy']:.4f} ({overall['top3_accuracy']*100:.2f}%)\n")
        f.write(f"  Top-5 Accuracy: {overall['top5_accuracy']:.4f} ({overall['top5_accuracy']*100:.2f}%)\n")
        f.write(f"  Macro F1-Score: {overall['macro_f1']:.4f}\n")
        f.write(f"  Weighted F1-Score: {overall['weighted_f1']:.4f}\n")
        f.write(f"  Number of samples: {overall['num_samples']}\n\n")
        
        # Per-class performance summary
        per_class = metrics['per_class_metrics']
        f1_scores = [(name, data['f1_score'], data['support']) for name, data in per_class.items()]
        f1_scores.sort(key=lambda x: x[1], reverse=True)
        
        f.write("TOP 10 PERFORMING CLASSES:\n")
        for i, (class_name, f1, support) in enumerate(f1_scores[:10]):
            f.write(f"  {i+1:2d}. {class_name:<45} F1: {f1:.4f} (n={support})\n")
        
        f.write("\nBOTTOM 10 PERFORMING CLASSES:\n")
        for i, (class_name, f1, support) in enumerate(f1_scores[-10:]):
            rank = len(f1_scores) - 9 + i
            f.write(f"  {rank:2d}. {class_name:<45} F1: {f1:.4f} (n={support})\n")
    
    print(f"Evaluation report generated in: {output_dir}")
    print("  - confusion_matrix.png")
    print("  - class_performance.png")
    print("  - confidence_analysis.png")
    print("  - detailed_metrics.json")
    print("  - evaluation_summary.txt")
    print("Visualization and analysis tools ready")

## Complete Evaluation Pipeline Demo

This section provides a complete end-to-end example of running the satellite image land cover classification evaluation. The pipeline demonstrates the full workflow from model loading through batch inference to comprehensive evaluation and visualization.

**Pipeline Steps:**
1. **Model Loading**: Load pre-trained models with automatic architecture detection and weight loading
2. **Batch Inference**: Process entire datasets with efficient batch processing and progress tracking
3. **Ground Truth Alignment**: Match predictions with ground truth labels from various file formats
4. **Metric Computation**: Calculate comprehensive evaluation metrics including accuracy, precision, recall, and confidence analysis
5. **Visualization Generation**: Create detailed plots and visualizations for performance analysis
6. **Report Generation**: Generate comprehensive evaluation reports with summaries and detailed results

The example below shows how to adapt the pipeline for your specific model checkpoints, datasets, and evaluation requirements. Simply modify the file paths and configuration parameters to match your setup.

In [None]:
# EXAMPLE USAGE: Complete Evaluation Pipeline
# Modify these paths to match your setup

# Configuration
MODEL_CHECKPOINT = "model.safetensors"  # Path to your model weights
INPUT_FOLDER = "test_images"            # Folder containing test images  
GROUND_TRUTH_FILE = "Satellite_Downstream/ground_truth.txt"  # Ground truth labels (optional)
OUTPUT_DIR = "evaluation_results"       # Where to save results
print(os.path.exists("Downloads/checkpoint_ResNet50/checkpoint_WideResNet_ECA"))
# Check if files exist
files_exist = {
    "model": os.path.exists(MODEL_CHECKPOINT),
    "input_folder": os.path.exists(INPUT_FOLDER),
    "ground_truth": os.path.exists(GROUND_TRUTH_FILE) if GROUND_TRUTH_FILE else False
}

print("FILE STATUS CHECK:")
print(f"  Model checkpoint: {'FOUND' if files_exist['model'] else 'MISSING'} {MODEL_CHECKPOINT}")
print(f"  Input folder: {'FOUND' if files_exist['input_folder'] else 'MISSING'} {INPUT_FOLDER}")
print(f"  Ground truth: {'FOUND' if files_exist['ground_truth'] else 'MISSING'} {GROUND_TRUTH_FILE}")

if not files_exist['model']:
    print("\nModel checkpoint not found. Please:")
    print("   1. Download a model from Hugging Face or research repository")
    print("   2. Update MODEL_CHECKPOINT variable with correct path")
    print("   3. Supported formats: .safetensors, .pth, .pt")

if not files_exist['input_folder']:
    print("\nInput folder not found. Please:")
    print("   1. Create a folder with satellite images")
    print("   2. Update INPUT_FOLDER variable with correct path")
    print("   3. Supported formats: .jpg, .jpeg, .png, .tif, .tiff")

print("\n" + "="*50)
print("Ready to run evaluation pipeline!")
print("Uncomment and run the cells below when files are ready.")
print("="*50)

In [None]:
print("Loading model...")
model = load_satellite_model(
    checkpoint_path=MODEL_CHECKPOINT,
    architecture='resnet50',  # or 'resnet101', 'wide_resnet50_2'
    num_classes=19,
    input_channels=10
)
print("Model loaded successfully!")

In [None]:
print("Running batch inference...")
predictions = run_batch_inference(
    model=model,
    input_folder=INPUT_FOLDER,
    output_file=os.path.join(OUTPUT_DIR, "predictions.csv"),
    max_images=100  # Remove or set to None to process all images
)
print("Batch inference completed!")

In [None]:
if os.path.exists(GROUND_TRUTH_FILE):
    print("Loading ground truth and computing metrics...")
    
    # Load ground truth
    ground_truth = evaluator.load_ground_truth(GROUND_TRUTH_FILE)
    print(f"Loaded ground truth for {len(ground_truth)} images")
    
    # Compute comprehensive metrics
    metrics = evaluator.compute_metrics(predictions, ground_truth)
    print("Metrics computed!")
    
    # Print evaluation summary
    evaluator.print_evaluation_summary(metrics)
    
    # Generate comprehensive report with visualizations
    generate_evaluation_report(metrics, OUTPUT_DIR)
    
else:
    print("No ground truth file found. Skipping accuracy evaluation.")
    print("Predictions saved to CSV for manual review.")

In [None]:
print("\n" + "="*60)
print("SATELLITE LAND COVER CLASSIFICATION EVALUATION COMPLETE")
print("="*60)

if 'metrics' in locals():
    print("\nEVALUATION RESULTS SUMMARY:")
    print(f"  - Overall Accuracy: {metrics['overall_metrics']['accuracy']:.4f}")
    print(f"  - Macro F1-Score: {metrics['overall_metrics']['macro_f1']:.4f}")
    print(f"  - Images Evaluated: {metrics['overall_metrics']['num_samples']}")
    print(f"\nResults saved in: {OUTPUT_DIR}/")
    print("  - Confusion matrix visualization")
    print("  - Per-class performance charts") 
    print("  - Confidence analysis plots")
    print("  - Detailed metrics (JSON)")
    print("  - Evaluation summary report")

print(f"\nPredictions saved to: {OUTPUT_DIR}/predictions.csv")
print("\nThis completes the downstream evaluation for satellite land cover classification.")
print("Use the generated metrics and visualizations to assess restoration model quality.")