# Thesis Reproduction Notebook

This notebook reproduces experiments for:
**Self-Supervised Plant Disease Detection under Domain Shifts: Comparative Evaluation of CNN, Transformer, and Hybrid Architectures**

It generates (or helps reproduce) **Tables 6.1–6.4** and key figures (training curves, confusion matrix, Grad-CAM).

> Notes:
> - Deterministic seeds are set via `utils.repro.set_global_seed`.
> - PlantDoc results should be averaged over **4 seeds** as in the thesis.


In [None]:
# Setup
import os
import torch
from utils.repro import set_global_seed

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Device:', DEVICE)

SEED = 42
set_global_seed(SEED, deterministic=True)


In [None]:

# Reproducible PlantVillage split loading (80/10/10 stratified, seed=42)
from utils.dataset_prep import load_plantvillage_from_splits
from utils.repro import set_global_seed
from utils.augmentation import train_transform_cnn, train_transform_vit, val_transform_cnn, val_transform_vit
import os

set_global_seed(42, deterministic=True)

SPLIT_JSON = os.path.join("data", "splits", "plantvillage_split_seed42.json")
pv_train_set, pv_val_set, pv_test_set, CLASS_NAMES = load_plantvillage_from_splits(
    data_dir="data",
    split_json=SPLIT_JSON,
    transform_train=train_transform_cnn,  # for CNN loaders; for ViT/hybrid we wrap later
    transform_val=val_transform_cnn,
    transform_test=val_transform_cnn
)

print("Loaded PlantVillage splits from:", SPLIT_JSON)
print("Classes:", len(CLASS_NAMES))
print("Train/Val/Test:", len(pv_train_set), len(pv_val_set), len(pv_test_set))


In [None]:
# Dataset paths
DATA_DIR = '../data'
PLANTVILLAGE_DIR = os.path.join(DATA_DIR, 'PlantVillage')
PLANTDOC_DIR = os.path.join(DATA_DIR, 'PlantDoc')
SPLITS_DIR = os.path.join(DATA_DIR, 'splits')

print('PlantVillage:', PLANTVILLAGE_DIR)
print('PlantDoc:', PLANTDOC_DIR)


In [None]:
# Load datasets
from torchvision import datasets
from torch.utils.data import DataLoader, Subset

from utils.augmentation import train_transform_cnn, val_transform_cnn, train_transform_vit, val_transform_vit
from utils.dataset_prep import HybridDataset

# Assumes PlantVillage already has train/valid/test folders (Kaggle split).
# If you use raw PlantVillage, run scripts/create_splits.py and modify dataset loading accordingly.

pv_train = datasets.ImageFolder(os.path.join(PLANTVILLAGE_DIR, 'train'), transform=None)
pv_val   = datasets.ImageFolder(os.path.join(PLANTVILLAGE_DIR, 'valid'), transform=None)
pv_test  = datasets.ImageFolder(os.path.join(PLANTVILLAGE_DIR, 'test'), transform=None)

num_classes = len(pv_train.classes)
print('PlantVillage classes:', num_classes)

# Apply transforms per model
pv_train_cnn = datasets.ImageFolder(os.path.join(PLANTVILLAGE_DIR, 'train'), transform=train_transform_cnn)
pv_val_cnn   = datasets.ImageFolder(os.path.join(PLANTVILLAGE_DIR, 'valid'), transform=val_transform_cnn)
pv_test_cnn  = datasets.ImageFolder(os.path.join(PLANTVILLAGE_DIR, 'test'), transform=val_transform_cnn)

pv_train_vit = datasets.ImageFolder(os.path.join(PLANTVILLAGE_DIR, 'train'), transform=train_transform_vit)
pv_val_vit   = datasets.ImageFolder(os.path.join(PLANTVILLAGE_DIR, 'valid'), transform=val_transform_vit)
pv_test_vit  = datasets.ImageFolder(os.path.join(PLANTVILLAGE_DIR, 'test'), transform=val_transform_vit)

# Hybrid dataset wraps base (no transform) and produces (cnn, vit) tensors
hybrid_train = HybridDataset(pv_train, transform_cnn=train_transform_cnn, transform_vit=train_transform_vit)
hybrid_val   = HybridDataset(pv_val,   transform_cnn=val_transform_cnn,   transform_vit=val_transform_vit)
hybrid_test  = HybridDataset(pv_test,  transform_cnn=val_transform_cnn,   transform_vit=val_transform_vit)


In [None]:
# Dataloaders
BATCH_CNN = 64
BATCH_VIT = 32

train_loader_cnn = DataLoader(pv_train_cnn, batch_size=BATCH_CNN, shuffle=True, num_workers=4, pin_memory=True)
val_loader_cnn   = DataLoader(pv_val_cnn,   batch_size=BATCH_CNN, shuffle=False, num_workers=4, pin_memory=True)
test_loader_cnn  = DataLoader(pv_test_cnn,  batch_size=BATCH_CNN, shuffle=False, num_workers=4, pin_memory=True)

train_loader_vit = DataLoader(pv_train_vit, batch_size=BATCH_VIT, shuffle=True, num_workers=4, pin_memory=True)
val_loader_vit   = DataLoader(pv_val_vit,   batch_size=BATCH_VIT, shuffle=False, num_workers=4, pin_memory=True)
test_loader_vit  = DataLoader(pv_test_vit,  batch_size=BATCH_VIT, shuffle=False, num_workers=4, pin_memory=True)

train_loader_hybrid = DataLoader(hybrid_train, batch_size=BATCH_CNN, shuffle=True, num_workers=4, pin_memory=True)
val_loader_hybrid   = DataLoader(hybrid_val,   batch_size=BATCH_CNN, shuffle=False, num_workers=4, pin_memory=True)
test_loader_hybrid  = DataLoader(hybrid_test,  batch_size=BATCH_CNN, shuffle=False, num_workers=4, pin_memory=True)


In [None]:
# Models
from models.resnet50_model import ResNet50Classifier
from models.vit_dino_model import ViTClassifier
from models.hybrid_cnn_vit_model import HybridCNNViTModel

resnet = ResNet50Classifier(num_classes=num_classes, pretrained=True)
vit    = ViTClassifier(num_classes=num_classes, pretrained=True)
hybrid = HybridCNNViTModel(num_classes=num_classes, pretrained=True)

print('ViT backbone:', getattr(vit, 'model_name', 'unknown'))


In [None]:
# Train (Table 6.1 - full data)
from scripts.train import train_model

resnet_trained, resnet_hist = train_model(resnet, train_loader_cnn, val_loader_cnn, model_type='resnet', epochs=50)
vit_trained, vit_hist       = train_model(vit, train_loader_vit, val_loader_vit, model_type='vit', epochs=50)
hybrid_trained, hybrid_hist = train_model(hybrid, train_loader_hybrid, val_loader_hybrid, model_type='hybrid', epochs=50)


In [None]:
# Plot training curves (Figure 6.1 style)
from utils.plot_utils import plot_training_curves

plot_training_curves(resnet_hist, title='ResNet50 Training Curves')
plot_training_curves(vit_hist, title='ViT (DINOv2-B) Training Curves')
plot_training_curves(hybrid_hist, title='Hybrid CNN-ViT Training Curves')


In [None]:
# Evaluate on PlantVillage test (Table 6.1)
from scripts.evaluate import evaluate_model

resnet_metrics = evaluate_model(resnet_trained, test_loader_cnn, model_type='resnet')
vit_metrics    = evaluate_model(vit_trained, test_loader_vit, model_type='vit')
hybrid_metrics = evaluate_model(hybrid_trained, test_loader_hybrid, model_type='hybrid')

print('ResNet50 : acc, prec, rec, f1 =', resnet_metrics)
print('DINOv2-B : acc, prec, rec, f1 =', vit_metrics)
print('Hybrid   : acc, prec, rec, f1 =', hybrid_metrics)


In [None]:
# Low-label experiments (Table 6.2): 25% and 10% labeled subsets
from utils.stratified_split import stratified_subsample

idx_25 = stratified_subsample(pv_train_cnn, 0.25, seed=SEED)
idx_10 = stratified_subsample(pv_train_cnn, 0.10, seed=SEED)

train_cnn_25 = Subset(pv_train_cnn, idx_25)
train_cnn_10 = Subset(pv_train_cnn, idx_10)

train_vit_25 = Subset(pv_train_vit, idx_25)
train_vit_10 = Subset(pv_train_vit, idx_10)

train_hybrid_25 = Subset(hybrid_train, idx_25)
train_hybrid_10 = Subset(hybrid_train, idx_10)

loader_cnn_25 = DataLoader(train_cnn_25, batch_size=BATCH_CNN, shuffle=True, num_workers=4, pin_memory=True)
loader_cnn_10 = DataLoader(train_cnn_10, batch_size=BATCH_CNN, shuffle=True, num_workers=4, pin_memory=True)

loader_vit_25 = DataLoader(train_vit_25, batch_size=BATCH_VIT, shuffle=True, num_workers=4, pin_memory=True)
loader_vit_10 = DataLoader(train_vit_10, batch_size=BATCH_VIT, shuffle=True, num_workers=4, pin_memory=True)

loader_hybrid_25 = DataLoader(train_hybrid_25, batch_size=BATCH_CNN, shuffle=True, num_workers=4, pin_memory=True)
loader_hybrid_10 = DataLoader(train_hybrid_10, batch_size=BATCH_CNN, shuffle=True, num_workers=4, pin_memory=True)

# fresh models for each regime
resnet_25 = ResNet50Classifier(num_classes=num_classes, pretrained=True)
resnet_10 = ResNet50Classifier(num_classes=num_classes, pretrained=True)

vit_25 = ViTClassifier(num_classes=num_classes, pretrained=True)
vit_10 = ViTClassifier(num_classes=num_classes, pretrained=True)

hybrid_25 = HybridCNNViTModel(num_classes=num_classes, pretrained=True)
hybrid_10 = HybridCNNViTModel(num_classes=num_classes, pretrained=True)

# train (longer epochs in low-label regime)
resnet_25, _ = train_model(resnet_25, loader_cnn_25, val_loader_cnn, model_type='resnet', epochs=100)
resnet_10, _ = train_model(resnet_10, loader_cnn_10, val_loader_cnn, model_type='resnet', epochs=100)

vit_25, _ = train_model(vit_25, loader_vit_25, val_loader_vit, model_type='vit', epochs=100)
vit_10, _ = train_model(vit_10, loader_vit_10, val_loader_vit, model_type='vit', epochs=100)

hybrid_25, _ = train_model(hybrid_25, loader_hybrid_25, val_loader_hybrid, model_type='hybrid', epochs=100)
hybrid_10, _ = train_model(hybrid_10, loader_hybrid_10, val_loader_hybrid, model_type='hybrid', epochs=100)

# evaluate
print('ResNet25:', evaluate_model(resnet_25, test_loader_cnn, model_type='resnet'))
print('ResNet10:', evaluate_model(resnet_10, test_loader_cnn, model_type='resnet'))
print('ViT25:', evaluate_model(vit_25, test_loader_vit, model_type='vit'))
print('ViT10:', evaluate_model(vit_10, test_loader_vit, model_type='vit'))
print('Hybrid25:', evaluate_model(hybrid_25, test_loader_hybrid, model_type='hybrid'))
print('Hybrid10:', evaluate_model(hybrid_10, test_loader_hybrid, model_type='hybrid'))


In [None]:
# Robustness to corruptions (Table 6.3)
from torchvision import datasets
from scripts.evaluate import CorruptDataset, evaluate_hybrid_corruption

base_test = datasets.ImageFolder(os.path.join(PLANTVILLAGE_DIR, 'test'))

blur_cnn = CorruptDataset(base_test, blur=True, noise=False, transform=val_transform_cnn)
noise_cnn = CorruptDataset(base_test, blur=False, noise=True, transform=val_transform_cnn)
blur_vit = CorruptDataset(base_test, blur=True, noise=False, transform=val_transform_vit)
noise_vit = CorruptDataset(base_test, blur=False, noise=True, transform=val_transform_vit)

acc_resnet_blur = evaluate_model(resnet_trained, DataLoader(blur_cnn, batch_size=BATCH_CNN), model_type='resnet')[0]
acc_resnet_noise = evaluate_model(resnet_trained, DataLoader(noise_cnn, batch_size=BATCH_CNN), model_type='resnet')[0]
acc_vit_blur = evaluate_model(vit_trained, DataLoader(blur_vit, batch_size=BATCH_VIT), model_type='vit')[0]
acc_vit_noise = evaluate_model(vit_trained, DataLoader(noise_vit, batch_size=BATCH_VIT), model_type='vit')[0]

acc_hybrid_blur = evaluate_hybrid_corruption(hybrid_trained, blur=True, noise=False, data_dir=DATA_DIR)
acc_hybrid_noise = evaluate_hybrid_corruption(hybrid_trained, blur=False, noise=True, data_dir=DATA_DIR)

print('ResNet blur/noise:', acc_resnet_blur, acc_resnet_noise)
print('ViT blur/noise:', acc_vit_blur, acc_vit_noise)
print('Hybrid blur/noise:', acc_hybrid_blur, acc_hybrid_noise)


In [None]:
# Cross-domain evaluation (Table 6.4): PlantDoc averaged over 4 seeds
from torchvision import datasets
from torch.utils.data import Subset
import numpy as np

from utils.repro import set_global_seed

# PlantDoc assumed at ../data/PlantDoc/plantdoc (ImageFolder)
pd_root = os.path.join(PLANTDOC_DIR, 'plantdoc')

seeds = [0, 1, 2, 3]
res_runs=[]
vit_runs=[]
hyb_runs=[]

for s in seeds:
    set_global_seed(s, deterministic=True)
    # PlantDoc transforms: use val transforms
    pd_cnn = datasets.ImageFolder(pd_root, transform=val_transform_cnn)
    pd_vit = datasets.ImageFolder(pd_root, transform=val_transform_vit)
    # hybrid needs base without transform
    pd_base = datasets.ImageFolder(pd_root)
    pd_hyb = HybridDataset(pd_base, transform_cnn=val_transform_cnn, transform_vit=val_transform_vit)

    loader_cnn = DataLoader(pd_cnn, batch_size=BATCH_CNN, shuffle=False, num_workers=4)
    loader_vit = DataLoader(pd_vit, batch_size=BATCH_VIT, shuffle=False, num_workers=4)
    loader_hyb = DataLoader(pd_hyb, batch_size=BATCH_CNN, shuffle=False, num_workers=4)

    res_runs.append(evaluate_model(resnet_trained, loader_cnn, model_type='resnet'))
    vit_runs.append(evaluate_model(vit_trained, loader_vit, model_type='vit'))
    hyb_runs.append(evaluate_model(hybrid_trained, loader_hyb, model_type='hybrid'))

res_mean = np.mean(res_runs, axis=0); res_std = np.std(res_runs, axis=0)
vit_mean = np.mean(vit_runs, axis=0); vit_std = np.std(vit_runs, axis=0)
hyb_mean = np.mean(hyb_runs, axis=0); hyb_std = np.std(hyb_runs, axis=0)

print('ResNet mean±std:', res_mean, res_std)
print('ViT mean±std:', vit_mean, vit_std)
print('Hybrid mean±std:', hyb_mean, hyb_std)


In [None]:
# Confusion matrix (Figure 6.3) for Hybrid on PlantVillage test
import numpy as np
from utils.metrics import compute_confusion_matrix
from utils.plot_utils import plot_confusion_matrix

hybrid_trained.eval()
all_preds=[]
all_labels=[]

with torch.no_grad():
    for (img_cnn, img_vit), labels in test_loader_hybrid:
        img_cnn = img_cnn.to(DEVICE)
        img_vit = img_vit.to(DEVICE)
        out = hybrid_trained(img_cnn, img_vit)
        preds = out.argmax(dim=1).cpu().numpy()
        all_preds.extend(preds.tolist())
        all_labels.extend(labels.numpy().tolist())

cm = compute_confusion_matrix(all_labels, all_preds, labels=list(range(num_classes)))
plot_confusion_matrix(cm, class_names=pv_train.classes, normalize=True, title='Hybrid Confusion Matrix (Normalized)')


In [None]:
# Grad-CAM (Figure 6.4) example
import matplotlib.pyplot as plt
from PIL import Image
from utils.grad_cam import generate_gradcam, overlay_cam_on_image

# pick a sample from PlantVillage test
sample_path, sample_label = pv_test.samples[0]
img = Image.open(sample_path).convert('RGB')

img_cnn = val_transform_cnn(img).unsqueeze(0)
img_vit = val_transform_vit(img).unsqueeze(0)

cam_resnet = generate_gradcam(resnet_trained, img_cnn, model_type='resnet')
cam_vit = generate_gradcam(vit_trained, img_vit, model_type='vit')
cam_hyb = 0.5 * cam_resnet + 0.5 * cam_vit

ov1 = overlay_cam_on_image(img, cam_resnet)
ov2 = overlay_cam_on_image(img, cam_vit)
ov3 = overlay_cam_on_image(img, cam_hyb)

plt.figure(figsize=(12,4))
plt.subplot(1,3,1); plt.imshow(ov1); plt.axis('off'); plt.title('ResNet Grad-CAM')
plt.subplot(1,3,2); plt.imshow(ov2); plt.axis('off'); plt.title('ViT Grad-CAM')
plt.subplot(1,3,3); plt.imshow(ov3); plt.axis('off'); plt.title('Hybrid Grad-CAM')
plt.show()
