In [1]:
# Notebook: Train EfficientNet-B0 on PlantVillage Dataset

# -----------------------------
# Fix Python path (for imports)
# -----------------------------
import sys
import os
sys.path.append(os.path.abspath(".."))  # Add project root to Python path

# -----------------------------
# Imports
# -----------------------------
import torch
import torch.nn as nn
from torch.optim import AdamW
from tqdm import tqdm
from src.vision.utils import prepare_dataloaders
from src.vision.model import CropDiseaseClassifier

# -----------------------------
# Config
# -----------------------------
DATA_DIR = "../data/PlantVillage"  # PlantVillage dataset folder
BATCH_SIZE = 32
EPOCHS = 10
LR = 1e-3
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
MODEL_SAVE_PATH = "../best_model.pt"  # Save model in project root

# -----------------------------
# Load Data
# -----------------------------
train_loader, val_loader, class_to_idx = prepare_dataloaders(DATA_DIR, batch_size=BATCH_SIZE)
num_classes = len(class_to_idx)
print(f"Number of classes: {num_classes}")

# -----------------------------
# Initialize Model
# -----------------------------
model = CropDiseaseClassifier(num_classes=num_classes, pretrained=True).to(DEVICE)
criterion = nn.CrossEntropyLoss(label_smoothing=0.1)
optimizer = AdamW(model.parameters(), lr=LR)

# -----------------------------
# Training Loop
# -----------------------------
best_val_acc = 0
for epoch in range(EPOCHS):
    model.train()
    running_loss = 0
    correct = 0
    total = 0

    for images, labels in tqdm(train_loader, desc=f"Epoch {epoch+1}/{EPOCHS}"):
        images, labels = images.to(DEVICE), labels.to(DEVICE)
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item() * images.size(0)
        _, preds = torch.max(outputs, 1)
        correct += (preds == labels).sum().item()
        total += labels.size(0)

    train_loss = running_loss / total
    train_acc = correct / total

    # Validation
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in val_loader:
            images, labels = images.to(DEVICE), labels.to(DEVICE)
            outputs = model(images)
            _, preds = torch.max(outputs, 1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)

    val_acc = correct / total
    print(f"Epoch {epoch+1}: Train Loss={train_loss:.4f}, Train Acc={train_acc:.4f}, Val Acc={val_acc:.4f}")

    # Save best model
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        torch.save(model.state_dict(), MODEL_SAVE_PATH)
        print(f"Saved best model to {MODEL_SAVE_PATH}!")




Number of classes: 16


Epoch 1/10: 100%|██████████| 1032/1032 [06:28<00:00,  2.66it/s]


Epoch 1: Train Loss=1.3786, Train Acc=0.4815, Val Acc=0.4758
Saved best model to ../best_model.pt!


Epoch 2/10: 100%|██████████| 1032/1032 [04:32<00:00,  3.78it/s]


Epoch 2: Train Loss=1.2624, Train Acc=0.4840, Val Acc=0.4732


Epoch 3/10: 100%|██████████| 1032/1032 [03:37<00:00,  4.74it/s]


Epoch 3: Train Loss=1.2380, Train Acc=0.4919, Val Acc=0.4852
Saved best model to ../best_model.pt!


Epoch 4/10: 100%|██████████| 1032/1032 [04:21<00:00,  3.94it/s]


Epoch 4: Train Loss=1.2231, Train Acc=0.4928, Val Acc=0.4818


Epoch 5/10: 100%|██████████| 1032/1032 [04:23<00:00,  3.91it/s]


Epoch 5: Train Loss=1.2174, Train Acc=0.4874, Val Acc=0.4712


Epoch 6/10: 100%|██████████| 1032/1032 [03:48<00:00,  4.51it/s]


Epoch 6: Train Loss=1.2123, Train Acc=0.4985, Val Acc=0.4687


Epoch 7/10: 100%|██████████| 1032/1032 [03:43<00:00,  4.61it/s]


Epoch 7: Train Loss=1.2097, Train Acc=0.4968, Val Acc=0.4671


Epoch 8/10: 100%|██████████| 1032/1032 [03:45<00:00,  4.57it/s]


Epoch 8: Train Loss=1.2042, Train Acc=0.4963, Val Acc=0.4714


Epoch 9/10: 100%|██████████| 1032/1032 [03:41<00:00,  4.67it/s]


Epoch 9: Train Loss=1.2024, Train Acc=0.4979, Val Acc=0.4623


Epoch 10/10: 100%|██████████| 1032/1032 [03:43<00:00,  4.61it/s]


Epoch 10: Train Loss=1.1991, Train Acc=0.4964, Val Acc=0.4701
