# Model Benchmark Evaluation

This notebook evaluates all trained models on the same test dataset and compares their performance metrics.

In [None]:
import torch
import torch.nn as nn
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
from tqdm import tqdm
from sklearn.metrics import confusion_matrix, classification_report, accuracy_score, f1_score
import warnings
warnings.filterwarnings('ignore')

# Set style for better visualizations
plt.style.use('seaborn-v0_8-darkgrid')
sns.set_palette('husl')

# Check device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')

## 1. Load Test Dataset

In [None]:
# Import data preparation utilities
import sys
sys.path.append('./')

from TCN.dataset import prepare_data

# Prepare data loaders with consistent parameters
train_loader, val_loader, test_loader, data_info = prepare_data(
    data_path="./dataset",
    patch_size=64,
    stride=32,
    test_size=0.2,
    val_size=0.1,
    random_state=42,  # Fixed seed for reproducibility
    batch_size=16,
    num_workers=0
)

# Class names mapping
class_names = {
    0: 'Corn',
    1: 'Wheat', 
    2: 'Sunflower',
    3: 'Pumpkin',
    4: 'Artificial_Surface',
    5: 'Water',
    6: 'Road',
    7: 'Other'
}

num_classes = data_info['num_classes']
print(f"Number of test samples: {len(test_loader.dataset)}")
print(f"Number of classes: {num_classes}")
print(f"Test batch size: {test_loader.batch_size}")

## 2. Define Evaluation Functions

In [None]:
def calculate_iou(pred, target, num_classes):
    """Calculate IoU for each class"""
    ious = []
    for cls in range(num_classes):
        pred_cls = pred == cls
        target_cls = target == cls
        intersection = (pred_cls & target_cls).sum()
        union = (pred_cls | target_cls).sum()
        iou = intersection / union if union > 0 else 0
        ious.append(iou)
    return np.array(ious)

def evaluate_model(model, test_loader, device, model_name="Model"):
    """
    Evaluate a model on the test dataset
    """
    model.eval()
    all_preds = []
    all_targets = []
    
    print(f"\nEvaluating {model_name}...")
    
    with torch.no_grad():
        for data, targets in tqdm(test_loader, desc=f'Testing {model_name}'):
            data = data.to(device)
            
            # Get predictions
            outputs = model(data)
            
            # Handle different output shapes
            if len(outputs.shape) == 4:  # (batch, height, width, classes)
                _, predicted = outputs.max(-1)
            elif len(outputs.shape) == 3:  # (batch, classes, height*width)
                outputs = outputs.view(data.size(0), num_classes, data.size(2), data.size(3))
                _, predicted = outputs.max(1)
            else:
                raise ValueError(f"Unexpected output shape: {outputs.shape}")
            
            all_preds.append(predicted.cpu().numpy())
            all_targets.append(targets.numpy())
    
    # Concatenate all batches
    all_preds = np.concatenate(all_preds)
    all_targets = np.concatenate(all_targets)
    
    # Flatten for metrics calculation
    all_preds_flat = all_preds.flatten()
    all_targets_flat = all_targets.flatten()
    
    # Calculate metrics
    overall_acc = accuracy_score(all_targets_flat, all_preds_flat)
    
    # Per-class metrics
    per_class_acc = []
    per_class_iou = []
    
    for cls in range(num_classes):
        mask = all_targets_flat == cls
        if mask.sum() > 0:
            class_acc = accuracy_score(all_targets_flat[mask], all_preds_flat[mask])
            per_class_acc.append(class_acc)
        else:
            per_class_acc.append(0)
    
    # Calculate IoU
    ious = calculate_iou(all_preds_flat, all_targets_flat, num_classes)
    per_class_iou = ious
    
    # F1 scores
    f1_micro = f1_score(all_targets_flat, all_preds_flat, average='micro')
    f1_macro = f1_score(all_targets_flat, all_preds_flat, average='macro')
    f1_weighted = f1_score(all_targets_flat, all_preds_flat, average='weighted')
    
    # Confusion matrix
    cm = confusion_matrix(all_targets_flat, all_preds_flat, labels=range(num_classes))
    
    return {
        'model_name': model_name,
        'overall_accuracy': overall_acc,
        'mean_accuracy': np.mean(per_class_acc),
        'mean_iou': np.mean(per_class_iou),
        'f1_micro': f1_micro,
        'f1_macro': f1_macro,
        'f1_weighted': f1_weighted,
        'per_class_accuracy': per_class_acc,
        'per_class_iou': per_class_iou,
        'confusion_matrix': cm,
        'predictions': all_preds_flat,
        'targets': all_targets_flat
    }

## 3. Load and Evaluate Models

In [None]:
# Dictionary to store all results
benchmark_results = {}

# Get input shape from data
sample_batch, _ = next(iter(test_loader))
batch_size, height, width, temporal, spectral = sample_batch.shape
print(f"Input shape: {sample_batch.shape}")

### 3.1 Evaluate TCN Model

In [None]:
# Load TCN model
try:
    from TCN.model import create_tcn_model
    from TCN.compat import load_checkpoint
    
    # Create model
    tcn_model = create_tcn_model(
        input_shape=(height, width, temporal, spectral),
        num_classes=num_classes,
        hidden_channels=64,
        kernel_size=3,
        dropout=0.2
    ).to(device)
    
    # Load checkpoint
    checkpoint_path = './TCN/checkpoints_test/best_model.pth'
    if Path(checkpoint_path).exists():
        checkpoint = load_checkpoint(checkpoint_path, map_location=device)
        tcn_model.load_state_dict(checkpoint['model_state_dict'])
        print(f"TCN model loaded from {checkpoint_path}")
        
        # Evaluate
        tcn_results = evaluate_model(tcn_model, test_loader, device, "TCN")
        benchmark_results['TCN'] = tcn_results
    else:
        print(f"TCN checkpoint not found at {checkpoint_path}")
except Exception as e:
    print(f"Error loading TCN model: {e}")

### 3.2 Evaluate Transformer Model

In [None]:
# Load Transformer model
try:
    from Transformer.model import create_transformer_model
    
    # Create model
    transformer_model = create_transformer_model(
        input_shape=(height, width, temporal, spectral),
        num_classes=num_classes,
        embed_dim=128,
        num_heads=8,
        num_layers=6,
        dropout=0.1
    ).to(device)
    
    # Load checkpoint
    checkpoint_path = './Transformer/checkpoints/best_model.pth'
    if Path(checkpoint_path).exists():
        checkpoint = torch.load(checkpoint_path, map_location=device)
        transformer_model.load_state_dict(checkpoint['model_state_dict'])
        print(f"Transformer model loaded from {checkpoint_path}")
        
        # Evaluate
        transformer_results = evaluate_model(transformer_model, test_loader, device, "Transformer")
        benchmark_results['Transformer'] = transformer_results
    else:
        print(f"Transformer checkpoint not found at {checkpoint_path}")
except Exception as e:
    print(f"Error loading Transformer model: {e}")

### 3.3 Evaluate Swin Transformer Model

In [None]:
# Load Swin Transformer model
try:
    from Swin_Transformer.model import create_swin_model
    
    # Create model
    swin_model = create_swin_model(
        input_shape=(height, width, temporal, spectral),
        num_classes=num_classes,
        embed_dim=96,
        depths=[2, 2, 6, 2],
        num_heads=[3, 6, 12, 24],
        window_size=7
    ).to(device)
    
    # Load checkpoint
    checkpoint_path = './Swin-Transformer/checkpoints/best_model.pth'
    if Path(checkpoint_path).exists():
        checkpoint = torch.load(checkpoint_path, map_location=device)
        swin_model.load_state_dict(checkpoint['model_state_dict'])
        print(f"Swin Transformer model loaded from {checkpoint_path}")
        
        # Evaluate
        swin_results = evaluate_model(swin_model, test_loader, device, "Swin-Transformer")
        benchmark_results['Swin-Transformer'] = swin_results
    else:
        print(f"Swin Transformer checkpoint not found at {checkpoint_path}")
except Exception as e:
    print(f"Error loading Swin Transformer model: {e}")

## 4. Results Comparison

In [None]:
# Create comparison dataframe
if benchmark_results:
    comparison_data = []
    for model_name, results in benchmark_results.items():
        comparison_data.append({
            'Model': model_name,
            'Overall Accuracy': f"{results['overall_accuracy']:.4f}",
            'Mean Accuracy': f"{results['mean_accuracy']:.4f}",
            'Mean IoU': f"{results['mean_iou']:.4f}",
            'F1-Score (Micro)': f"{results['f1_micro']:.4f}",
            'F1-Score (Macro)': f"{results['f1_macro']:.4f}",
            'F1-Score (Weighted)': f"{results['f1_weighted']:.4f}"
        })
    
    comparison_df = pd.DataFrame(comparison_data)
    print("\n" + "="*80)
    print("MODEL BENCHMARK RESULTS")
    print("="*80)
    print(comparison_df.to_string(index=False))
    print("="*80)

## 5. Visualization

### 5.1 Overall Metrics Comparison

In [None]:
if benchmark_results:
    # Prepare data for plotting
    models = list(benchmark_results.keys())
    metrics = ['overall_accuracy', 'mean_accuracy', 'mean_iou', 'f1_weighted']
    metric_names = ['Overall Accuracy', 'Mean Accuracy', 'Mean IoU', 'F1-Score (Weighted)']
    
    # Create subplots
    fig, axes = plt.subplots(2, 2, figsize=(14, 10))
    axes = axes.flatten()
    
    for idx, (metric, metric_name) in enumerate(zip(metrics, metric_names)):
        values = [benchmark_results[model][metric] for model in models]
        
        ax = axes[idx]
        bars = ax.bar(models, values, color=['#FF6B6B', '#4ECDC4', '#45B7D1'][:len(models)])
        
        # Add value labels on bars
        for bar, value in zip(bars, values):
            height = bar.get_height()
            ax.text(bar.get_x() + bar.get_width()/2., height,
                   f'{value:.4f}', ha='center', va='bottom', fontsize=10)
        
        ax.set_title(metric_name, fontsize=12, fontweight='bold')
        ax.set_ylabel('Score', fontsize=10)
        ax.set_ylim(0, 1.05)
        ax.grid(axis='y', alpha=0.3)
    
    plt.suptitle('Model Performance Comparison', fontsize=16, fontweight='bold', y=1.02)
    plt.tight_layout()
    plt.show()

### 5.2 Per-Class Accuracy Comparison

In [None]:
if benchmark_results:
    # Create per-class accuracy comparison
    fig, ax = plt.subplots(1, 1, figsize=(14, 8))
    
    x = np.arange(num_classes)
    width = 0.25
    colors = ['#FF6B6B', '#4ECDC4', '#45B7D1']
    
    for i, (model_name, results) in enumerate(benchmark_results.items()):
        offset = (i - len(benchmark_results)/2 + 0.5) * width
        bars = ax.bar(x + offset, results['per_class_accuracy'], width, 
                      label=model_name, color=colors[i % len(colors)])
        
        # Add value labels
        for bar, value in zip(bars, results['per_class_accuracy']):
            if value > 0.05:  # Only show label if bar is visible
                ax.text(bar.get_x() + bar.get_width()/2., bar.get_height(),
                       f'{value:.2f}', ha='center', va='bottom', fontsize=8)
    
    ax.set_xlabel('Class', fontsize=12)
    ax.set_ylabel('Accuracy', fontsize=12)
    ax.set_title('Per-Class Accuracy Comparison', fontsize=14, fontweight='bold')
    ax.set_xticks(x)
    ax.set_xticklabels([class_names[i] for i in range(num_classes)], rotation=45, ha='right')
    ax.legend(loc='upper right', fontsize=10)
    ax.grid(axis='y', alpha=0.3)
    ax.set_ylim(0, 1.1)
    
    plt.tight_layout()
    plt.show()

### 5.3 Per-Class IoU Comparison

In [None]:
if benchmark_results:
    # Create per-class IoU comparison
    fig, ax = plt.subplots(1, 1, figsize=(14, 8))
    
    x = np.arange(num_classes)
    width = 0.25
    colors = ['#FF6B6B', '#4ECDC4', '#45B7D1']
    
    for i, (model_name, results) in enumerate(benchmark_results.items()):
        offset = (i - len(benchmark_results)/2 + 0.5) * width
        bars = ax.bar(x + offset, results['per_class_iou'], width, 
                      label=model_name, color=colors[i % len(colors)], alpha=0.8)
        
        # Add value labels
        for bar, value in zip(bars, results['per_class_iou']):
            if value > 0.05:  # Only show label if bar is visible
                ax.text(bar.get_x() + bar.get_width()/2., bar.get_height(),
                       f'{value:.2f}', ha='center', va='bottom', fontsize=8)
    
    ax.set_xlabel('Class', fontsize=12)
    ax.set_ylabel('IoU', fontsize=12)
    ax.set_title('Per-Class IoU Comparison', fontsize=14, fontweight='bold')
    ax.set_xticks(x)
    ax.set_xticklabels([class_names[i] for i in range(num_classes)], rotation=45, ha='right')
    ax.legend(loc='upper right', fontsize=10)
    ax.grid(axis='y', alpha=0.3)
    ax.set_ylim(0, 1.1)
    
    plt.tight_layout()
    plt.show()

### 5.4 Confusion Matrices

In [None]:
if benchmark_results:
    # Plot confusion matrices
    n_models = len(benchmark_results)
    fig, axes = plt.subplots(1, n_models, figsize=(8*n_models, 6))
    
    if n_models == 1:
        axes = [axes]
    
    for idx, (model_name, results) in enumerate(benchmark_results.items()):
        cm = results['confusion_matrix']
        
        # Normalize confusion matrix
        cm_normalized = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
        cm_normalized = np.nan_to_num(cm_normalized)  # Replace NaN with 0
        
        # Plot
        im = axes[idx].imshow(cm_normalized, cmap='Blues', aspect='auto')
        axes[idx].set_title(f'{model_name} Confusion Matrix', fontsize=12, fontweight='bold')
        axes[idx].set_xlabel('Predicted Label', fontsize=10)
        axes[idx].set_ylabel('True Label', fontsize=10)
        
        # Set ticks
        axes[idx].set_xticks(range(num_classes))
        axes[idx].set_yticks(range(num_classes))
        axes[idx].set_xticklabels([class_names[i][:4] for i in range(num_classes)], rotation=45)
        axes[idx].set_yticklabels([class_names[i][:4] for i in range(num_classes)])
        
        # Add colorbar
        plt.colorbar(im, ax=axes[idx], fraction=0.046, pad=0.04)
        
        # Add text annotations
        for i in range(num_classes):
            for j in range(num_classes):
                text = axes[idx].text(j, i, f'{cm_normalized[i, j]:.2f}',
                                      ha="center", va="center", 
                                      color="white" if cm_normalized[i, j] > 0.5 else "black",
                                      fontsize=8)
    
    plt.suptitle('Normalized Confusion Matrices', fontsize=14, fontweight='bold', y=1.05)
    plt.tight_layout()
    plt.show()

### 5.5 Radar Chart Comparison

In [None]:
if benchmark_results:
    from math import pi
    
    # Prepare data
    categories = ['Overall\nAccuracy', 'Mean\nAccuracy', 'Mean\nIoU', 
                  'F1-Score\n(Micro)', 'F1-Score\n(Macro)', 'F1-Score\n(Weighted)']
    metrics_keys = ['overall_accuracy', 'mean_accuracy', 'mean_iou', 
                    'f1_micro', 'f1_macro', 'f1_weighted']
    
    # Number of variables
    N = len(categories)
    
    # Compute angle for each axis
    angles = [n / float(N) * 2 * pi for n in range(N)]
    angles += angles[:1]
    
    # Initialize plot
    fig, ax = plt.subplots(figsize=(10, 10), subplot_kw=dict(projection='polar'))
    
    colors = ['#FF6B6B', '#4ECDC4', '#45B7D1']
    
    for idx, (model_name, results) in enumerate(benchmark_results.items()):
        values = [results[metric] for metric in metrics_keys]
        values += values[:1]  # Complete the circle
        
        ax.plot(angles, values, 'o-', linewidth=2, label=model_name, color=colors[idx % len(colors)])
        ax.fill(angles, values, alpha=0.25, color=colors[idx % len(colors)])
    
    ax.set_xticks(angles[:-1])
    ax.set_xticklabels(categories, size=10)
    ax.set_ylim(0, 1)
    ax.set_yticks([0.2, 0.4, 0.6, 0.8, 1.0])
    ax.set_yticklabels(['0.2', '0.4', '0.6', '0.8', '1.0'], size=8)
    ax.grid(True)
    
    plt.legend(loc='upper right', bbox_to_anchor=(1.3, 1.1), fontsize=10)
    plt.title('Multi-Metric Model Comparison', size=14, fontweight='bold', y=1.08)
    plt.tight_layout()
    plt.show()

## 6. Statistical Summary

In [None]:
if benchmark_results:
    # Create detailed per-class comparison
    class_comparison = pd.DataFrame()
    
    for model_name, results in benchmark_results.items():
        model_data = {}
        for i in range(num_classes):
            model_data[f"{class_names[i]}_Acc"] = f"{results['per_class_accuracy'][i]:.3f}"
            model_data[f"{class_names[i]}_IoU"] = f"{results['per_class_iou'][i]:.3f}"
        
        class_comparison = pd.concat([class_comparison, pd.DataFrame([model_data], index=[model_name])])
    
    print("\n" + "="*80)
    print("PER-CLASS DETAILED METRICS")
    print("="*80)
    
    # Print accuracy columns
    acc_cols = [col for col in class_comparison.columns if '_Acc' in col]
    print("\nAccuracy per Class:")
    print(class_comparison[acc_cols].to_string())
    
    # Print IoU columns
    iou_cols = [col for col in class_comparison.columns if '_IoU' in col]
    print("\nIoU per Class:")
    print(class_comparison[iou_cols].to_string())
    
    # Best model determination
    print("\n" + "="*80)
    print("BEST MODEL ANALYSIS")
    print("="*80)
    
    metrics_to_check = ['overall_accuracy', 'mean_iou', 'f1_weighted']
    
    for metric in metrics_to_check:
        best_model = max(benchmark_results.items(), key=lambda x: x[1][metric])
        print(f"\nBest {metric.replace('_', ' ').title()}: {best_model[0]} ({best_model[1][metric]:.4f})")
    
    # Overall best model (average rank)
    model_scores = {}
    for model_name in benchmark_results.keys():
        scores = [benchmark_results[model_name][metric] for metric in metrics_to_check]
        model_scores[model_name] = np.mean(scores)
    
    best_overall = max(model_scores.items(), key=lambda x: x[1])
    print(f"\n🏆 Overall Best Model: {best_overall[0]} (Average Score: {best_overall[1]:.4f})")

## 7. Save Results

In [None]:
# Save benchmark results to file
if benchmark_results:
    import json
    from datetime import datetime
    
    # Prepare results for saving
    save_results = {}
    for model_name, results in benchmark_results.items():
        save_results[model_name] = {
            'overall_accuracy': float(results['overall_accuracy']),
            'mean_accuracy': float(results['mean_accuracy']),
            'mean_iou': float(results['mean_iou']),
            'f1_micro': float(results['f1_micro']),
            'f1_macro': float(results['f1_macro']),
            'f1_weighted': float(results['f1_weighted']),
            'per_class_accuracy': [float(x) for x in results['per_class_accuracy']],
            'per_class_iou': [float(x) for x in results['per_class_iou']]
        }
    
    # Add metadata
    save_data = {
        'timestamp': datetime.now().isoformat(),
        'test_set_size': len(test_loader.dataset),
        'num_classes': num_classes,
        'class_names': class_names,
        'results': save_results
    }
    
    # Save to JSON
    output_path = './benchmark_results.json'
    with open(output_path, 'w') as f:
        json.dump(save_data, f, indent=2)
    
    print(f"\n✅ Results saved to {output_path}")
    
    # Also save comparison DataFrame to CSV
    csv_path = './benchmark_comparison.csv'
    comparison_df.to_csv(csv_path, index=False)
    print(f"✅ Comparison table saved to {csv_path}")

## Summary

This benchmark evaluation provides a comprehensive comparison of all trained models on the same test dataset. The analysis includes:

1. **Overall Metrics**: Accuracy, IoU, and F1-scores
2. **Per-Class Performance**: Detailed accuracy and IoU for each crop type
3. **Visual Comparisons**: Bar charts, radar plots, and confusion matrices
4. **Statistical Analysis**: Best model identification across different metrics

The results are saved for future reference and can be used to select the best model for deployment.