# Knowledge Distillation - DenseNet201 (Teacher) to EfficientNet-B0 (Student)

**Teacher**: DenseNet201 (18.1M params) 
**Student**: EfficientNet-B0 (4.0M params)
**Compression**: ~4.5x smaller model

In [None]:
# Import required libraries
import os
from pathlib import Path
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torchvision.models import densenet201, DenseNet201_Weights, efficientnet_b0, EfficientNet_B0_Weights
from PIL import Image
from typing import Tuple, List
import numpy as np
import pandas as pd
from tqdm.notebook import tqdm
import random
import warnings
warnings.filterwarnings('ignore')

print("✓ All libraries imported successfully")

In [None]:
# Configuration
DATA_ROOT = Path("/Users/alimran/Desktop/CSE465/Split_Dataset")
IMG_SIZE = 224
BATCH_SIZE = 32
NUM_WORKERS = 0  # Set to 0 for Jupyter notebooks to avoid multiprocessing issues
EPOCHS = 50
LR = 1e-4
NUM_SPECIES = 3  # Eggplant, Potato, Tomato
NUM_HEALTH = 4   # Bacterial, Fungal, Healthy, Virus
DROPOUT = 0.3
SEED = 42

# KD-specific parameters
TEMPERATURE = 4.0
ALPHA = 0.7  # Weight for distillation loss (1-alpha for hard target loss)
PATIENCE = 7  # Early stopping patience

# Device setup - MPS for Mac, CUDA for NVIDIA, CPU fallback
device = torch.device("cuda" if torch.cuda.is_available() else
                      "mps" if hasattr(torch.backends, "mps") and torch.backends.mps.is_available()
                      else "cpu")

# Set random seeds
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)

print("="*80)
print("KNOWLEDGE DISTILLATION CONFIGURATION")
print("="*80)
print(f"Dataset:            Eggplant, Potato, Tomato (3 species)")
print(f"Health Classes:     Bacterial, Fungal, Healthy, Virus (4 classes)")
print(f"Data root:          {DATA_ROOT}")
print(f"Device:             {device}")
print(f"Batch Size:         {BATCH_SIZE}")
print(f"Learning Rate:      {LR}")
print(f"Epochs:             {EPOCHS}")
print(f"Patience:           {PATIENCE}")
print(f"\nKD Parameters:")
print(f"Temperature:        {TEMPERATURE}")
print(f"Alpha (KD weight):  {ALPHA}")
print(f"Random Seed:        {SEED}")
print("="*80)

## Visualization Code

In [None]:
# Enable inline plotting for Jupyter
%matplotlib inline

# Set publication-quality style
import matplotlib.pyplot as plt
import seaborn as sns

plt.style.use('seaborn-v0_8-darkgrid')
plt.rcParams['figure.dpi'] = 100  # Lower for notebook display, saved figures will be 300
plt.rcParams['savefig.dpi'] = 300
plt.rcParams['font.family'] = 'serif'
plt.rcParams['font.size'] = 10
plt.rcParams['axes.labelsize'] = 11
plt.rcParams['axes.titlesize'] = 12
plt.rcParams['legend.fontsize'] = 9
plt.rcParams['xtick.labelsize'] = 9
plt.rcParams['ytick.labelsize'] = 9

class TrainingLogger:
    """Logger to track training metrics epoch-by-epoch"""
    def __init__(self):
        self.history = {
            'epoch': [],
            'train_loss': [],
            'val_loss': [],
            'train_acc_species': [],
            'val_acc_species': [],
            'train_acc_health': [],
            'val_acc_health': [],
            'train_acc_both': [],
            'val_acc_both': [],
            'lr': []
        }
    
    def log_epoch(self, epoch, train_stats, val_stats, lr):
        """Log metrics for one epoch"""
        self.history['epoch'].append(epoch)
        self.history['train_loss'].append(train_stats['loss'])
        self.history['val_loss'].append(val_stats['loss'])
        self.history['train_acc_species'].append(train_stats['acc_species'])
        self.history['val_acc_species'].append(val_stats['acc_species'])
        self.history['train_acc_health'].append(train_stats['acc_health'])
        self.history['val_acc_health'].append(val_stats['acc_health'])
        self.history['train_acc_both'].append(train_stats['acc_both'])
        self.history['val_acc_both'].append(val_stats['acc_both'])
        self.history['lr'].append(lr)
    
    def plot_convergence(self, save_path='training_convergence.png', model_name='Model', show=True):
        """
        Create publication-quality convergence plots
        
        Args:
            save_path: Path to save the figure
            model_name: Name of the model for the title
            show: Whether to display the plot in notebook (default: True)
        """
        fig, axes = plt.subplots(2, 2, figsize=(14, 10))
        fig.suptitle(f'{model_name} Training Convergence', fontsize=14, fontweight='bold', y=0.995)
        
        epochs = self.history['epoch']
        
        # Plot 1: Training and Validation Loss
        ax1 = axes[0, 0]
        ax1.plot(epochs, self.history['train_loss'], 'o-', label='Train Loss', 
                linewidth=2, markersize=4, color='#2E86AB', alpha=0.8)
        ax1.plot(epochs, self.history['val_loss'], 's-', label='Val Loss', 
                linewidth=2, markersize=4, color='#A23B72', alpha=0.8)
        ax1.set_xlabel('Epoch')
        ax1.set_ylabel('Loss')
        ax1.set_title('(a) Multi-task Loss', fontweight='bold', loc='left')
        ax1.legend(framealpha=0.9)
        ax1.grid(True, alpha=0.3)
        
        # Plot 2: Species Accuracy
        ax2 = axes[0, 1]
        ax2.plot(epochs, self.history['train_acc_species'], 'o-', label='Train Species', 
                linewidth=2, markersize=4, color='#2E86AB', alpha=0.8)
        ax2.plot(epochs, self.history['val_acc_species'], 's-', label='Val Species', 
                linewidth=2, markersize=4, color='#A23B72', alpha=0.8)
        ax2.set_xlabel('Epoch')
        ax2.set_ylabel('Accuracy')
        ax2.set_title('(b) Species Classification', fontweight='bold', loc='left')
        ax2.set_ylim([0, 1.05])
        ax2.legend(framealpha=0.9)
        ax2.grid(True, alpha=0.3)
        
        # Plot 3: Disease Accuracy
        ax3 = axes[1, 0]
        ax3.plot(epochs, self.history['train_acc_health'], 'o-', label='Train Disease', 
                linewidth=2, markersize=4, color='#2E86AB', alpha=0.8)
        ax3.plot(epochs, self.history['val_acc_health'], 's-', label='Val Disease', 
                linewidth=2, markersize=4, color='#A23B72', alpha=0.8)
        ax3.set_xlabel('Epoch')
        ax3.set_ylabel('Accuracy')
        ax3.set_title('(c) Disease Detection', fontweight='bold', loc='left')
        ax3.set_ylim([0, 1.05])
        ax3.legend(framealpha=0.9)
        ax3.grid(True, alpha=0.3)
        
        # Plot 4: Joint Accuracy (Both Correct)
        ax4 = axes[1, 1]
        ax4.plot(epochs, self.history['train_acc_both'], 'o-', label='Train Both', 
                linewidth=2, markersize=4, color='#2E86AB', alpha=0.8)
        ax4.plot(epochs, self.history['val_acc_both'], 's-', label='Val Both', 
                linewidth=2, markersize=4, color='#A23B72', alpha=0.8)
        
        # Mark best validation epoch
        if len(self.history['val_acc_both']) > 0:
            best_epoch_idx = np.argmax(self.history['val_acc_both'])
            best_epoch = epochs[best_epoch_idx]
            best_val_both = self.history['val_acc_both'][best_epoch_idx]
            ax4.axvline(x=best_epoch, color='red', linestyle='--', 
                       linewidth=1.5, alpha=0.6, label=f'Best (Epoch {best_epoch})')
            ax4.plot(best_epoch, best_val_both, 'r*', markersize=15, 
                    markeredgecolor='darkred', markeredgewidth=1.5)
            
            # Add annotation for best performance
            ax4.annotate(f'{best_val_both:.3f}', 
                        xy=(best_epoch, best_val_both),
                        xytext=(10, -15), textcoords='offset points',
                        bbox=dict(boxstyle='round,pad=0.5', fc='yellow', alpha=0.7),
                        arrowprops=dict(arrowstyle='->', connectionstyle='arc3,rad=0', color='red'),
                        fontsize=9, fontweight='bold')
        
        ax4.set_xlabel('Epoch')
        ax4.set_ylabel('Accuracy')
        ax4.set_title('(d) Joint Accuracy (Primary Metric)', fontweight='bold', loc='left')
        ax4.set_ylim([0, 1.05])
        ax4.legend(framealpha=0.9)
        ax4.grid(True, alpha=0.3)
        
        plt.tight_layout()
        plt.savefig(save_path, bbox_inches='tight', dpi=300)
        print(f"✓ Convergence plot saved to: {save_path}")
        
        if show:
            plt.show()
        else:
            plt.close()
    
    def plot_learning_rate(self, save_path='learning_rate_schedule.png', show=True):
        """Plot learning rate schedule"""
        fig, ax = plt.subplots(figsize=(10, 4))
        
        epochs = self.history['epoch']
        ax.plot(epochs, self.history['lr'], 'o-', linewidth=2, 
               markersize=5, color='#F18F01', label='Learning Rate')
        
        ax.set_xlabel('Epoch')
        ax.set_ylabel('Learning Rate')
        ax.set_title('Learning Rate Schedule (Cosine Annealing)', fontweight='bold')
        ax.legend(framealpha=0.9)
        ax.grid(True, alpha=0.3)
        ax.set_yscale('log')
        
        plt.tight_layout()
        plt.savefig(save_path, bbox_inches='tight', dpi=300)
        print(f"✓ Learning rate plot saved to: {save_path}")
        
        if show:
            plt.show()
        else:
            plt.close()
    
    def save_history(self, save_path='training_history.csv'):
        """Save training history to CSV"""
        import pandas as pd
        df = pd.DataFrame(self.history)
        df.to_csv(save_path, index=False)
        print(f"✓ Training history saved to: {save_path}")

print("✓ TrainingLogger class loaded successfully!")

In [None]:
# Label mappings (updated for CSE465 project)
SPECIES_MAP = {"eggplant": 0, "potato": 1, "tomato": 2}
HEALTH_MAP = {"bacterial": 0, "fungal": 1, "healthy": 2, "virus": 3}

def parse_joint_label(folder_name: str) -> Tuple[int, int]:
    """Parse folder name like 'Eggplant_Healthy' into (species_id, health_id)"""
    name = folder_name.strip()
    if "_" not in name:
        raise ValueError(f"Folder name not joint label: {name}")
    sp, he = name.split("_", 1)
    sp_id = SPECIES_MAP[sp.lower()]
    he_id = HEALTH_MAP[he.lower()]
    return sp_id, he_id

# Dataset class
class JointLeafDataset(Dataset):
    """Dataset that returns separate species and health labels"""
    def __init__(self, split_root: Path, transform=None):
        self.split_root = Path(split_root)
        self.samples: List[Tuple[str, int, int]] = []
        self.transform = transform
        
        for folder in sorted([d for d in self.split_root.iterdir() if d.is_dir()]):
            sp_id, he_id = parse_joint_label(folder.name)
            for p in folder.rglob("*"):
                if p.suffix.lower() in {".jpg", ".jpeg", ".png", ".bmp"}:
                    self.samples.append((str(p), sp_id, he_id))
        
        if len(self.samples) == 0:
            raise RuntimeError(f"No images found under {split_root}")
    
    def __len__(self):
        return len(self.samples)
    
    def __getitem__(self, idx):
        path, sp_id, he_id = self.samples[idx]
        
        try:
            img = Image.open(path).convert('RGB')
        except Exception as e:
            print(f"Error loading {path}: {e}")
            img = Image.new('RGB', (IMG_SIZE, IMG_SIZE))
        
        if self.transform is not None:
            img = self.transform(img)
        
        return img, torch.tensor(sp_id, dtype=torch.long), torch.tensor(he_id, dtype=torch.long)

# Data transforms (ImageNet normalization for pretrained models)
transform = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE), interpolation=transforms.InterpolationMode.BILINEAR),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

# Load datasets
print("Loading datasets...")
train_dataset = JointLeafDataset(DATA_ROOT / "train", transform=transform)
val_dataset = JointLeafDataset(DATA_ROOT / "val", transform=transform)
test_dataset = JointLeafDataset(DATA_ROOT / "test", transform=transform)

print(f"Training samples:   {len(train_dataset)}")
print(f"Validation samples: {len(val_dataset)}")
print(f"Test samples:       {len(test_dataset)}")

# Create data loaders
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, 
                          num_workers=NUM_WORKERS, pin_memory=(device.type=="cuda"))
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, 
                        num_workers=NUM_WORKERS, pin_memory=(device.type=="cuda"))
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False,
                         num_workers=NUM_WORKERS, pin_memory=(device.type=="cuda"))

print("✓ Datasets and loaders ready")

In [None]:
# Model architectures (updated for CSE465)

class MultiTaskDenseNet201(nn.Module):
    """DenseNet201 (TEACHER) with separate species and health classification heads"""
    def __init__(self, num_species=3, num_health=4, pretrained=False, dropout=0.3):
        super().__init__()
        if pretrained:
            weights = DenseNet201_Weights.IMAGENET1K_V1
            self.backbone = densenet201(weights=weights)
        else:
            self.backbone = densenet201(weights=None)
        
        # DenseNet201 has classifier as a single Linear layer, not a Sequential
        in_dim = self.backbone.classifier.in_features
        self.backbone.classifier = nn.Identity()
        
        self.dropout = nn.Dropout(dropout)
        self.head_species = nn.Linear(in_dim, num_species)
        self.head_health = nn.Linear(in_dim, num_health)
    
    def forward(self, x):
        feats = self.backbone(x)
        feats = self.dropout(feats)
        logits_species = self.head_species(feats)
        logits_health = self.head_health(feats)
        return logits_species, logits_health


class MultiTaskEfficientNetB0(nn.Module):
    """EfficientNet-B0 (STUDENT) with separate species and health classification heads"""
    def __init__(self, num_species=3, num_health=4, pretrained=False, dropout=0.3):
        super().__init__()
        if pretrained:
            weights = EfficientNet_B0_Weights.IMAGENET1K_V1
            self.backbone = efficientnet_b0(weights=weights)
        else:
            self.backbone = efficientnet_b0(weights=None)
        
        # EfficientNetB0 has classifier as a Sequential with dropout and linear layer
        in_dim = self.backbone.classifier[1].in_features
        self.backbone.classifier = nn.Identity()
        
        self.dropout = nn.Dropout(dropout)
        self.head_species = nn.Linear(in_dim, num_species)
        self.head_health = nn.Linear(in_dim, num_health)
    
    def forward(self, x):
        feats = self.backbone(x)
        feats = self.dropout(feats)
        logits_species = self.head_species(feats)
        logits_health = self.head_health(feats)
        return logits_species, logits_health

print("✓ Model classes defined")

In [None]:
# Load pre-trained models
print("="*80)
print("Loading Pre-Trained Models")
print("="*80)

# ========== TEACHER (DenseNet201) ==========
print("\n[1/2] Loading Teacher Model (DenseNet201)...")
teacher_path = '/Users/alimran/Desktop/CSE465/best_DenseNet201.pt'

teacher = MultiTaskDenseNet201(num_species=NUM_SPECIES, num_health=NUM_HEALTH, dropout=DROPOUT)
teacher_checkpoint = torch.load(teacher_path, map_location=device, weights_only=False)

teacher.load_state_dict(teacher_checkpoint["model"])
teacher.to(device)
teacher.eval()  # Set to evaluation mode
for param in teacher.parameters():
    param.requires_grad = False  # Freeze teacher

teacher_epoch = teacher_checkpoint.get('epoch', 'N/A')
teacher_val_acc = teacher_checkpoint.get('val_health', 0.0)

print(f"✓ Teacher loaded successfully")
print(f"  Model: DenseNet201")
print(f"  Path: {teacher_path}")
print(f"  Trained Epoch: {teacher_epoch}")
print(f"  Val Health Acc: {teacher_val_acc:.4f}")

# ========== STUDENT (EfficientNet-B0) ==========
print("\n[2/2] Loading Student Model (EfficientNet-B0)...")
student_path = '/Users/alimran/Desktop/CSE465/best_multitask_efficientnet_b0 (1).pt'

student = MultiTaskEfficientNetB0(num_species=NUM_SPECIES, num_health=NUM_HEALTH, dropout=DROPOUT)
student_checkpoint = torch.load(student_path, map_location=device, weights_only=False)

student.load_state_dict(student_checkpoint["model"])
student.to(device)
student.train()  # Set to training mode (will be fine-tuned with KD)

student_epoch = student_checkpoint.get('epoch', 'N/A')
student_val_acc = student_checkpoint.get('val_health', 0.0)

print(f"✓ Student loaded successfully")
print(f"  Model: EfficientNet-B0")
print(f"  Path: {student_path}")
print(f"  Trained Epoch: {student_epoch}")
print(f"  Val Health Acc (before KD): {student_val_acc:.4f}")

# ========== MODEL STATISTICS ==========
teacher_params = sum(p.numel() for p in teacher.parameters())
student_params = sum(p.numel() for p in student.parameters())

print("\n" + "="*80)
print("MODEL STATISTICS")
print("="*80)
print(f"Teacher Parameters:  {teacher_params:,}")
print(f"Student Parameters:  {student_params:,}")
print(f"Compression Ratio:   {teacher_params/student_params:.2f}x")
print(f"Size Reduction:      {(1 - student_params/teacher_params)*100:.1f}%")
print("="*80)

In [None]:
# Knowledge Distillation Loss Function
def kd_loss(student_outputs, teacher_outputs, labels_species, labels_health, 
            temperature=4.0, alpha=0.7):
    """
    Knowledge Distillation loss for multitask learning
    
    Args:
        student_outputs: tuple (species_logits, health_logits) from student
        teacher_outputs: tuple (species_logits, health_logits) from teacher
        labels_species: ground truth species labels
        labels_health: ground truth health labels
        temperature: softening temperature for distillation
        alpha: weight for distillation loss (1-alpha for hard target loss)
    
    Returns:
        Combined KD loss
    """
    student_species, student_health = student_outputs
    teacher_species, teacher_health = teacher_outputs
    
    # Distillation loss for species (soft targets from teacher)
    soft_targets_species = F.softmax(teacher_species / temperature, dim=1)
    soft_prob_species = F.log_softmax(student_species / temperature, dim=1)
    distill_loss_species = F.kl_div(soft_prob_species, soft_targets_species, 
                                     reduction='batchmean') * (temperature ** 2)
    
    # Distillation loss for health (soft targets from teacher)
    soft_targets_health = F.softmax(teacher_health / temperature, dim=1)
    soft_prob_health = F.log_softmax(student_health / temperature, dim=1)
    distill_loss_health = F.kl_div(soft_prob_health, soft_targets_health, 
                                    reduction='batchmean') * (temperature ** 2)
    
    # Hard target loss (ground truth)
    hard_loss_species = F.cross_entropy(student_species, labels_species)
    hard_loss_health = F.cross_entropy(student_health, labels_health)
    
    # Combined losses
    distillation_loss = (distill_loss_species + distill_loss_health) / 2
    hard_loss = (hard_loss_species + hard_loss_health) / 2
    
    # Final KD loss
    total_loss = alpha * distillation_loss + (1 - alpha) * hard_loss
    
    return total_loss

print("✓ KD loss function defined")

In [None]:
# Training setup
optimizer = torch.optim.AdamW(student.parameters(), lr=LR, weight_decay=5e-4)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS, eta_min=LR/100)

# Initialize TrainingLogger
logger = TrainingLogger()

# Training history
history = {
    "train_loss": [], "val_loss": [],
    "train_acc_species": [], "val_acc_species": [],
    "train_acc_health": [], "val_acc_health": [],
    "train_acc_both": [], "val_acc_both": []
}

best_val_health = student_val_acc  # Start from pre-trained performance
best_epoch = 0
epochs_without_improvement = 0

print("\n" + "="*80)
print("TRAINING SETUP COMPLETE")
print("="*80)
print(f"Optimizer:           AdamW")
print(f"Scheduler:           CosineAnnealingLR")
print(f"Starting Val Health: {best_val_health:.4f} (from pre-trained student)")
print(f"Target:              Improve via Knowledge Distillation")
print("="*80)

In [None]:
# Training loop with Knowledge Distillation
print("\n" + "="*80)
print("STARTING KNOWLEDGE DISTILLATION TRAINING")
print("="*80)

for epoch in range(EPOCHS):
    print(f"\n{'='*80}")
    print(f"Epoch {epoch+1}/{EPOCHS} | LR: {optimizer.param_groups[0]['lr']:.2e}")
    print(f"{'='*80}")
    
    # ========== TRAINING PHASE ==========
    student.train()
    train_loss = 0.0
    train_correct_species = 0
    train_correct_health = 0
    train_correct_both = 0
    train_total = 0
    
    print("Training...")
    for imgs, y_species, y_health in tqdm(train_loader, desc="Train"):
        imgs = imgs.to(device, non_blocking=True)
        y_species = y_species.to(device, non_blocking=True)
        y_health = y_health.to(device, non_blocking=True)
        
        batch_size = imgs.size(0)
        
        # Get teacher predictions (no gradient)
        with torch.no_grad():
            teacher_outputs = teacher(imgs)
        
        # Get student predictions
        student_outputs = student(imgs)
        
        # Calculate KD loss
        loss = kd_loss(student_outputs, teacher_outputs, y_species, y_health, 
                      temperature=TEMPERATURE, alpha=ALPHA)
        
        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(student.parameters(), 1.0)  # Gradient clipping
        optimizer.step()
        
        # Calculate accuracies
        student_species, student_health = student_outputs
        preds_species = student_species.argmax(dim=1)
        preds_health = student_health.argmax(dim=1)
        
        train_correct_species += (preds_species == y_species).sum().item()
        train_correct_health += (preds_health == y_health).sum().item()
        both_correct = (preds_species == y_species) & (preds_health == y_health)
        train_correct_both += both_correct.sum().item()
        
        train_loss += loss.item() * batch_size
        train_total += batch_size
    
    # Training metrics
    train_loss_avg = train_loss / train_total
    train_acc_species = train_correct_species / train_total
    train_acc_health = train_correct_health / train_total
    train_acc_both = train_correct_both / train_total
    
    history["train_loss"].append(train_loss_avg)
    history["train_acc_species"].append(train_acc_species)
    history["train_acc_health"].append(train_acc_health)
    history["train_acc_both"].append(train_acc_both)
    
    # ========== VALIDATION PHASE ==========
    student.eval()
    val_loss = 0.0
    val_correct_species = 0
    val_correct_health = 0
    val_correct_both = 0
    val_total = 0
    
    print("Validating...")
    with torch.no_grad():
        for imgs, y_species, y_health in tqdm(val_loader, desc="Val"):
            imgs = imgs.to(device, non_blocking=True)
            y_species = y_species.to(device, non_blocking=True)
            y_health = y_health.to(device, non_blocking=True)
            
            batch_size = imgs.size(0)
            
            # Get predictions
            teacher_outputs = teacher(imgs)
            student_outputs = student(imgs)
            
            # Calculate loss
            loss = kd_loss(student_outputs, teacher_outputs, y_species, y_health, 
                          temperature=TEMPERATURE, alpha=ALPHA)
            
            # Calculate accuracies
            student_species, student_health = student_outputs
            preds_species = student_species.argmax(dim=1)
            preds_health = student_health.argmax(dim=1)
            
            val_correct_species += (preds_species == y_species).sum().item()
            val_correct_health += (preds_health == y_health).sum().item()
            both_correct = (preds_species == y_species) & (preds_health == y_health)
            val_correct_both += both_correct.sum().item()
            
            val_loss += loss.item() * batch_size
            val_total += batch_size
    
    # Validation metrics
    val_loss_avg = val_loss / val_total
    val_acc_species = val_correct_species / val_total
    val_acc_health = val_correct_health / val_total
    val_acc_both = val_correct_both / val_total
    
    history["val_loss"].append(val_loss_avg)
    history["val_acc_species"].append(val_acc_species)
    history["val_acc_health"].append(val_acc_health)
    history["val_acc_both"].append(val_acc_both)
    
    # Log to TrainingLogger
    train_stats = {
        'loss': train_loss_avg,
        'acc_species': train_acc_species,
        'acc_health': train_acc_health,
        'acc_both': train_acc_both
    }
    val_stats = {
        'loss': val_loss_avg,
        'acc_species': val_acc_species,
        'acc_health': val_acc_health,
        'acc_both': val_acc_both
    }
    logger.log_epoch(epoch+1, train_stats, val_stats, optimizer.param_groups[0]['lr'])
    
    # Print epoch summary
    print(f"\n{'EPOCH SUMMARY':^80}")
    print("-"*80)
    print(f"  Train - Loss: {train_loss_avg:.4f} | Species: {train_acc_species:.3f} | "
          f"Health: {train_acc_health:.3f} | Both: {train_acc_both:.3f}")
    print(f"  Val   - Loss: {val_loss_avg:.4f} | Species: {val_acc_species:.3f} | "
          f"Health: {val_acc_health:.3f} | Both: {val_acc_both:.3f}")
    
    # Save best model (based on health accuracy)
    if val_acc_health > best_val_health:
        best_val_health = val_acc_health
        best_epoch = epoch
        epochs_without_improvement = 0
        torch.save({
            "model": student.state_dict(),
            "optimizer": optimizer.state_dict(),
            "epoch": epoch,
            "val_health": best_val_health,
            "val_species": val_acc_species,
            "val_both": val_acc_both,
            "history": history
        }, "/Users/alimran/Desktop/CSE465/best_kd_student_efficientnetb0.pt")
        print(f"  ★ New best model saved! Val Health Acc: {best_val_health:.4f}")
    else:
        epochs_without_improvement += 1
        print(f"  No improvement. Epochs without improvement: {epochs_without_improvement}/{PATIENCE}")
    
    # Early stopping
    if epochs_without_improvement >= PATIENCE:
        print(f"\n{'⚠ EARLY STOPPING TRIGGERED':^80}")
        print(f"No improvement for {PATIENCE} epochs. Stopping training.")
        break
    
    print(f"{'='*80}")
    
    # Update scheduler
    scheduler.step()

print("\n" + "="*80)
print("KNOWLEDGE DISTILLATION TRAINING COMPLETE")
print("="*80)
print(f"Best epoch: {best_epoch+1} with val_health={best_val_health:.4f}")
print(f"Total epochs trained: {epoch+1}/{EPOCHS}")
print(f"Improvement over pre-trained: {(best_val_health - student_val_acc)*100:.2f}%")
print(f"Model saved to: best_kd_student_efficientnetb0.pt")
print("="*80)

## Plot Saving

## Final Testing & Comparison

In [None]:
# Test all three models: Teacher, Pre-trained Student, KD Student

print("\n" + "="*80)
print("COMPREHENSIVE TESTING - TEACHER vs STUDENT (Before/After KD)")
print("="*80)

def evaluate_model(model, loader, model_name):
    """Evaluate a model on the test set"""
    model.eval()
    correct_species = 0
    correct_health = 0
    correct_both = 0
    total = 0
    
    with torch.no_grad():
        for imgs, y_species, y_health in tqdm(loader, desc=f"Testing {model_name}"):
            imgs = imgs.to(device, non_blocking=True)
            y_species = y_species.to(device, non_blocking=True)
            y_health = y_health.to(device, non_blocking=True)
            
            logits_species, logits_health = model(imgs)
            preds_species = logits_species.argmax(dim=1)
            preds_health = logits_health.argmax(dim=1)
            
            correct_species += (preds_species == y_species).sum().item()
            correct_health += (preds_health == y_health).sum().item()
            both_correct = (preds_species == y_species) & (preds_health == y_health)
            correct_both += both_correct.sum().item()
            
            total += imgs.size(0)
    
    return {
        'species': correct_species / total,
        'health': correct_health / total,
        'both': correct_both / total
    }

# Load pre-trained student (before KD) for comparison
pretrained_student = MultiTaskEfficientNetB0(num_species=NUM_SPECIES, num_health=NUM_HEALTH, dropout=DROPOUT)
pretrained_checkpoint = torch.load('/Users/alimran/Desktop/CSE465/best_multitask_efficientnet_b0 (1).pt', 
                                   map_location=device, weights_only=False)
pretrained_student.load_state_dict(pretrained_checkpoint["model"])
pretrained_student.to(device)
pretrained_student.eval()

# Load KD-trained student
kd_student = MultiTaskEfficientNetB0(num_species=NUM_SPECIES, num_health=NUM_HEALTH, dropout=DROPOUT)
kd_checkpoint = torch.load('/Users/alimran/Desktop/CSE465/best_kd_student_efficientnetb0.pt', 
                           map_location=device, weights_only=False)
kd_student.load_state_dict(kd_checkpoint["model"])
kd_student.to(device)
kd_student.eval()

# Test all three models
print("\n[1/3] Testing Teacher (DenseNet201)...")
teacher_results = evaluate_model(teacher, test_loader, "Teacher")

print("\n[2/3] Testing Pre-trained Student (EfficientNet-B0)...")
pretrained_results = evaluate_model(pretrained_student, test_loader, "Pre-trained Student")

print("\n[3/3] Testing KD Student (EfficientNet-B0 + Knowledge Distillation)...")
kd_results = evaluate_model(kd_student, test_loader, "KD Student")

# Display comparison table
print("\n" + "="*80)
print("TEST SET RESULTS COMPARISON")
print("="*80)
print(f"{'Model':<40} {'Species':<12} {'Health':<12} {'Both':<12} {'Params':<15}")
print("-"*80)
print(f"{'Teacher (DenseNet201)':<40} {teacher_results['species']:.4f}      {teacher_results['health']:.4f}      {teacher_results['both']:.4f}      {teacher_params:,}")
print(f"{'Student (EfficientNet-B0)':<40} {pretrained_results['species']:.4f}      {pretrained_results['health']:.4f}      {pretrained_results['both']:.4f}      {student_params:,}")
print(f"{'KD Student (EfficientNet-B0 + KD)':<40} {kd_results['species']:.4f}      {kd_results['health']:.4f}      {kd_results['both']:.4f}      {student_params:,}")
print("="*80)

# Calculate improvements
health_improvement = (kd_results['health'] - pretrained_results['health']) * 100
both_improvement = (kd_results['both'] - pretrained_results['both']) * 100
teacher_gap_before = (teacher_results['health'] - pretrained_results['health']) * 100
teacher_gap_after = (teacher_results['health'] - kd_results['health']) * 100

print("\nKEY INSIGHTS:")
print("-"*80)
print(f"1. KD Improvement (Health Acc):    {health_improvement:+.2f}%")
print(f"2. KD Improvement (Both Acc):      {both_improvement:+.2f}%")
print(f"3. Gap to Teacher (Before KD):     {teacher_gap_before:.2f}%")
print(f"4. Gap to Teacher (After KD):      {teacher_gap_after:.2f}%")
print(f"5. Model Size Reduction:           {(1 - student_params/teacher_params)*100:.1f}%")
print(f"6. Compression Ratio:              {teacher_params/student_params:.2f}x smaller")
print("="*80)

In [None]:
# -------------------------------
# Save Individual Training Plots
# -------------------------------
print("\nGenerating Knowledge Distillation training plots...")

epochs_range = range(1, len(history['train_loss']) + 1)

# 1. Training Loss
fig, ax = plt.subplots(figsize=(10, 6))
ax.plot(epochs_range, history['train_loss'], 'b-o', label='Train Loss', linewidth=2, markersize=6)
ax.set_xlabel('Epoch', fontsize=12)
ax.set_ylabel('Loss', fontsize=12)
ax.set_title('Knowledge Distillation - Training Loss Over Epochs', fontsize=14, fontweight='bold')
ax.legend(fontsize=11)
ax.grid(True, alpha=0.3)
plt.tight_layout()
plt.savefig('kd_plot_train_loss.png', dpi=150, bbox_inches='tight')
print("✓ Saved kd_plot_train_loss.png")
plt.show()
plt.close()

# 2. Validation Loss
fig, ax = plt.subplots(figsize=(10, 6))
ax.plot(epochs_range, history['val_loss'], 'r-o', label='Validation Loss', linewidth=2, markersize=6)
ax.set_xlabel('Epoch', fontsize=12)
ax.set_ylabel('Loss', fontsize=12)
ax.set_title('Knowledge Distillation - Validation Loss Over Epochs', fontsize=14, fontweight='bold')
ax.legend(fontsize=11)
ax.grid(True, alpha=0.3)
plt.tight_layout()
plt.savefig('kd_plot_val_loss.png', dpi=150, bbox_inches='tight')
print("✓ Saved kd_plot_val_loss.png")
plt.show()
plt.close()

# 3. Train vs Val Loss Comparison
fig, ax = plt.subplots(figsize=(10, 6))
ax.plot(epochs_range, history['train_loss'], 'b-o', label='Train Loss', linewidth=2, markersize=6)
ax.plot(epochs_range, history['val_loss'], 'r-o', label='Val Loss', linewidth=2, markersize=6)
ax.set_xlabel('Epoch', fontsize=12)
ax.set_ylabel('Loss', fontsize=12)
ax.set_title('Knowledge Distillation - Training vs Validation Loss', fontsize=14, fontweight='bold')
ax.legend(fontsize=11)
ax.grid(True, alpha=0.3)
plt.tight_layout()
plt.savefig('kd_plot_loss_comparison.png', dpi=150, bbox_inches='tight')
print("✓ Saved kd_plot_loss_comparison.png")
plt.show()
plt.close()

# 4. Species Accuracy
fig, ax = plt.subplots(figsize=(10, 6))
ax.plot(epochs_range, history['train_acc_species'], 'b-o', label='Train Species Acc', linewidth=2, markersize=6)
ax.plot(epochs_range, history['val_acc_species'], 'r-o', label='Val Species Acc', linewidth=2, markersize=6)
ax.set_xlabel('Epoch', fontsize=12)
ax.set_ylabel('Accuracy', fontsize=12)
ax.set_title('Knowledge Distillation - Species Classification Accuracy', fontsize=14, fontweight='bold')
ax.legend(fontsize=11)
ax.grid(True, alpha=0.3)
ax.set_ylim([0, 1.05])
plt.tight_layout()
plt.savefig('kd_plot_species_accuracy.png', dpi=150, bbox_inches='tight')
print("✓ Saved kd_plot_species_accuracy.png")
plt.show()
plt.close()

# 5. Health Accuracy
fig, ax = plt.subplots(figsize=(10, 6))
ax.plot(epochs_range, history['train_acc_health'], 'b-o', label='Train Health Acc', linewidth=2, markersize=6)
ax.plot(epochs_range, history['val_acc_health'], 'r-o', label='Val Health Acc', linewidth=2, markersize=6)
ax.set_xlabel('Epoch', fontsize=12)
ax.set_ylabel('Accuracy', fontsize=12)
ax.set_title('Knowledge Distillation - Health/Disease Classification Accuracy', fontsize=14, fontweight='bold')
ax.legend(fontsize=11)
ax.grid(True, alpha=0.3)
ax.set_ylim([0, 1.05])
plt.tight_layout()
plt.savefig('kd_plot_health_accuracy.png', dpi=150, bbox_inches='tight')
print("✓ Saved kd_plot_health_accuracy.png")
plt.show()
plt.close()

# 6. Joint (Both) Accuracy
fig, ax = plt.subplots(figsize=(10, 6))
ax.plot(epochs_range, history['train_acc_both'], 'b-o', label='Train Both Acc', linewidth=2, markersize=6)
ax.plot(epochs_range, history['val_acc_both'], 'r-o', label='Val Both Acc', linewidth=2, markersize=6)
if best_val_health > 0:
    ax.axhline(y=best_val_health, color='g', linestyle='--', linewidth=2, 
               label=f'Best Val Health: {best_val_health:.3f}')
ax.set_xlabel('Epoch', fontsize=12)
ax.set_ylabel('Accuracy', fontsize=12)
ax.set_title('Knowledge Distillation - Joint Classification Accuracy', fontsize=14, fontweight='bold')
ax.legend(fontsize=11)
ax.grid(True, alpha=0.3)
ax.set_ylim([0, 1.05])
plt.tight_layout()
plt.savefig('kd_plot_joint_accuracy.png', dpi=150, bbox_inches='tight')
print("✓ Saved kd_plot_joint_accuracy.png")
plt.show()
plt.close()

# 7. All Accuracies Together
fig, ax = plt.subplots(figsize=(12, 7))
ax.plot(epochs_range, history['train_acc_species'], 'b-o', label='Train Species', linewidth=2, markersize=5)
ax.plot(epochs_range, history['val_acc_species'], 'b--s', label='Val Species', linewidth=2, markersize=5)
ax.plot(epochs_range, history['train_acc_health'], 'g-o', label='Train Health', linewidth=2, markersize=5)
ax.plot(epochs_range, history['val_acc_health'], 'g--s', label='Val Health', linewidth=2, markersize=5)
ax.plot(epochs_range, history['train_acc_both'], 'r-o', label='Train Both', linewidth=2, markersize=5)
ax.plot(epochs_range, history['val_acc_both'], 'r--s', label='Val Both', linewidth=2, markersize=5)
if best_val_health > 0:
    ax.axhline(y=best_val_health, color='orange', linestyle='--', linewidth=2, 
               label=f'Best: {best_val_health:.3f}')
ax.set_xlabel('Epoch', fontsize=12)
ax.set_ylabel('Accuracy', fontsize=12)
ax.set_title('Knowledge Distillation - All Metrics Over Training', fontsize=14, fontweight='bold')
ax.legend(fontsize=10, ncol=2)
ax.grid(True, alpha=0.3)
ax.set_ylim([0, 1.05])
plt.tight_layout()
plt.savefig('kd_plot_all_metrics.png', dpi=150, bbox_inches='tight')
print("✓ Saved kd_plot_all_metrics.png")
plt.show()
plt.close()

print("\n" + "="*80)
print("All KD training plots saved:")
print("  - kd_plot_train_loss.png")
print("  - kd_plot_val_loss.png")
print("  - kd_plot_loss_comparison.png")
print("  - kd_plot_species_accuracy.png")
print("  - kd_plot_health_accuracy.png")
print("  - kd_plot_joint_accuracy.png")
print("  - kd_plot_all_metrics.png")
print("="*80)

## Comprehensive Testing with Metrics

In [None]:
# Import metrics libraries
from sklearn.metrics import (
    classification_report, 
    confusion_matrix, 
    accuracy_score, 
    precision_recall_fscore_support,
    f1_score
)

# -------------------------------
# Comprehensive Testing Function
# -------------------------------
def comprehensive_test(model, test_loader, device, species_map, health_map):
    """
    Perform comprehensive testing with metrics and visualizations
    """
    model.eval()
    
    # Storage for predictions and ground truth
    all_species_preds = []
    all_species_true = []
    all_health_preds = []
    all_health_true = []
    all_both_correct = []
    
    print("Running comprehensive test evaluation...")
    
    with torch.no_grad():
        for batch_idx, (imgs, y_species, y_health) in enumerate(test_loader):
            imgs = imgs.to(device, non_blocking=True)
            y_species = y_species.to(device, non_blocking=True)
            y_health = y_health.to(device, non_blocking=True)
            
            # Get predictions
            logits_species, logits_health = model(imgs)
            preds_species = logits_species.argmax(dim=1)
            preds_health = logits_health.argmax(dim=1)
            
            # Store predictions and ground truth
            all_species_preds.extend(preds_species.cpu().numpy())
            all_species_true.extend(y_species.cpu().numpy())
            all_health_preds.extend(preds_health.cpu().numpy())
            all_health_true.extend(y_health.cpu().numpy())
            
            # Check if both predictions are correct
            both_correct = ((preds_species == y_species) & (preds_health == y_health)).cpu().numpy()
            all_both_correct.extend(both_correct)
    
    # Convert to numpy arrays
    all_species_preds = np.array(all_species_preds)
    all_species_true = np.array(all_species_true)
    all_health_preds = np.array(all_health_preds)
    all_health_true = np.array(all_health_true)
    all_both_correct = np.array(all_both_correct)
    
    # Reverse mapping for labels
    species_labels = {v: k.capitalize() for k, v in species_map.items()}
    health_labels = {v: k.capitalize() for k, v in health_map.items()}
    
    # -------------------------------
    # Print Metrics
    # -------------------------------
    print("\n" + "="*80)
    print("COMPREHENSIVE TEST RESULTS - KNOWLEDGE DISTILLATION STUDENT")
    print("="*80)
    
    # Overall accuracies
    species_acc = accuracy_score(all_species_true, all_species_preds)
    health_acc = accuracy_score(all_health_true, all_health_preds)
    both_acc = all_both_correct.mean()
    
    print(f"\n{'OVERALL ACCURACIES':^80}")
    print("-"*80)
    print(f"  Species Classification:  {species_acc:.4f} ({species_acc*100:.2f}%)")
    print(f"  Health Classification:   {health_acc:.4f} ({health_acc*100:.2f}%)")
    print(f"  Both Correct (Joint):    {both_acc:.4f} ({both_acc*100:.2f}%)")
    
    # Species Classification Report
    print(f"\n{'SPECIES CLASSIFICATION REPORT':^80}")
    print("-"*80)
    print(classification_report(
        all_species_true, 
        all_species_preds,
        target_names=[species_labels[i] for i in sorted(species_labels.keys())],
        digits=4
    ))
    
    # Health Classification Report
    print(f"\n{'HEALTH/DISEASE CLASSIFICATION REPORT':^80}")
    print("-"*80)
    print(classification_report(
        all_health_true, 
        all_health_preds,
        target_names=[health_labels[i] for i in sorted(health_labels.keys())],
        digits=4
    ))
    
    # Per-class joint accuracy
    print(f"\n{'PER-CLASS JOINT ACCURACY':^80}")
    print("-"*80)
    for sp_id in sorted(species_labels.keys()):
        for he_id in sorted(health_labels.keys()):
            # Find samples of this joint class
            mask = (all_species_true == sp_id) & (all_health_true == he_id)
            if mask.sum() > 0:
                joint_acc = all_both_correct[mask].mean()
                count = mask.sum()
                sp_name = species_labels[sp_id]
                he_name = health_labels[he_id]
                print(f"  {sp_name:8s} + {he_name:12s}: {joint_acc:.4f} ({joint_acc*100:.2f}%) [{count:4d} samples]")
    
    # -------------------------------
    # Visualizations
    # -------------------------------
    
    # Species confusion matrix
    fig, ax = plt.subplots(figsize=(8, 6))
    cm_species = confusion_matrix(all_species_true, all_species_preds)
    sns.heatmap(cm_species, annot=True, fmt='d', cmap='Blues', ax=ax,
                xticklabels=[species_labels[i] for i in sorted(species_labels.keys())],
                yticklabels=[species_labels[i] for i in sorted(species_labels.keys())])
    ax.set_title(f'KD Student - Species Classification\nAccuracy: {species_acc:.2%}', fontsize=12, fontweight='bold')
    ax.set_ylabel('True Label')
    ax.set_xlabel('Predicted Label')

    plt.tight_layout()
    plt.savefig('kd_confusion_matrix_species.png', dpi=150, bbox_inches='tight')
    print(f"\n✓ Saved species confusion matrix to 'kd_confusion_matrix_species.png'")
    plt.show()
    plt.close()

    # Health confusion matrix
    fig, ax = plt.subplots(figsize=(8, 6))
    cm_health = confusion_matrix(all_health_true, all_health_preds)
    sns.heatmap(cm_health, annot=True, fmt='d', cmap='Greens', ax=ax,
                xticklabels=[health_labels[i] for i in sorted(health_labels.keys())],
                yticklabels=[health_labels[i] for i in sorted(health_labels.keys())])
    ax.set_title(f'KD Student - Health/Disease Classification\nAccuracy: {health_acc:.2%}', fontsize=12, fontweight='bold')
    ax.set_ylabel('True Label')
    ax.set_xlabel('Predicted Label')

    plt.tight_layout()
    plt.savefig('kd_confusion_matrix_health.png', dpi=150, bbox_inches='tight')
    print(f"✓ Saved health confusion matrix to 'kd_confusion_matrix_health.png'")
    plt.show()
    plt.close()
    
    # Joint Accuracy Heatmap
    fig, ax = plt.subplots(figsize=(8, 6))
    joint_acc_matrix = np.zeros((len(species_labels), len(health_labels)))
    
    for sp_id in sorted(species_labels.keys()):
        for he_id in sorted(health_labels.keys()):
            mask = (all_species_true == sp_id) & (all_health_true == he_id)
            if mask.sum() > 0:
                joint_acc_matrix[sp_id, he_id] = all_both_correct[mask].mean()
    
    sns.heatmap(joint_acc_matrix, annot=True, fmt='.3f', cmap='RdYlGn', 
                xticklabels=[health_labels[i] for i in sorted(health_labels.keys())],
                yticklabels=[species_labels[i] for i in sorted(species_labels.keys())],
                vmin=0, vmax=1, ax=ax, cbar_kws={'label': 'Accuracy'})
    ax.set_title('KD Student - Joint Classification Accuracy by Class', fontsize=12, fontweight='bold')
    ax.set_ylabel('Species')
    ax.set_xlabel('Health Status')
    
    plt.tight_layout()
    plt.savefig('kd_joint_accuracy_heatmap.png', dpi=150, bbox_inches='tight')
    print(f"✓ Saved joint accuracy heatmap to 'kd_joint_accuracy_heatmap.png'")
    plt.show()
    plt.close()
    
    print("\n" + "="*80)
    print("Testing complete! Generated visualizations:")
    print("  - kd_confusion_matrix_species.png")
    print("  - kd_confusion_matrix_health.png")
    print("  - kd_joint_accuracy_heatmap.png")
    print("="*80 + "\n")
    
    return {
        'species_accuracy': species_acc,
        'health_accuracy': health_acc,
        'joint_accuracy': both_acc,
        'species_preds': all_species_preds,
        'species_true': all_species_true,
        'health_preds': all_health_preds,
        'health_true': all_health_true
    }

print("✓ Comprehensive testing function defined")

In [None]:
# -------------------------------
# Load Test Dataset and Run Comprehensive Testing
# -------------------------------

# Load test dataset
print("Loading test dataset...")
test_dataset = JointLeafDataset(DATA_ROOT / "test", transform=transform)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, 
                        num_workers=NUM_WORKERS, pin_memory=(device.type=="cuda"))
print(f"Test samples: {len(test_dataset)}")

# Load the best model
print("\nLoading best KD student model...")
checkpoint = torch.load("/Users/alimran/Desktop/CSE499A/best_kd_student.pt", map_location=device)
student.load_state_dict(checkpoint["model"])
print(f"Loaded best model from epoch {checkpoint['epoch']+1} with val_both={checkpoint['val_both']:.3f}")

# Run comprehensive testing
test_results = comprehensive_test(
    model=student,
    test_loader=test_loader,
    device=device,
    species_map=SPECIES_MAP,
    health_map=HEALTH_MAP
)

## 10 Sample Visualization from 3 Classes

In [None]:
# Load model
# Define your label mappings
species_labels = {
    0: 'Guava',
    1: 'Mango',
    2: 'Papaya'
}

health_labels = {
    0: 'Healthy',
    1: 'Anthracnose',
}

# Load best model
checkpoint = torch.load("/Users/alimran/Desktop/CSE499A/best_kd_student.pt", map_location=device)
student.load_state_dict(checkpoint["model"])
student.eval()

# Number of samples to display per class
amount = 10

# Collect samples from validation set
sample_images_by_class = {0: [], 1: [], 2: []}
sample_predictions_by_class = {0: [], 1: [], 2: []}
sample_ground_truth_by_class = {0: [], 1: [], 2: []}

with torch.no_grad():
    for images, species_batch, health_batch in val_loader:
        images = images.to(device)
        
        outputs = student(images)
        species_preds = outputs[0].argmax(1)
        health_preds = outputs[1].argmax(1)
        
        for i in range(len(images)):
            species_class = species_batch[i].item()
            
            if len(sample_images_by_class[species_class]) < amount:
                sample_images_by_class[species_class].append(images[i].cpu())
                sample_predictions_by_class[species_class].append({
                    'species': species_preds[i].item(),
                    'health': health_preds[i].item()
                })
                sample_ground_truth_by_class[species_class].append({
                    'species': species_batch[i].item(),
                    'health': health_batch[i].item()
                })
        
        if all(len(samples) >= amount for samples in sample_images_by_class.values()):
            break

# Visualize
mean = np.array([0.485, 0.456, 0.406])
std = np.array([0.229, 0.224, 0.225])

fig, axes = plt.subplots(3, amount, figsize=(3*amount, 9))

for row, species_idx in enumerate(sorted(sample_images_by_class.keys())):
    for col in range(amount):
        ax = axes[row, col]
        
        img = sample_images_by_class[species_idx][col]
        pred = sample_predictions_by_class[species_idx][col]
        gt = sample_ground_truth_by_class[species_idx][col]
        
        # Denormalize and display
        img_display = img.numpy().transpose(1, 2, 0)
        img_display = std * img_display + mean
        img_display = np.clip(img_display, 0, 1)
        
        ax.imshow(img_display)
        ax.axis('off')
        
        # Check correctness
        both_correct = (pred['species'] == gt['species']) and (pred['health'] == gt['health'])
        
        # Create title
        pred_sp = species_labels[pred['species']]
        pred_he = health_labels[pred['health']]
        gt_sp = species_labels[gt['species']]
        gt_he = health_labels[gt['health']]
        
        title = f"Pred: {pred_sp}, {pred_he}\nTrue: {gt_sp}, {gt_he}"
        color = 'green' if both_correct else 'red'
        ax.set_title(title, fontsize=8, color=color, fontweight='bold')
    
    # Add species label
    fig.text(0.02, 0.5 + (1 - row) * 0.3, species_labels[species_idx], 
             fontsize=12, fontweight='bold', va='center', rotation=90)

plt.suptitle(f'KD Student Sample Predictions - {amount} Samples per Class\n(Green=Correct, Red=Incorrect)', 
             fontsize=14, fontweight='bold', y=0.995)
plt.tight_layout(rect=[0.05, 0, 1, 0.99])
plt.savefig('kd_sample_predictions.png', dpi=150, bbox_inches='tight')
print(f"✓ Saved sample predictions ({3*amount} total: {amount} per class) to 'kd_sample_predictions.png'")
plt.show()
plt.close()