In [1]:
import os
import numpy as np
import random
import pandas
from PIL import Image
from tqdm import tqdm

import torch
import torch.nn as nn
import torchvision.transforms as T
import torch.optim as optim
from torch.utils.data import DataLoader
from sklearn.metrics import f1_score


from resnet import ResNet18, ResidualBlock
from mobilenet import MobileNetV1, MobileNetV2
from vggnet import VGG16
from data import ImgDataset

In [2]:
seed = 42
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Reproducibility settings
torch.manual_seed(seed)
np.random.seed(seed)
# torch.use_deterministic_algorithms(True)
# os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"

print(f"Using device: {device}")

Using device: cuda


In [3]:
CSV_PATH = "/home/banana9205/Desktop/Main/Uni/DATH/Dataset/plantvillage dataset/dataframes"
IMG_PATH = "/home/banana9205/Desktop/Main/Uni/DATH/Dataset/plantvillage dataset"
TRAIN_IMG = os.path.join(IMG_PATH, "train")
VAL_IMG = os.path.join(IMG_PATH, "val")
TEST_IMG = os.path.join(IMG_PATH, "test")

In [4]:
train_csv_path = os.path.join(CSV_PATH, "train_labels.csv")
val_csv_path = os.path.join(CSV_PATH, "val_labels.csv")
test_csv_path = os.path.join(CSV_PATH, "test_labels.csv")

print(f"Path to training CSV: {train_csv_path}")
print(f"Path to validation CSV: {val_csv_path}")
print(f"Path to test CSV: {test_csv_path}")

train_csv = pandas.read_csv(train_csv_path)
val_csv = pandas.read_csv(val_csv_path)
test_csv = pandas.read_csv(test_csv_path)

Path to training CSV: /home/banana9205/Desktop/Main/Uni/DATH/Dataset/plantvillage dataset/dataframes/train_labels.csv
Path to validation CSV: /home/banana9205/Desktop/Main/Uni/DATH/Dataset/plantvillage dataset/dataframes/val_labels.csv
Path to test CSV: /home/banana9205/Desktop/Main/Uni/DATH/Dataset/plantvillage dataset/dataframes/test_labels.csv


In [5]:
# Normalization stats for ImageNet pretrained models
# mean and std for each channel (R, G, B)
#
# If we want optimal values, we would have to compute mean and std over our dataset,
# but it would take a long time. Thus we use ImageNet stats as an approximation.
#
# Some notes on why normalization is important:
#   - Normalization, in this case, brings the pixel values to approximately [-1, 1] range.
#
#   - If left unnormalized, pixel values would be in [0, 255] range after ToTensor(), which
#     can lead to skewed activations in the network and large variance/unstable gradients.
#
#   - A [-1, 1] range (hopefully) makes things easier for the model, as it centers the data
#     around 0 and is symmetric, which helps with convergence during training, as gradients 
#     are more balanced.

stats = ((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))

# Here we provide 3 modes of data augmentation for training dataset
# Mild augmentation: horizontal and vertical flips
# Moderate augmentation: flips + random rotations + crops + color jitter
# Aggressive augmentation: flips + random rotations + crops + color jitter + gaussian blur

mild_transform = T.Compose([
    T.RandomHorizontalFlip(),
    T.RandomVerticalFlip(),

    T.ToTensor(),
    T.Normalize(*stats)
])

moderate_transform = T.Compose([
    T.RandomHorizontalFlip(),
    T.RandomVerticalFlip(),

    T.RandomRotation(90),
    T.RandomResizedCrop(256, scale=(0.8, 1.0)),
    T.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.02),

    T.ToTensor(),
    T.Normalize(*stats)
])

aggressive_transform = T.Compose([
    T.RandomHorizontalFlip(),
    T.RandomVerticalFlip(),

    T.RandomRotation(180),
    T.RandomResizedCrop(256, scale=(0.3, 1.0)),
    T.ColorJitter(brightness=0.6, contrast=0.6, saturation=0.6, hue=0.03),

    T.RandomApply([T.GaussianBlur(kernel_size=5, sigma=(0.1, 2.0))], p=0.5),

    T.ToTensor(),
    T.Normalize(*stats)
])

# Also some transformations for validation and test datasets
val_test_transform = T.Compose([
    T.ToTensor(),
    T.Normalize(*stats)
])

In [6]:
all_data = pandas.concat([train_csv, val_csv, test_csv])
master_classes = (
    all_data[["plant", "disease"]]
    .drop_duplicates()
    .sort_values(["plant", "disease"], ignore_index=True)
)

# We divide the training dataset into 3 different augmentation levels:
#   - 60% of them use mild augmentation
#   - 30% of them use moderate augmentation
#   - 10% of them use aggressive augmentation

# Deterministically shuffle the training dataframe
train_csv_shuffled = train_csv.sample(frac=1, random_state=seed).reset_index(drop=True)

# Lists to hold the dataframe chunks
mild_chunks = []
moderate_chunks = []
aggressive_chunks = []

# Group by (plant, disease) so we can split each class individually
grouped = train_csv.groupby(['plant', 'disease'])

for _, group in grouped:
    num_samples = len(group)
    mild_end = int(0.6 * num_samples)
    moderate_end = int(0.9 * num_samples)

    mild_chunks.append(group.iloc[:mild_end])
    moderate_chunks.append(group.iloc[mild_end:moderate_end])
    aggressive_chunks.append(group.iloc[moderate_end:])

train_mild_df = pandas.concat(mild_chunks)
train_moderate_df = pandas.concat(moderate_chunks)
train_aggressive_df = pandas.concat(aggressive_chunks)

train_ds_mild = ImgDataset(train_mild_df, TRAIN_IMG, transform=mild_transform, classes=master_classes)
train_ds_moderate = ImgDataset(train_moderate_df, TRAIN_IMG, transform=moderate_transform, classes=master_classes)
train_ds_aggressive = ImgDataset(train_aggressive_df, TRAIN_IMG, transform=aggressive_transform, classes=master_classes)

train_ds = torch.utils.data.ConcatDataset([train_ds_mild, train_ds_moderate, train_ds_aggressive])
val_ds = ImgDataset(val_csv, VAL_IMG, transform=val_test_transform, classes=master_classes)
test_ds = ImgDataset(test_csv, TEST_IMG, transform=val_test_transform, classes=master_classes)

num_classes = len(master_classes)

print(f"Total distinct classes: {len(master_classes)}")
print(train_ds_mild.classes)
print(f"Mild samples: {len(train_ds_mild)} samples with {len(train_ds_mild.classes)} classes,\n"
      f"Moderate: {len(train_ds_moderate)} samples with {len(train_ds_moderate.classes)} classes,\n"
      f"Aggressive: {len(train_ds_aggressive)} samples with {len(train_ds_aggressive.classes)} classes")

Total distinct classes: 15
         plant                             disease
0        Apple                          Apple_scab
1        Apple                           Black_rot
2        Apple                    Cedar_apple_rust
3        Apple                             healthy
4        Grape                           Black_rot
5        Grape                Esca_(Black_Measles)
6        Grape  Leaf_blight_(Isariopsis_Leaf_Spot)
7        Grape                             healthy
8        Peach                      Bacterial_spot
9        Peach                             healthy
10      Potato                        Early_blight
11      Potato                         Late_blight
12      Potato                             healthy
13  Strawberry                         Leaf_scorch
14  Strawberry                             healthy
Mild samples: 5714 samples with 15 classes,
Moderate: 2858 samples with 15 classes,
Aggressive: 959 samples with 15 classes


In [7]:
# Training configs
MODEL_DIR = os.path.join("/home/banana9205/Desktop/Main/Uni/DATH/models")

batch_size = 16
lr = 2e-4
weight_decay = 1e-3

model_configs = [
    {
        "name": "ResNet18",
        "model": ResNet18(block=ResidualBlock, blocks_per_layer=[2, 2, 2, 2], n_channels=3, n_classes=num_classes),
        "epochs": 10,
        "patience": 3,
        "device": device
    },
    {
        "name": "MobileNetV1",
        "model": MobileNetV1(n_channels=3, n_classes=num_classes),
        "epochs": 15,
        "patience": 5,
        "device": device
    },
    {
        "name": "MobileNetV2",
        "model": MobileNetV2(n_channels=3, n_classes=num_classes),
        "epochs": 15,
        "patience": 5,
        "device": device
    },
    {
        "name": "VGG16",
        "model": VGG16(n_channels=3, n_classes=num_classes),
        "epochs": 10,
        "patience": 3,
        "device": device
    }
]

In [8]:
def seed_worker(worker_id):
    # Derive a unique seed for this worker based on the global seed
    worker_seed = torch.initial_seed() % 2**32
    np.random.seed(worker_seed)
    random.seed(worker_seed)

# Create a deterministic generator for the DataLoader
g = torch.Generator()
g.manual_seed(seed)  # Uses the global 'seed' variable (42)

train_loader = DataLoader(train_ds, batch_size=batch_size, 
                          shuffle=True, num_workers=4, 
                          worker_init_fn=seed_worker, generator=g)

val_loader   = DataLoader(val_ds, batch_size=batch_size, 
                          shuffle=False, num_workers=4, 
                          worker_init_fn=seed_worker, generator=g)

test_loader  = DataLoader(test_ds, batch_size=batch_size, 
                          shuffle=False, num_workers=4, 
                          worker_init_fn=seed_worker, generator=g)

In [9]:
def train_and_evaluate(config):
    """
    A helper function to train and evaluate a model based on the provided configuration.
    Args: config (dict): A dictionary containing the following keys:
        - "name": Name of the model (str)
        - "model": The PyTorch model instance (nn.Module)
        - "epochs": Number of training epochs (int)
        - "patience": Patience for early stopping (int)
    """
    name = config["name"]
    model = config["model"].to(device)
    num_epochs = config["epochs"]
    patience = config["patience"]
    
    print(f"\n{'='*40}\nTraining {name}\n{'='*40}")
    
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
    
    best_val_acc = 0.0
    best_val_f1 = 0.0
    epochs_no_improve = 0
    
    for epoch in range(num_epochs):
        # Training phase
        model.train()
        train_loss, correct, total = 0, 0, 0

        for images, labels in tqdm(train_loader, desc=f"[Train] Epoch {epoch+1}/{num_epochs}"):
            images, labels = images.to(device), labels.to(device)
            
            outputs = model(images)
            loss = criterion(outputs, labels)
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            train_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            
        avg_train_loss = train_loss / len(train_loader)
        train_acc = correct / total * 100
        
        # Validation phase (Accuracy + F1 Score)
        model.eval()
        val_loss, correct, total = 0, 0, 0
        all_preds = []
        all_labels = []
        
        with torch.no_grad():
            for images, labels in tqdm(val_loader, desc=f"Epoch {epoch+1}/{num_epochs} [Val]"):
                images, labels = images.to(device), labels.to(device)
                outputs = model(images)
                loss = criterion(outputs, labels)
                
                val_loss += loss.item()
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()
                
                all_preds.extend(predicted.cpu().numpy())
                all_labels.extend(labels.cpu().numpy())
        
        avg_val_loss = val_loss / len(val_loader)
        val_acc = correct / total * 100
        val_f1 = f1_score(all_labels, all_preds, average='weighted')
        
        print(f"Results: Train Loss: {avg_train_loss:.4f} | Val Loss: {avg_val_loss:.4f}")
        print(f"         Train Acc:  {train_acc:.2f}%   | Val Acc:  {val_acc:.2f}%   | Val F1:   {val_f1:.4f}")

        # Early stopping check
        if val_f1 > best_val_f1:
            best_val_f1 = val_f1
            best_val_acc = val_acc
            epochs_no_improve = 0
            save_path = os.path.join(MODEL_DIR, f"{name}_best.pth")
            torch.save(model.state_dict(), save_path)
            print(f"Best model saved.")
        else:
            epochs_no_improve += 1
            print(f"No improvement for {epochs_no_improve}/{patience} epochs.")
            if epochs_no_improve >= patience:
                print(f"Early stopping triggered for {name}!")
                break

    return best_val_acc, best_val_f1

In [10]:
# Main training loop for all model configurations
models_to_train = []
results = {}

for config in model_configs:
    if config["name"] not in models_to_train:
        continue
    val_acc, val_f1 = train_and_evaluate(config)
    results[config["name"]] = (val_acc, val_f1)
    
print("\nFinal Metrics of Models:")
for name, (acc, f1) in results.items():
    print(f"{name}: {acc:.2f}% | {f1:.2f}%")


Final Metrics of Models:


In [11]:
# Testing phase
models_to_test = ["ResNet18", "MobileNetV1", "MobileNetV2"]

print("\nTesting Best Models on Test Set:")
for config in model_configs:
    name = config["name"]
    if name not in models_to_test:
        continue
    model = config["model"].to(device)
    load_path = os.path.join(MODEL_DIR, f"{name}_best.pth")
    model.load_state_dict(torch.load(load_path))
    model.eval()
    
    correct, total = 0, 0
    all_preds = []
    all_labels = []
    
    with torch.no_grad():
        for images, labels in tqdm(test_loader, desc=f"Testing {name}"):
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            
            all_preds.extend(predicted.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
    
    test_acc = correct / total * 100
    test_f1 = f1_score(all_labels, all_preds, average='weighted')
    
    print(f"{name} Test Accuracy: {test_acc:.2f}%, Test F1 Score: {test_f1:.4f}")


Testing Best Models on Test Set:


Testing ResNet18: 100%|██████████| 127/127 [00:04<00:00, 30.36it/s]


ResNet18 Test Accuracy: 97.19%, Test F1 Score: 0.9720


Testing MobileNetV1: 100%|██████████| 127/127 [00:03<00:00, 34.71it/s]


MobileNetV1 Test Accuracy: 89.85%, Test F1 Score: 0.8988


Testing MobileNetV2: 100%|██████████| 127/127 [00:04<00:00, 31.45it/s]

MobileNetV2 Test Accuracy: 94.38%, Test F1 Score: 0.9442





In [12]:
# Final model size and parameter count report
def get_model_size(model):
    param_size = 0
    for param in model.parameters():
        param_size += param.nelement() * param.element_size()
    buffer_size = 0
    for buffer in model.buffers():
        buffer_size += buffer.nelement() * buffer.element_size()

    size_all_mb = (param_size + buffer_size) / 1024**2
    return size_all_mb

print(f"{'Model':<15} | {'Params (M)':<12} | {'Size (MB)':<10}")

for config in model_configs:
    name = config["name"]
    if name not in models_to_test:
        continue
    model = config["model"]
    # Move to CPU for static calculation just in case
    model.cpu()
    
    num_params = sum(p.numel() for p in model.parameters()) / 1e6
    size_mb = get_model_size(model)
    
    print(f"{config['name']:<15} | {num_params:<12.2f} | {size_mb:<10.2f}")

Model           | Params (M)   | Size (MB) 
ResNet18        | 11.18        | 42.70     
MobileNetV1     | 3.22         | 12.38     
MobileNetV2     | 2.24         | 8.69      
