## Imports

In [None]:
import random
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils import clip_grad_norm_
from torchvision import transforms
from PIL import Image
from pathlib import Path
from typing import Tuple, List
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import confusion_matrix, classification_report, accuracy_score
import timm

print(f"PyTorch version: {torch.__version__}")
print(f"timm version: {timm.__version__}")


##CONFIGURATION - ALL HYPERPARAMETERS

In [None]:
DATA_ROOT = Path("/Users/alimran/Desktop/CSE465/Split_Dataset")
IMG_SIZE = 224
BATCH_SIZE = 32
NUM_WORKERS = 0
EPOCHS = 50
LR = 1e-4
WEIGHT_DECAY = 5e-4
DROPOUT = 0.3
GRADIENT_CLIP = 1.0
SEED = 42
LOSS_WEIGHT_HEALTH = 1.0
USE_AMP = True
PATIENCE = 5

# Model Config
NUM_SPECIES = 3
NUM_HEALTH = 4
PRETRAINED = False

# Swin Model Selection
# Available models: 'swin_tiny_patch4_window7_224', 'swin_small_patch4_window7_224', 'swin_base_patch4_window7_224'
SWIN_MODEL = 'swin_tiny_patch4_window7_224'  # Swin Tiny

# Model name
MODEL_NAME = f"best_multitask_{SWIN_MODEL}"

# Device setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)

# Set seeds for reproducibility
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed(SEED)
    torch.cuda.manual_seed_all(SEED)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

print(f"✓ Configuration loaded")
print(f"  - Model: {SWIN_MODEL}")
print(f"  - Pretrained: {PRETRAINED}")
print(f"  - Data: {DATA_ROOT}")
print(f"  - Image size: {IMG_SIZE}x{IMG_SIZE}")
print(f"  - Batch size: {BATCH_SIZE}")
print(f"  - Epochs: {EPOCHS}")
print(f"  - Learning rate: {LR}")


## LABEL MAPS

In [None]:
SPECIES_MAP = {"eggplant": 0, "potato": 1, "tomato": 2}
HEALTH_MAP = {"bacterial": 0, "fungal": 1, "healthy": 2, "virus": 3}

def parse_joint_label(folder_name: str) -> Tuple[int, int]:
    """Parse folder name into species and health IDs"""
    name = folder_name.strip()
    if "_" not in name:
        raise ValueError(f"Folder name not joint label: {name}")
    sp, he = name.split("_", 1)
    
    if sp.lower() not in SPECIES_MAP:
        raise KeyError(f"Unknown species: {sp} (available: {list(SPECIES_MAP.keys())})")
    if he.lower() not in HEALTH_MAP:
        raise KeyError(f"Unknown health status: {he} (available: {list(HEALTH_MAP.keys())})")
    
    sp_id = SPECIES_MAP[sp.lower()]
    he_id = HEALTH_MAP[he.lower()]
    return sp_id, he_id

## DATASET

In [None]:
class JointLeafDataset(Dataset):
    def __init__(self, split_root: Path, transform=None):
        self.split_root = Path(split_root)
        self.samples: List[Tuple[str, int, int]] = []
        self.transform = transform
        
        # Check if split_root exists
        if not self.split_root.exists():
            raise RuntimeError(
                f"Directory does not exist: {split_root}\n"
                f"Please check if the path is correct."
            )
        
        # Get all subdirectories
        subdirs = [d for d in self.split_root.iterdir() if d.is_dir()]
        
        if len(subdirs) == 0:
            raise RuntimeError(
                f"No subdirectories found in: {split_root}\n"
                f"Expected structure: {split_root}/species_health/images.jpg\n"
                f"Example: {split_root}/guava_healthy/img001.jpg"
            )
        
        print(f"   Found {len(subdirs)} subdirectories in {split_root.name}")
        
        for folder in sorted(subdirs):
            try:
                sp_id, he_id = parse_joint_label(folder.name)
            except (ValueError, KeyError) as e:
                print(f"   ⚠ Skipping folder '{folder.name}': {e}")
                continue
            
            # Find images
            images_in_folder = []
            for p in folder.rglob("*"):
                if p.suffix.lower() in {".jpg", ".jpeg", ".png", ".bmp"}:
                    images_in_folder.append(str(p))
            
            print(f"   - {folder.name}: {len(images_in_folder)} images")
            
            for img_path in images_in_folder:
                self.samples.append((img_path, sp_id, he_id))
        
        if len(self.samples) == 0:
            raise RuntimeError(
                f"No images found under {split_root}\n"
                f"Found directories: {[d.name for d in subdirs]}\n"
                f"Expected directory names: species_health (e.g., 'guava_healthy' or 'Guava_Healthy')\n"
                f"Supported image formats: .jpg, .jpeg, .png, .bmp\n"
                f"Please check:\n"
                f"  1. Directory names follow 'species_health' format\n"
                f"  2. Images are in correct format\n"
                f"  3. Images are inside the subdirectories"
            )
    
    def __len__(self):
        return len(self.samples)
    
    def __getitem__(self, idx):
        path, sp_id, he_id = self.samples[idx]
        
        try:
            img = Image.open(path).convert('RGB')
        except Exception as e:
            print(f"Error loading image {path}: {e}")
            img = Image.new('RGB', (IMG_SIZE, IMG_SIZE))
        
        if self.transform is not None:
            img = self.transform(img)
        
        return img, torch.tensor(sp_id, dtype=torch.long), torch.tensor(he_id, dtype=torch.long)


## TRANSFORMS

In [None]:
train_tf = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE), interpolation=transforms.InterpolationMode.BILINEAR),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

eval_tf = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE), interpolation=transforms.InterpolationMode.BILINEAR),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

## DATASETS & LOADERS

In [None]:
print("\n" + "="*80)
print("Loading datasets...")
print("="*80)
train_ds = JointLeafDataset(DATA_ROOT / "train", transform=train_tf)
print()
val_ds = JointLeafDataset(DATA_ROOT / "val", transform=eval_tf)
print()
test_ds = JointLeafDataset(DATA_ROOT / "test", transform=eval_tf)

print("\n" + "="*80)
print(f"Dataset Summary:")
print(f"  Train: {len(train_ds)} images")
print(f"  Val:   {len(val_ds)} images")
print(f"  Test:  {len(test_ds)} images")
print(f"  Total: {len(train_ds) + len(val_ds) + len(test_ds)} images")
print("="*80)

# Test loading one sample
print("\nTesting sample loading...")
try:
    sample_img, sample_sp, sample_he = train_ds[0]
    print(f"✓ Sample loaded successfully: shape={sample_img.shape}, species={sample_sp}, health={sample_he}")
except Exception as e:
    print(f"✗ Failed to load sample: {e}")
    raise

train_loader = DataLoader(
    train_ds,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=NUM_WORKERS,
    pin_memory=(device.type == "cuda")
)
val_loader = DataLoader(
    val_ds,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=NUM_WORKERS,
    pin_memory=(device.type == "cuda")
)
test_loader = DataLoader(
    test_ds,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=NUM_WORKERS,
    pin_memory=(device.type == "cuda")
)

# Test DataLoader
print("\nTesting DataLoader...")
try:
    for i, (imgs, sp, he) in enumerate(train_loader):
        print(f"✓ Batch {i}: imgs {imgs.shape}, species {sp.shape}, health {he.shape}")
        if i >= 2:
            break
    print("✓ DataLoader test passed!\n")
except Exception as e:
    print(f"✗ DataLoader failed: {e}")
    import traceback
    traceback.print_exc()
    raise

## MODEL DEFINITION

In [None]:
class MultiTaskSwin(nn.Module):
    """
    Multi-task Swin Transformer model for species and health classification
    Uses Swin as backbone with two separate classification heads
    """
    def __init__(self, num_species=NUM_SPECIES, num_health=NUM_HEALTH, 
                 model_name=SWIN_MODEL, pretrained=PRETRAINED, dropout=DROPOUT):
        super().__init__()
        
        # Load Swin model from timm
        print(f"Loading Swin model: {model_name}")
        try:
            self.backbone = timm.create_model(
                model_name,
                pretrained=pretrained,
                num_classes=0  # Remove classification head to get features only
            )
            print(f"✓ Successfully loaded {model_name} from timm")
            
            # Get the feature dimension from the backbone
            # Swin models have different feature dimensions
            in_dim = self.backbone.num_features
            print(f"✓ Feature dimension: {in_dim}")
            
        except Exception as e:
            print(f"✗ Error loading {model_name}: {e}")
            print(f"Falling back to swin_tiny_patch4_window7_224...")
            model_name = 'swin_tiny_patch4_window7_224'
            self.backbone = timm.create_model(
                model_name,
                pretrained=pretrained,
                num_classes=0
            )
            in_dim = self.backbone.num_features
            print(f"✓ Loaded fallback model: {model_name} with feature dim: {in_dim}")
        
        # Dropout for regularization
        self.dropout = nn.Dropout(dropout)
        
        # Multi-task heads
        self.head_species = nn.Linear(in_dim, num_species)
        self.head_health = nn.Linear(in_dim, num_health)
        
        print(f"✓ Multi-task heads created:")
        print(f"  - Species head: {in_dim} → {num_species}")
        print(f"  - Health head: {in_dim} → {num_health}")
    
    def forward(self, x):
        # Extract features from backbone
        feats = self.backbone(x)
        
        # Apply dropout
        feats = self.dropout(feats)
        
        # Get predictions from both heads
        logits_species = self.head_species(feats)
        logits_health = self.head_health(feats)
        
        return logits_species, logits_health

# Initialize the model
print("\n" + "="*80)
print("INITIALIZING MULTI-TASK SWIN MODEL")
print("="*80)
model = MultiTaskSwin(
    num_species=NUM_SPECIES, 
    num_health=NUM_HEALTH, 
    model_name=SWIN_MODEL,
    pretrained=PRETRAINED, 
    dropout=DROPOUT
).to(device)
print(f"✓ Model successfully initialized on {device}")
print("="*80 + "\n")


## OPTIMIZER, SCHEDULER, LOSS FUNCTIONS

In [None]:
optimizer = torch.optim.AdamW(model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS)
criterion_species = nn.CrossEntropyLoss()
criterion_health = nn.CrossEntropyLoss()
scaler = torch.cuda.amp.GradScaler(enabled=USE_AMP and device.type == "cuda")


## TRAINING UTILITIES

In [None]:
def accuracy(logits, targets):
    """Calculate accuracy from logits and targets"""
    preds = logits.argmax(dim=1)
    return (preds == targets).float().mean().item()

def run_epoch(loader, model, optimizer=None, train=True, epoch=0):
    """Run one epoch of training or evaluation"""
    if train:
        model.train()
    else:
        model.eval()
    
    running = {
        "loss": 0.0,
        "acc_species": 0.0,
        "acc_health": 0.0,
        "n": 0
    }
    total_batches = len(loader)
    
    for batch_idx, (imgs, y_species, y_health) in enumerate(loader):
        imgs = imgs.to(device, non_blocking=True)
        y_species = y_species.to(device, non_blocking=True)
        y_health = y_health.to(device, non_blocking=True)
        
        with torch.set_grad_enabled(train):
            if USE_AMP and device.type == "cuda":
                with torch.amp.autocast('cuda'):
                    logits_species, logits_health = model(imgs)
                    loss = criterion_species(logits_species, y_species) + \
                           LOSS_WEIGHT_HEALTH * criterion_health(logits_health, y_health)
            else:
                logits_species, logits_health = model(imgs)
                loss = criterion_species(logits_species, y_species) + \
                       LOSS_WEIGHT_HEALTH * criterion_health(logits_health, y_health)
        
        if train:
            optimizer.zero_grad(set_to_none=True)
            if USE_AMP and device.type == "cuda":
                scaler.scale(loss).backward()
                if GRADIENT_CLIP:
                    scaler.unscale_(optimizer)
                    clip_grad_norm_(model.parameters(), GRADIENT_CLIP)
                scaler.step(optimizer)
                scaler.update()
            else:
                loss.backward()
                if GRADIENT_CLIP:
                    clip_grad_norm_(model.parameters(), GRADIENT_CLIP)
                optimizer.step()
        
        # Calculate accuracies
        acc_sp = accuracy(logits_species, y_species)
        acc_he = accuracy(logits_health, y_health)
        
        # Update running statistics
        bs = imgs.size(0)
        running["loss"] += loss.item() * bs
        running["acc_species"] += acc_sp * bs
        running["acc_health"] += acc_he * bs
        running["n"] += bs
        
        # Print progress
        if (batch_idx + 1) % max(1, total_batches // 10) == 0 or (batch_idx + 1) == total_batches:
            avg_loss = running["loss"] / running["n"]
            avg_sp = running["acc_species"] / running["n"]
            avg_he = running["acc_health"] / running["n"]
            print(f"  [{batch_idx + 1}/{total_batches}] loss: {avg_loss:.4f}, "
                  f"sp: {avg_sp:.3f}, he: {avg_he:.3f}")
    
    # Calculate averages
    for k in ["loss", "acc_species", "acc_health"]:
        running[k] /= max(1, running["n"])
    
    return running

## TRAINING LOOP

In [None]:
# Storage for history
history = {
    "train_loss": [],
    "val_loss": [],
    "train_acc_species": [],
    "val_acc_species": [],
    "train_acc_health": [],
    "val_acc_health": []
}

best_val_health = 0.0
best_epoch = 0
epochs_without_improvement = 0

print("="*80)
print("STARTING TRAINING")
print("="*80)

for epoch in range(EPOCHS):
    print(f"\n{'='*80}")
    print(f"Epoch {epoch+1}/{EPOCHS} | LR: {optimizer.param_groups[0]['lr']:.2e}")
    print(f"{'='*80}")
    
    # Train
    print("Training...")
    train_stats = run_epoch(train_loader, model, optimizer, train=True, epoch=epoch)
    
    # Validate
    print("Validating...")
    val_stats = run_epoch(val_loader, model, optimizer=None, train=False, epoch=epoch)
    
    # Update scheduler
    scheduler.step()
    
    # Store history
    history["train_loss"].append(train_stats["loss"])
    history["val_loss"].append(val_stats["loss"])
    history["train_acc_species"].append(train_stats["acc_species"])
    history["val_acc_species"].append(val_stats["acc_species"])
    history["train_acc_health"].append(train_stats["acc_health"])
    history["val_acc_health"].append(val_stats["acc_health"])
    
    # Print summary
    print(f"\n{'EPOCH SUMMARY':^80}")
    print("-"*80)
    print(f"  Train - Loss: {train_stats['loss']:.4f} | Species: {train_stats['acc_species']:.3f} | "
          f"Health: {train_stats['acc_health']:.3f}")
    print(f"  Val   - Loss: {val_stats['loss']:.4f} | Species: {val_stats['acc_species']:.3f} | "
          f"Health: {val_stats['acc_health']:.3f}")
    
    # Save best model and check for improvement (based on health accuracy)
    if val_stats["acc_health"] > best_val_health:
        best_val_health = val_stats["acc_health"]
        best_epoch = epoch
        epochs_without_improvement = 0
        torch.save({
            "model": model.state_dict(),
            "optimizer": optimizer.state_dict(),
            "epoch": epoch,
            "val_health": best_val_health
        }, f"{MODEL_NAME}.pt")
        print(f"  ★ New best model saved! Val Health Acc: {best_val_health:.4f}")
    else:
        epochs_without_improvement += 1
        print(f"  No improvement. Epochs without improvement: {epochs_without_improvement}/{PATIENCE}")
    
    # Early stopping
    if epochs_without_improvement >= PATIENCE:
        print(f"\n{'⚠ EARLY STOPPING TRIGGERED':^80}")
        print(f"No improvement for {PATIENCE} epochs. Stopping training.")
        break
    
    print(f"{'='*80}\n")

print("\n" + "="*80)
print("TRAINING COMPLETE")
print("="*80)
print(f"Best epoch: {best_epoch+1} with val_health={best_val_health:.4f}")
print(f"Total epochs trained: {epoch+1}/{EPOCHS}")


## TEST THE BEST MODEL

In [None]:
print("\nTesting best model on test set...")
checkpoint = torch.load(f"{MODEL_NAME}.pt", map_location=device)
model.load_state_dict(checkpoint["model"])
print(f"Loaded best model from epoch {checkpoint['epoch']+1} with val_health={checkpoint['val_health']:.3f}")

# Running the test phase
test_stats = run_epoch(test_loader, model, optimizer=None, train=False)
print(f"\n{'TEST SET RESULTS':^80}")
print("-"*80)
print(f"  Loss: {test_stats['loss']:.4f}")
print(f"  Species Accuracy: {test_stats['acc_species']:.4f} ({test_stats['acc_species']*100:.2f}%)")
print(f"  Health Accuracy:  {test_stats['acc_health']:.4f} ({test_stats['acc_health']*100:.2f}%)")
print("-"*80 + "\n")

# Save final model with proper naming
final_model_name = f"final_multitask_{SWIN_MODEL}"
torch.save({
    "model": model.state_dict(),
    "epoch": epoch,
    "test_stats": test_stats,
    "model_config": {
        "model_name": SWIN_MODEL,
        "num_species": NUM_SPECIES,
        "num_health": NUM_HEALTH,
        "dropout": DROPOUT
    },
    "spec": {"species_map": SPECIES_MAP, "health_map": HEALTH_MAP}
}, f"{final_model_name}.pt")
print(f"✓ Saved final model as '{final_model_name}.pt'\n")


## SAVE INDIVIDUAL TRAINING PLOTS

In [None]:
print("\nGenerating individual training plots...")

epochs_range = range(1, len(history['train_loss']) + 1)

# 1. Training Loss
fig, ax = plt.subplots(figsize=(10, 6))
ax.plot(epochs_range, history['train_loss'], 'b-o', label='Train Loss', linewidth=2, markersize=6)
ax.set_xlabel('Epoch', fontsize=12)
ax.set_ylabel('Loss', fontsize=12)
ax.set_title('Training Loss Over Epochs', fontsize=14, fontweight='bold')
ax.legend(fontsize=11)
ax.grid(True, alpha=0.3)
plt.tight_layout()
plt.savefig('plot_train_loss.png', dpi=150, bbox_inches='tight')
print("✓ Saved plot_train_loss.png")
plt.close()

# 2. Validation Loss
fig, ax = plt.subplots(figsize=(10, 6))
ax.plot(epochs_range, history['val_loss'], 'r-o', label='Validation Loss', linewidth=2, markersize=6)
ax.set_xlabel('Epoch', fontsize=12)
ax.set_ylabel('Loss', fontsize=12)
ax.set_title('Validation Loss Over Epochs', fontsize=14, fontweight='bold')
ax.legend(fontsize=11)
ax.grid(True, alpha=0.3)
plt.tight_layout()
plt.savefig('plot_val_loss.png', dpi=150, bbox_inches='tight')
print("✓ Saved plot_val_loss.png")
plt.close()

# 3. Train vs Val Loss Comparison
fig, ax = plt.subplots(figsize=(10, 6))
ax.plot(epochs_range, history['train_loss'], 'b-o', label='Train Loss', linewidth=2, markersize=6)
ax.plot(epochs_range, history['val_loss'], 'r-o', label='Val Loss', linewidth=2, markersize=6)
ax.set_xlabel('Epoch', fontsize=12)
ax.set_ylabel('Loss', fontsize=12)
ax.set_title('Training vs Validation Loss', fontsize=14, fontweight='bold')
ax.legend(fontsize=11)
ax.grid(True, alpha=0.3)
plt.tight_layout()
plt.savefig('plot_loss_comparison.png', dpi=150, bbox_inches='tight')
print("✓ Saved plot_loss_comparison.png")
plt.close()

# 4. Species Accuracy
fig, ax = plt.subplots(figsize=(10, 6))
ax.plot(epochs_range, history['train_acc_species'], 'b-o', label='Train Species Acc', linewidth=2, markersize=6)
ax.plot(epochs_range, history['val_acc_species'], 'r-o', label='Val Species Acc', linewidth=2, markersize=6)
ax.set_xlabel('Epoch', fontsize=12)
ax.set_ylabel('Accuracy', fontsize=12)
ax.set_title('Species Classification Accuracy', fontsize=14, fontweight='bold')
ax.legend(fontsize=11)
ax.grid(True, alpha=0.3)
ax.set_ylim([0, 1.05])
plt.tight_layout()
plt.savefig('plot_species_accuracy.png', dpi=150, bbox_inches='tight')
print("✓ Saved plot_species_accuracy.png")
plt.close()

# 5. Health Accuracy
fig, ax = plt.subplots(figsize=(10, 6))
ax.plot(epochs_range, history['train_acc_health'], 'b-o', label='Train Health Acc', linewidth=2, markersize=6)
ax.plot(epochs_range, history['val_acc_health'], 'r-o', label='Val Health Acc', linewidth=2, markersize=6)
ax.axhline(y=best_val_health, color='g', linestyle='--', linewidth=2, label=f'Best Val Health: {best_val_health:.3f}')
ax.set_xlabel('Epoch', fontsize=12)
ax.set_ylabel('Accuracy', fontsize=12)
ax.set_title('Health/Disease Classification Accuracy', fontsize=14, fontweight='bold')
ax.legend(fontsize=11)
ax.grid(True, alpha=0.3)
ax.set_ylim([0, 1.05])
plt.tight_layout()
plt.savefig('plot_health_accuracy.png', dpi=150, bbox_inches='tight')
print("✓ Saved plot_health_accuracy.png")
plt.close()

# 6. All Accuracies Together
fig, ax = plt.subplots(figsize=(12, 7))
ax.plot(epochs_range, history['train_acc_species'], 'b-o', label='Train Species', linewidth=2, markersize=5)
ax.plot(epochs_range, history['val_acc_species'], 'b--s', label='Val Species', linewidth=2, markersize=5)
ax.plot(epochs_range, history['train_acc_health'], 'g-o', label='Train Health', linewidth=2, markersize=5)
ax.plot(epochs_range, history['val_acc_health'], 'g--s', label='Val Health', linewidth=2, markersize=5)
ax.set_xlabel('Epoch', fontsize=12)
ax.set_ylabel('Accuracy', fontsize=12)
ax.set_title('All Metrics Over Training', fontsize=14, fontweight='bold')
ax.legend(fontsize=10, ncol=2)
ax.grid(True, alpha=0.3)
ax.set_ylim([0, 1.05])
plt.tight_layout()
plt.savefig('plot_all_metrics.png', dpi=150, bbox_inches='tight')
print("✓ Saved plot_all_metrics.png")
plt.close()

print("\n" + "="*80)
print("All training plots saved:")
print("  - plot_train_loss.png")
print("  - plot_val_loss.png")
print("  - plot_loss_comparison.png")
print("  - plot_species_accuracy.png")
print("  - plot_health_accuracy.png")
print("  - plot_all_metrics.png")
print("="*80 + "\n")

# Save history to CSV
history_df = pd.DataFrame(history)
history_df.to_csv(f'history_{MODEL_NAME.lower().replace("-", "_")}.csv', index=False)
print(f"✓ Saved training history to 'history_{MODEL_NAME.lower().replace('-', '_')}.csv'")

# Plot Learning Rate Schedule
fig, ax = plt.subplots(figsize=(10, 6))
lrs = [LR * (0.5 * (1 + np.cos(np.pi * i / EPOCHS))) for i in range(len(history["train_loss"]))]
ax.plot(lrs, marker='o', markersize=4, linewidth=2)
ax.set_xlabel('Epoch', fontsize=12)
ax.set_ylabel('Learning Rate', fontsize=12)
ax.set_title('Learning Rate Schedule (Cosine Annealing)', fontsize=14, fontweight='bold')
ax.grid(True, alpha=0.3)
plt.tight_layout()
plt.savefig(f'lr_schedule_{MODEL_NAME.lower().replace("-", "_")}.png', dpi=150, bbox_inches='tight')
print(f"✓ Saved LR schedule plot to 'lr_schedule_{MODEL_NAME.lower().replace('-', '_')}.png'")
plt.close()


## COMPREHENSIVE TESTING FUNCTION

In [None]:
def comprehensive_test(model, test_loader, device, species_map, health_map):
    """
    Perform comprehensive testing with metrics and visualizations
    """
    model.eval()
    
    # Storage for predictions and ground truth
    all_species_preds = []
    all_species_true = []
    all_health_preds = []
    all_health_true = []
    
    print("Running comprehensive test evaluation...")
    
    with torch.no_grad():
        for batch_idx, (imgs, y_species, y_health) in enumerate(test_loader):
            imgs = imgs.to(device, non_blocking=True)
            y_species = y_species.to(device, non_blocking=True)
            y_health = y_health.to(device, non_blocking=True)
            
            # Get predictions
            logits_species, logits_health = model(imgs)
            preds_species = logits_species.argmax(dim=1)
            preds_health = logits_health.argmax(dim=1)
            
            # Store predictions and ground truth
            all_species_preds.extend(preds_species.cpu().numpy())
            all_species_true.extend(y_species.cpu().numpy())
            all_health_preds.extend(preds_health.cpu().numpy())
            all_health_true.extend(y_health.cpu().numpy())
    
    # Convert to numpy arrays
    all_species_preds = np.array(all_species_preds)
    all_species_true = np.array(all_species_true)
    all_health_preds = np.array(all_health_preds)
    all_health_true = np.array(all_health_true)
    
    # Reverse mapping for labels
    species_labels = {v: k.capitalize() for k, v in species_map.items()}
    health_labels = {v: k.capitalize() for k, v in health_map.items()}
    
    # -------------------------------
    # Print Metrics
    # -------------------------------
    print("\n" + "="*80)
    print("COMPREHENSIVE TEST RESULTS")
    print("="*80)
    
    # Overall accuracies
    species_acc = accuracy_score(all_species_true, all_species_preds)
    health_acc = accuracy_score(all_health_true, all_health_preds)
    
    print(f"\n{'OVERALL ACCURACIES':^80}")
    print("-"*80)
    print(f"  Species Classification:  {species_acc:.4f} ({species_acc*100:.2f}%)")
    print(f"  Health Classification:   {health_acc:.4f} ({health_acc*100:.2f}%)")
    
    # Species Classification Report
    print(f"\n{'SPECIES CLASSIFICATION REPORT':^80}")
    print("-"*80)
    print(classification_report(
        all_species_true,
        all_species_preds,
        target_names=[species_labels[i] for i in sorted(species_labels.keys())],
        digits=4
    ))
    
    # Health Classification Report
    print(f"\n{'HEALTH/DISEASE CLASSIFICATION REPORT':^80}")
    print("-"*80)
    print(classification_report(
        all_health_true,
        all_health_preds,
        target_names=[health_labels[i] for i in sorted(health_labels.keys())],
        digits=4
    ))
    
    # -------------------------------
    # Visualizations
    # -------------------------------
    
    # 1. Species Confusion Matrix
    fig, ax = plt.subplots(figsize=(8, 6))
    cm_species = confusion_matrix(all_species_true, all_species_preds)
    sns.heatmap(
        cm_species,
        annot=True,
        fmt='d',
        cmap='Blues',
        ax=ax,
        xticklabels=[species_labels[i] for i in sorted(species_labels.keys())],
        yticklabels=[species_labels[i] for i in sorted(species_labels.keys())]
    )
    ax.set_title(f'Species Classification\nAccuracy: {species_acc:.2%}',
                 fontsize=12, fontweight='bold')
    ax.set_ylabel('True Label')
    ax.set_xlabel('Predicted Label')
    plt.tight_layout()
    plt.savefig('confusion_matrix_species.png', dpi=150, bbox_inches='tight')
    print(f"\nSaved species confusion matrix to 'confusion_matrix_species.png'")
    plt.close()
    
    # 2. Health Confusion Matrix
    fig, ax = plt.subplots(figsize=(8, 6))
    cm_health = confusion_matrix(all_health_true, all_health_preds)
    sns.heatmap(
        cm_health,
        annot=True,
        fmt='d',
        cmap='Greens',
        ax=ax,
        xticklabels=[health_labels[i] for i in sorted(health_labels.keys())],
        yticklabels=[health_labels[i] for i in sorted(health_labels.keys())]
    )
    ax.set_title(f'Health/Disease Classification\nAccuracy: {health_acc:.2%}',
                 fontsize=12, fontweight='bold')
    ax.set_ylabel('True Label')
    ax.set_xlabel('Predicted Label')
    plt.tight_layout()
    plt.savefig('confusion_matrix_health.png', dpi=150, bbox_inches='tight')
    print(f"Saved health confusion matrix to 'confusion_matrix_health.png'")
    plt.close()
    
    print("\n" + "="*80)
    print("Testing complete! Generated visualizations:")
    print("  - confusion_matrix_species.png")
    print("  - confusion_matrix_health.png")
    print("="*80 + "\n")
    
    return {
        'species_accuracy': species_acc,
        'health_accuracy': health_acc,
        'species_preds': all_species_preds,
        'species_true': all_species_true,
        'health_preds': all_health_preds,
        'health_true': all_health_true
    }

## RUN COMPREHENSIVE TEST EVALUATION

In [None]:
print("="*80)
print("Running Comprehensive Test Evaluation")
print("="*80 + "\n")

# Load the best model (already loaded above, but doing it again for clarity)
checkpoint = torch.load(f"{MODEL_NAME}.pt", map_location=device)
model.load_state_dict(checkpoint["model"])
print(f"Loaded best model from epoch {checkpoint['epoch']+1} with val_health={checkpoint['val_health']:.3f}")

# Run comprehensive testing
test_results = comprehensive_test(
    model=model,
    test_loader=test_loader,
    device=device,
    species_map=SPECIES_MAP,
    health_map=HEALTH_MAP
)


## SAMPLE PREDICTIONS VISUALIZATION (10 PER CLASS)

In [None]:
print("\n" + "="*80)
print("Generating Sample Predictions Visualization")
print("="*80 + "\n")

# Label mappings for visualization
species_labels = {
    0: 'Eggplant',
    1: 'Potato',
    2: 'Tomato'
}

health_labels = {
    0: 'Bacterial',
    1: 'Fungal',
    2: 'Healthy',
    3: 'Virus'
}

# Number of samples to display per class
amount = 10

# Collect samples from validation set
sample_images_by_class = {0: [], 1: [], 2: []}
sample_predictions_by_class = {0: [], 1: [], 2: []}
sample_ground_truth_by_class = {0: [], 1: [], 2: []}

model.eval()
with torch.no_grad():
    for images, species_batch, health_batch in val_loader:
        images = images.to(device)
        
        outputs = model(images)
        species_preds = outputs[0].argmax(1)
        health_preds = outputs[1].argmax(1)
        
        for i in range(len(images)):
            species_class = species_batch[i].item()
            
            if len(sample_images_by_class[species_class]) < amount:
                sample_images_by_class[species_class].append(images[i].cpu())
                sample_predictions_by_class[species_class].append({
                    'species': species_preds[i].item(),
                    'health': health_preds[i].item()
                })
                sample_ground_truth_by_class[species_class].append({
                    'species': species_batch[i].item(),
                    'health': health_batch[i].item()
                })
        
        if all(len(samples) >= amount for samples in sample_images_by_class.values()):
            break

# Visualize
mean = np.array([0.485, 0.456, 0.406])
std = np.array([0.229, 0.224, 0.225])

fig, axes = plt.subplots(NUM_SPECIES, amount, figsize=(3*amount, 9))

for row, species_idx in enumerate(sorted(sample_images_by_class.keys())):
    for col in range(amount):
        ax = axes[row, col]
        
        img = sample_images_by_class[species_idx][col]
        pred = sample_predictions_by_class[species_idx][col]
        gt = sample_ground_truth_by_class[species_idx][col]
        
        # Denormalize and display
        img_display = img.numpy().transpose(1, 2, 0)
        img_display = std * img_display + mean
        img_display = np.clip(img_display, 0, 1)
        
        ax.imshow(img_display)
        ax.axis('off')
        
        # Check correctness
        both_correct = (pred['species'] == gt['species']) and (pred['health'] == gt['health'])
        
        # Create title
        pred_sp = species_labels[pred['species']]
        pred_he = health_labels[pred['health']]
        gt_sp = species_labels[gt['species']]
        gt_he = health_labels[gt['health']]
        
        title = f"Pred: {pred_sp}, {pred_he}\nTrue: {gt_sp}, {gt_he}"
        color = 'green' if both_correct else 'red'
        ax.set_title(title, fontsize=8, color=color, fontweight='bold')
    
    # Add species label
    fig.text(0.02, 0.5 + (1 - row) * 0.3, species_labels[species_idx],
             fontsize=12, fontweight='bold', va='center', rotation=90)

plt.suptitle(f'Sample Predictions - {amount} Samples per Class\n(Green=Correct, Red=Incorrect)',
             fontsize=14, fontweight='bold', y=0.995)
plt.tight_layout(rect=[0.05, 0, 1, 0.99])
plt.savefig('sample_predictions.png', dpi=150, bbox_inches='tight')
print(f"Saved sample predictions ({NUM_SPECIES*amount} total: {amount} per class) to 'sample_predictions.png'")
plt.close()


## FINAL SUMMARY

In [None]:
print("\n" + "="*80)
print("ALL TASKS COMPLETED")
print("="*80)
print(f"\nModel: {MODEL_NAME}")
print(f"Best validation health accuracy: {best_val_health:.4f} at epoch {best_epoch+1}")
print(f"Final test health accuracy: {test_stats['acc_health']:.4f}")
print(f"\nHyperparameters:")
print(f"  - Batch Size: {BATCH_SIZE}")
print(f"  - Learning Rate: {LR}")
print(f"  - Weight Decay: {WEIGHT_DECAY}")
print(f"  - Dropout: {DROPOUT}")
print(f"  - Epochs: {EPOCHS}")
print(f"  - Patience: {PATIENCE}")
print(f"  - Image Size: {IMG_SIZE}")
print(f"  - Num Species: {NUM_SPECIES}")
print(f"  - Num Health: {NUM_HEALTH}")
print(f"  - Pretrained: {PRETRAINED}")
print("\nGenerated files:")
print(f"  - {MODEL_NAME}.pt (best model checkpoint)")
print(f"  - final_multitask_{SWIN_MODEL}.pt (final model)")
print(f"  - confusion_matrix_species.png")
print(f"  - confusion_matrix_health.png")
print(f"  - sample_predictions.png")
print(f"  - plot_train_loss.png")
print(f"  - plot_val_loss.png")
print(f"  - plot_loss_comparison.png")
print(f"  - plot_species_accuracy.png")
print(f"  - plot_health_accuracy.png")
print(f"  - plot_all_metrics.png")
print(f"  - lr_schedule_{MODEL_NAME.lower().replace('-', '_')}.png")
print(f"  - history_{MODEL_NAME.lower().replace('-', '_')}.csv")
print("="*80)
