# Hornet and Wasp Classification Project

This notebook implements and compares three different deep learning models for classifying hornets and wasps:
- Model 1: Custom CNN from scratch
- Model 2: Transfer Learning with pre-trained ResNet
- Model 3: Vision Transformer (ViT)

**Dataset:**
- Vespa_crabro: 954 train, 100 validation images
- Vespa_velutina: 1,102 train, 100 validation images
- Vespula_sp: 1,032 train, 97 validation images
- Total: 3,088 training, 297 validation images

## 1. Setup and Data Exploration

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
import torchvision
from torchvision import transforms, models
from torchvision.datasets import ImageFolder

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import classification_report, confusion_matrix, accuracy_score
from sklearn.metrics import precision_recall_fscore_support

import os
from PIL import Image
import random
from pathlib import Path
import time
from collections import defaultdict

# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)
random.seed(42)

# Check if GPU is available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"Memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.1f} GB")

In [None]:
# Data paths
train_dir = 'dataset/data3000/data/train/images'
val_dir = 'dataset/data3000/data/val/images'

# Get class names
class_names = sorted(os.listdir(train_dir))
print(f"Classes: {class_names}")
num_classes = len(class_names)
print(f"Number of classes: {num_classes}")

# Count images per class
for split, split_dir in [('Train', train_dir), ('Validation', val_dir)]:
    print(f"\n{split} set distribution:")
    total = 0
    for class_name in class_names:
        class_path = os.path.join(split_dir, class_name)
        count = len(os.listdir(class_path))
        total += count
        print(f"  {class_name}: {count} images")
    print(f"  Total: {total} images")

In [None]:
# Visualize sample images from each class
fig, axes = plt.subplots(3, 5, figsize=(15, 9))
fig.suptitle('Sample Images from Each Class', fontsize=16)

for i, class_name in enumerate(class_names):
    class_path = os.path.join(train_dir, class_name)
    image_files = os.listdir(class_path)[:5]  # Take first 5 images
    
    for j, img_file in enumerate(image_files):
        img_path = os.path.join(class_path, img_file)
        img = Image.open(img_path)
        
        axes[i, j].imshow(img)
        axes[i, j].set_title(f'{class_name}' if j == 0 else '')
        axes[i, j].axis('off')

plt.tight_layout()
plt.show()

In [None]:
# Analyze image dimensions
dimensions = []
for class_name in class_names:
    class_path = os.path.join(train_dir, class_name)
    for img_file in os.listdir(class_path)[:10]:  # Sample 10 images per class
        img_path = os.path.join(class_path, img_file)
        img = Image.open(img_path)
        dimensions.append(img.size)

dimensions_df = pd.DataFrame(dimensions, columns=['width', 'height'])
print("Image dimensions statistics:")
print(dimensions_df.describe())

# Plot dimension distribution
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))
ax1.hist(dimensions_df['width'], bins=20, alpha=0.7, label='Width')
ax1.set_xlabel('Width (pixels)')
ax1.set_ylabel('Frequency')
ax1.set_title('Image Width Distribution')

ax2.hist(dimensions_df['height'], bins=20, alpha=0.7, label='Height')
ax2.set_xlabel('Height (pixels)')
ax2.set_ylabel('Frequency')
ax2.set_title('Image Height Distribution')

plt.tight_layout()
plt.show()

## 2. Data Loading and Preprocessing

In [None]:
# Data transforms
train_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomRotation(degrees=15),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
    transforms.RandomAffine(degrees=0, translate=(0.1, 0.1)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

val_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Create datasets
train_dataset = ImageFolder(train_dir, transform=train_transform)
val_dataset = ImageFolder(val_dir, transform=val_transform)

# Create data loaders
batch_size = 32
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=4)

print(f"Training samples: {len(train_dataset)}")
print(f"Validation samples: {len(val_dataset)}")
print(f"Batch size: {batch_size}")
print(f"Training batches: {len(train_loader)}")
print(f"Validation batches: {len(val_loader)}")
print(f"Class to index mapping: {train_dataset.class_to_idx}")

In [None]:
# Visualize augmented images
def show_augmented_images(dataset, num_samples=5):
    fig, axes = plt.subplots(2, num_samples, figsize=(15, 6))
    
    # Get a random sample
    idx = random.randint(0, len(dataset) - 1)
    original_img = Image.open(dataset.imgs[idx][0])
    
    # Show original
    axes[0, 0].imshow(original_img)
    axes[0, 0].set_title('Original')
    axes[0, 0].axis('off')
    
    # Show augmented versions
    for i in range(num_samples):
        if i == 0:
            # Show resized version
            img_tensor, _ = dataset[idx]
            # Denormalize for display
            denorm = transforms.Normalize(mean=[-0.485/0.229, -0.456/0.224, -0.406/0.225],
                                        std=[1/0.229, 1/0.224, 1/0.225])
            img_display = denorm(img_tensor).clamp(0, 1)
            axes[1, i].imshow(img_display.permute(1, 2, 0))
            axes[1, i].set_title('Processed')
        else:
            img_tensor, _ = dataset[idx]
            # Denormalize for display
            denorm = transforms.Normalize(mean=[-0.485/0.229, -0.456/0.224, -0.406/0.225],
                                        std=[1/0.229, 1/0.224, 1/0.225])
            img_display = denorm(img_tensor).clamp(0, 1)
            axes[1, i].imshow(img_display.permute(1, 2, 0))
            axes[1, i].set_title(f'Augmented {i}')
        axes[1, i].axis('off')
    
    # Hide unused subplots in first row
    for i in range(1, num_samples):
        axes[0, i].axis('off')
    
    plt.tight_layout()
    plt.show()

show_augmented_images(train_dataset)

## 3. Model Definitions

### Model 1: Custom CNN from Scratch

In [None]:
class CustomCNN(nn.Module):
    def __init__(self, num_classes=3):
        super(CustomCNN, self).__init__()
        
        # Convolutional layers
        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
        self.conv4 = nn.Conv2d(128, 256, kernel_size=3, padding=1)
        self.conv5 = nn.Conv2d(256, 512, kernel_size=3, padding=1)
        
        # Pooling and dropout
        self.pool = nn.MaxPool2d(2, 2)
        self.dropout = nn.Dropout(0.5)
        self.batch_norm1 = nn.BatchNorm2d(32)
        self.batch_norm2 = nn.BatchNorm2d(64)
        self.batch_norm3 = nn.BatchNorm2d(128)
        self.batch_norm4 = nn.BatchNorm2d(256)
        self.batch_norm5 = nn.BatchNorm2d(512)
        
        # Fully connected layers
        self.fc1 = nn.Linear(512 * 7 * 7, 1024)
        self.fc2 = nn.Linear(1024, 512)
        self.fc3 = nn.Linear(512, num_classes)
        
    def forward(self, x):
        # Block 1
        x = self.pool(F.relu(self.batch_norm1(self.conv1(x))))
        
        # Block 2
        x = self.pool(F.relu(self.batch_norm2(self.conv2(x))))
        
        # Block 3
        x = self.pool(F.relu(self.batch_norm3(self.conv3(x))))
        
        # Block 4
        x = self.pool(F.relu(self.batch_norm4(self.conv4(x))))
        
        # Block 5
        x = self.pool(F.relu(self.batch_norm5(self.conv5(x))))
        
        # Flatten and fully connected
        x = x.view(-1, 512 * 7 * 7)
        x = F.relu(self.fc1(x))
        x = self.dropout(x)
        x = F.relu(self.fc2(x))
        x = self.dropout(x)
        x = self.fc3(x)
        
        return x

# Test the model
model_cnn = CustomCNN(num_classes=num_classes).to(device)
print(f"Custom CNN parameters: {sum(p.numel() for p in model_cnn.parameters()):,}")

# Test forward pass
dummy_input = torch.randn(1, 3, 224, 224).to(device)
output = model_cnn(dummy_input)
print(f"Output shape: {output.shape}")

### Model 2: Transfer Learning with ResNet50

In [None]:
class ResNetTransfer(nn.Module):
    def __init__(self, num_classes=3, pretrained=True):
        super(ResNetTransfer, self).__init__()
        
        # Load pre-trained ResNet50
        self.backbone = models.resnet50(pretrained=pretrained)
        
        # Replace the final layer
        in_features = self.backbone.fc.in_features
        self.backbone.fc = nn.Sequential(
            nn.Dropout(0.5),
            nn.Linear(in_features, 512),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(512, num_classes)
        )
        
    def forward(self, x):
        return self.backbone(x)

# Create ResNet model
model_resnet = ResNetTransfer(num_classes=num_classes).to(device)
print(f"ResNet50 parameters: {sum(p.numel() for p in model_resnet.parameters()):,}")

# Test forward pass
output = model_resnet(dummy_input)
print(f"Output shape: {output.shape}")

### Model 3: Vision Transformer (ViT)

In [None]:
class SimpleViT(nn.Module):
    def __init__(self, num_classes=3, img_size=224, patch_size=16, embed_dim=768, num_heads=12, num_layers=12):
        super(SimpleViT, self).__init__()
        
        self.img_size = img_size
        self.patch_size = patch_size
        self.num_patches = (img_size // patch_size) ** 2
        self.embed_dim = embed_dim
        
        # Patch embedding
        self.patch_embed = nn.Conv2d(3, embed_dim, kernel_size=patch_size, stride=patch_size)
        
        # Class token and positional embedding
        self.cls_token = nn.Parameter(torch.randn(1, 1, embed_dim))
        self.pos_embed = nn.Parameter(torch.randn(1, self.num_patches + 1, embed_dim))
        
        # Transformer encoder
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=embed_dim,
            nhead=num_heads,
            dim_feedforward=embed_dim * 4,
            dropout=0.1,
            batch_first=True
        )
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
        
        # Classification head
        self.norm = nn.LayerNorm(embed_dim)
        self.head = nn.Linear(embed_dim, num_classes)
        self.dropout = nn.Dropout(0.1)
        
    def forward(self, x):
        B = x.shape[0]
        
        # Patch embedding
        x = self.patch_embed(x)  # (B, embed_dim, H//patch_size, W//patch_size)
        x = x.flatten(2).transpose(1, 2)  # (B, num_patches, embed_dim)
        
        # Add class token
        cls_tokens = self.cls_token.expand(B, -1, -1)
        x = torch.cat((cls_tokens, x), dim=1)
        
        # Add positional embedding
        x = x + self.pos_embed
        
        # Transformer
        x = self.transformer(x)
        
        # Classification
        x = self.norm(x[:, 0])  # Use class token
        x = self.dropout(x)
        x = self.head(x)
        
        return x

# Create ViT model (smaller version for faster training)
model_vit = SimpleViT(num_classes=num_classes, embed_dim=384, num_heads=6, num_layers=6).to(device)
print(f"ViT parameters: {sum(p.numel() for p in model_vit.parameters()):,}")

# Test forward pass
output = model_vit(dummy_input)
print(f"Output shape: {output.shape}")

## 4. Training Functions

In [None]:
def train_model(model, train_loader, val_loader, criterion, optimizer, scheduler, num_epochs, model_name):
    """
    Train a model and return training history
    """
    history = {
        'train_loss': [],
        'train_acc': [],
        'val_loss': [],
        'val_acc': []
    }
    
    best_val_acc = 0.0
    best_model_state = None
    
    print(f"Training {model_name}...")
    print("-" * 50)
    
    for epoch in range(num_epochs):
        # Training phase
        model.train()
        running_loss = 0.0
        running_corrects = 0
        
        for inputs, labels in train_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            
            _, preds = torch.max(outputs, 1)
            running_loss += loss.item()
            running_corrects += torch.sum(preds == labels.data)
        
        epoch_loss = running_loss / len(train_loader)
        epoch_acc = running_corrects.double() / len(train_dataset)
        
        history['train_loss'].append(epoch_loss)
        history['train_acc'].append(epoch_acc.item())
        
        # Validation phase
        model.eval()
        val_running_loss = 0.0
        val_running_corrects = 0
        
        with torch.no_grad():
            for inputs, labels in val_loader:
                inputs, labels = inputs.to(device), labels.to(device)
                outputs = model(inputs)
                loss = criterion(outputs, labels)
                
                _, preds = torch.max(outputs, 1)
                val_running_loss += loss.item()
                val_running_corrects += torch.sum(preds == labels.data)
        
        val_epoch_loss = val_running_loss / len(val_loader)
        val_epoch_acc = val_running_corrects.double() / len(val_dataset)
        
        history['val_loss'].append(val_epoch_loss)
        history['val_acc'].append(val_epoch_acc.item())
        
        # Save best model
        if val_epoch_acc > best_val_acc:
            best_val_acc = val_epoch_acc
            best_model_state = model.state_dict().copy()
        
        # Step scheduler
        scheduler.step(val_epoch_loss)
        
        # Print progress
        if (epoch + 1) % 5 == 0 or epoch == 0:
            print(f'Epoch [{epoch+1}/{num_epochs}]')
            print(f'Train Loss: {epoch_loss:.4f}, Train Acc: {epoch_acc:.4f}')
            print(f'Val Loss: {val_epoch_loss:.4f}, Val Acc: {val_epoch_acc:.4f}')
            print(f'Best Val Acc: {best_val_acc:.4f}')
            print()
    
    # Load best model
    model.load_state_dict(best_model_state)
    print(f"Training completed! Best validation accuracy: {best_val_acc:.4f}")
    
    return history, best_val_acc

def evaluate_model(model, data_loader, class_names):
    """
    Evaluate model and return detailed metrics
    """
    model.eval()
    all_preds = []
    all_labels = []
    all_probs = []
    
    with torch.no_grad():
        for inputs, labels in data_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            probabilities = F.softmax(outputs, dim=1)
            _, preds = torch.max(outputs, 1)
            
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
            all_probs.extend(probabilities.cpu().numpy())
    
    # Calculate metrics
    accuracy = accuracy_score(all_labels, all_preds)
    precision, recall, f1, support = precision_recall_fscore_support(all_labels, all_preds, average='weighted')
    
    # Classification report
    class_report = classification_report(all_labels, all_preds, target_names=class_names, output_dict=True)
    
    # Confusion matrix
    cm = confusion_matrix(all_labels, all_preds)
    
    return {
        'accuracy': accuracy,
        'precision': precision,
        'recall': recall,
        'f1': f1,
        'classification_report': class_report,
        'confusion_matrix': cm,
        'predictions': all_preds,
        'labels': all_labels,
        'probabilities': all_probs
    }

## 5. Training All Models

In [None]:
# Training parameters
num_epochs = 25
learning_rate = 0.001

# Initialize models
models_dict = {
    'Custom CNN': CustomCNN(num_classes=num_classes).to(device),
    'ResNet50': ResNetTransfer(num_classes=num_classes).to(device),
    'Vision Transformer': SimpleViT(num_classes=num_classes, embed_dim=384, num_heads=6, num_layers=6).to(device)
}

# Training results storage
results = {}
training_histories = {}

# Train each model
for model_name, model in models_dict.items():
    print(f"\n{'='*60}")
    print(f"TRAINING {model_name.upper()}")
    print(f"{'='*60}")
    
    # Setup optimizer and scheduler
    criterion = nn.CrossEntropyLoss()
    
    if model_name == 'ResNet50':
        # Use different learning rates for backbone and classifier
        optimizer = optim.Adam([
            {'params': model.backbone.fc.parameters(), 'lr': learning_rate},
            {'params': model.backbone.parameters(), 'lr': learning_rate * 0.1}
        ])
    else:
        optimizer = optim.Adam(model.parameters(), lr=learning_rate)
    
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', patience=5, factor=0.5)
    
    # Train model
    start_time = time.time()
    history, best_val_acc = train_model(
        model, train_loader, val_loader, criterion, optimizer, scheduler, num_epochs, model_name
    )
    training_time = time.time() - start_time
    
    # Evaluate model
    print("Evaluating on validation set...")
    metrics = evaluate_model(model, val_loader, class_names)
    
    # Store results
    results[model_name] = {
        'model': model,
        'metrics': metrics,
        'training_time': training_time,
        'best_val_acc': best_val_acc
    }
    training_histories[model_name] = history
    
    print(f"Training time: {training_time:.2f} seconds")
    print(f"Final validation accuracy: {metrics['accuracy']:.4f}")
    print(f"Precision: {metrics['precision']:.4f}")
    print(f"Recall: {metrics['recall']:.4f}")
    print(f"F1-score: {metrics['f1']:.4f}")

print("\n" + "="*60)
print("ALL MODELS TRAINED SUCCESSFULLY!")
print("="*60)

## 6. Model Comparison and Visualization

In [None]:
# Create comprehensive comparison table
comparison_data = []
for model_name, result in results.items():
    metrics = result['metrics']
    comparison_data.append({
        'Model': model_name,
        'Accuracy': f"{metrics['accuracy']:.4f}",
        'Precision': f"{metrics['precision']:.4f}",
        'Recall': f"{metrics['recall']:.4f}",
        'F1-Score': f"{metrics['f1']:.4f}",
        'Training Time (s)': f"{result['training_time']:.1f}",
        'Parameters': f"{sum(p.numel() for p in result['model'].parameters()):,}"
    })

comparison_df = pd.DataFrame(comparison_data)
print("Model Comparison Summary:")
print(comparison_df.to_string(index=False))

In [None]:
# Plot training histories
fig, axes = plt.subplots(2, 2, figsize=(15, 10))
fig.suptitle('Training Histories Comparison', fontsize=16)

metrics_to_plot = ['train_loss', 'val_loss', 'train_acc', 'val_acc']
titles = ['Training Loss', 'Validation Loss', 'Training Accuracy', 'Validation Accuracy']

for idx, (metric, title) in enumerate(zip(metrics_to_plot, titles)):
    ax = axes[idx // 2, idx % 2]
    
    for model_name, history in training_histories.items():
        ax.plot(history[metric], label=model_name, linewidth=2)
    
    ax.set_title(title)
    ax.set_xlabel('Epoch')
    ax.set_ylabel(title.split()[1])
    ax.legend()
    ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

In [None]:
# Plot confusion matrices
fig, axes = plt.subplots(1, 3, figsize=(18, 5))
fig.suptitle('Confusion Matrices Comparison', fontsize=16)

for idx, (model_name, result) in enumerate(results.items()):
    cm = result['metrics']['confusion_matrix']
    
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', 
                xticklabels=class_names, yticklabels=class_names,
                ax=axes[idx])
    axes[idx].set_title(f'{model_name}\nAccuracy: {result["metrics"]["accuracy"]:.4f}')
    axes[idx].set_xlabel('Predicted')
    axes[idx].set_ylabel('Actual')

plt.tight_layout()
plt.show()

In [None]:
# Performance metrics comparison
metrics_names = ['accuracy', 'precision', 'recall', 'f1']
x = np.arange(len(metrics_names))
width = 0.25

fig, ax = plt.subplots(figsize=(12, 6))

for i, (model_name, result) in enumerate(results.items()):
    metrics = result['metrics']
    values = [metrics[metric] for metric in metrics_names]
    ax.bar(x + i * width, values, width, label=model_name, alpha=0.8)

ax.set_xlabel('Metrics')
ax.set_ylabel('Score')
ax.set_title('Model Performance Comparison')
ax.set_xticks(x + width)
ax.set_xticklabels([m.capitalize() for m in metrics_names])
ax.legend()
ax.grid(True, axis='y', alpha=0.3)

# Add value labels on bars
for i, (model_name, result) in enumerate(results.items()):
    metrics = result['metrics']
    values = [metrics[metric] for metric in metrics_names]
    for j, v in enumerate(values):
        ax.text(j + i * width, v + 0.01, f'{v:.3f}', ha='center', va='bottom', fontsize=9)

plt.tight_layout()
plt.show()

In [None]:
# Per-class performance comparison
fig, axes = plt.subplots(1, 3, figsize=(18, 5))
fig.suptitle('Per-Class Performance Comparison', fontsize=16)

metrics_to_show = ['precision', 'recall', 'f1-score']

for metric_idx, metric in enumerate(metrics_to_show):
    ax = axes[metric_idx]
    
    x = np.arange(len(class_names))
    width = 0.25
    
    for model_idx, (model_name, result) in enumerate(results.items()):
        report = result['metrics']['classification_report']
        values = [report[class_name][metric] for class_name in class_names]
        ax.bar(x + model_idx * width, values, width, label=model_name, alpha=0.8)
    
    ax.set_xlabel('Classes')
    ax.set_ylabel(metric.capitalize())
    ax.set_title(f'{metric.capitalize()} by Class')
    ax.set_xticks(x + width)
    ax.set_xticklabels([name.replace('_', '\n') for name in class_names], fontsize=10)
    ax.legend()
    ax.grid(True, axis='y', alpha=0.3)

plt.tight_layout()
plt.show()

## 7. Prediction Examples with Confidence

In [None]:
def predict_with_confidence(models_dict, image_path, transform, class_names, top_k=3):
    """
    Make predictions with all models and show confidence scores
    """
    # Load and preprocess image
    image = Image.open(image_path).convert('RGB')
    image_tensor = transform(image).unsqueeze(0).to(device)
    
    predictions = {}
    
    for model_name, model in models_dict.items():
        model.eval()
        with torch.no_grad():
            outputs = model(image_tensor)
            probabilities = F.softmax(outputs, dim=1)[0]
            
            # Get top k predictions
            top_probs, top_indices = torch.topk(probabilities, top_k)
            
            predictions[model_name] = {
                'top_classes': [class_names[idx] for idx in top_indices.cpu().numpy()],
                'top_probs': [prob.item() for prob in top_probs.cpu()],
                'predicted_class': class_names[top_indices[0]],
                'confidence': top_probs[0].item()
            }
    
    return image, predictions

# Get some random validation images for prediction examples
sample_images = []
for class_name in class_names:
    class_path = os.path.join(val_dir, class_name)
    images = os.listdir(class_path)
    sample_images.extend([os.path.join(class_path, img) for img in images[:2]])

# Make predictions on sample images
fig, axes = plt.subplots(len(sample_images), 4, figsize=(20, 5 * len(sample_images)))
if len(sample_images) == 1:
    axes = axes.reshape(1, -1)

for img_idx, img_path in enumerate(sample_images):
    # Get predictions
    image, predictions = predict_with_confidence(
        {name: result['model'] for name, result in results.items()}, 
        img_path, val_transform, class_names
    )
    
    # True label
    true_class = os.path.basename(os.path.dirname(img_path))
    
    # Show original image
    axes[img_idx, 0].imshow(image)
    axes[img_idx, 0].set_title(f'True Class: {true_class}', fontweight='bold')
    axes[img_idx, 0].axis('off')
    
    # Show predictions for each model
    for model_idx, (model_name, pred) in enumerate(predictions.items()):
        ax = axes[img_idx, model_idx + 1]
        
        # Create prediction text
        pred_text = f"Prediction: {pred['predicted_class']}\nConfidence: {pred['confidence']:.1%}\n\n"
        pred_text += "Top 3 predictions:\n"
        for class_name, prob in zip(pred['top_classes'], pred['top_probs']):
            pred_text += f"• {class_name}: {prob:.1%}\n"
        
        # Color based on correctness
        color = 'green' if pred['predicted_class'] == true_class else 'red'
        
        ax.text(0.05, 0.95, pred_text, transform=ax.transAxes, fontsize=10,
                verticalalignment='top', bbox=dict(boxstyle='round', facecolor=color, alpha=0.1))
        ax.set_title(f'{model_name}', fontweight='bold')
        ax.axis('off')

plt.tight_layout()
plt.show()

## 8. Final Summary and Recommendations

In [None]:
# Find the best performing model
best_model = max(results.items(), key=lambda x: x[1]['metrics']['accuracy'])
best_model_name, best_result = best_model

print("=" * 60)
print("FINAL ANALYSIS SUMMARY")
print("=" * 60)
print()
print(f"🏆 BEST PERFORMING MODEL: {best_model_name}")
print(f"   • Accuracy: {best_result['metrics']['accuracy']:.4f} ({best_result['metrics']['accuracy']:.1%})")
print(f"   • Precision: {best_result['metrics']['precision']:.4f}")
print(f"   • Recall: {best_result['metrics']['recall']:.4f}")
print(f"   • F1-Score: {best_result['metrics']['f1']:.4f}")
print(f"   • Training Time: {best_result['training_time']:.1f} seconds")
print()

print("📊 COMPLETE RANKING:")
ranked_models = sorted(results.items(), key=lambda x: x[1]['metrics']['accuracy'], reverse=True)
for i, (model_name, result) in enumerate(ranked_models, 1):
    print(f"   {i}. {model_name}: {result['metrics']['accuracy']:.1%} accuracy")
print()

print("💡 RECOMMENDATIONS:")
print(f"   • For highest accuracy: Use {ranked_models[0][0]}")
print(f"   • For fastest training: Use {min(results.items(), key=lambda x: x[1]['training_time'])[0]}")
print(f"   • For production deployment: Consider {best_model_name} for best balance of accuracy and reliability")
print()

print("🔍 DETAILED INSIGHTS:")
for model_name, result in results.items():
    report = result['metrics']['classification_report']
    print(f"\n{model_name}:")
    print(f"   • Overall accuracy: {result['metrics']['accuracy']:.1%}")
    print(f"   • Best class performance: {max(report[cls]['f1-score'] for cls in class_names):.3f} F1-score")
    print(f"   • Worst class performance: {min(report[cls]['f1-score'] for cls in class_names):.3f} F1-score")
    print(f"   • Parameters: {sum(p.numel() for p in result['model'].parameters()):,}")

print("\n" + "=" * 60)
print("HORNET/WASP CLASSIFICATION PROJECT COMPLETED SUCCESSFULLY!")
print("=" * 60)

In [None]:
# Save model comparison results to CSV for future reference
comparison_df.to_csv('model_comparison_results.csv', index=False)
print("Model comparison results saved to 'model_comparison_results.csv'")

# Optional: Save the best model
torch.save(best_result['model'].state_dict(), f'best_model_{best_model_name.replace(" ", "_").lower()}.pth')
print(f"Best model ({best_model_name}) saved to 'best_model_{best_model_name.replace(' ', '_').lower()}.pth'")