In [1]:
# 02_train_model.ipynb - Cell 1
import os
from pathlib import Path
import time
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms, models
import numpy as np
from sklearn.metrics import accuracy_score, f1_score, classification_report
from tqdm.notebook import tqdm

# Paths (notebook in notebooks/)
ROOT = Path("..")
DATA_DIR = ROOT / "data"
TRAIN_DIR = DATA_DIR / "train"
VAL_DIR = DATA_DIR / "val"
TEST_DIR = DATA_DIR / "test"
MODEL_DIR = ROOT / "models"
MODEL_DIR.mkdir(parents=True, exist_ok=True)

# Hyperparams
IMAGE_SIZE = 224
BATCH_SIZE = 32
NUM_EPOCHS = 8
LR = 1e-4
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", DEVICE)



Using device: cpu


In [2]:
# Cell 2 - data transforms and loaders
train_t = transforms.Compose([
    transforms.RandomResizedCrop(IMAGE_SIZE, scale=(0.7,1.0)),
    transforms.RandomHorizontalFlip(),
    transforms.ColorJitter(0.2,0.2,0.2,0.1),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485,0.456,0.406],
                         std=[0.229,0.224,0.225]),
])
val_t = transforms.Compose([
    transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485,0.456,0.406],
                         std=[0.229,0.224,0.225]),
])

train_ds = datasets.ImageFolder(TRAIN_DIR, transform=train_t)
val_ds = datasets.ImageFolder(VAL_DIR, transform=val_t)
test_ds = datasets.ImageFolder(TEST_DIR, transform=val_t)

train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=4)
val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)
test_loader = DataLoader(test_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)

classes = train_ds.classes
num_classes = len(classes)
print("Classes (count):", num_classes)
print(classes)


Classes (count): 15
['Pepper__bell___Bacterial_spot', 'Pepper__bell___healthy', 'Potato___Early_blight', 'Potato___Late_blight', 'Potato___healthy', 'Tomato_Bacterial_spot', 'Tomato_Early_blight', 'Tomato_Late_blight', 'Tomato_Leaf_Mold', 'Tomato_Septoria_leaf_spot', 'Tomato_Spider_mites_Two_spotted_spider_mite', 'Tomato__Target_Spot', 'Tomato__Tomato_YellowLeaf__Curl_Virus', 'Tomato__Tomato_mosaic_virus', 'Tomato_healthy']


In [3]:
# Cell 3 - model, loss, optimizer, helper fn
model = models.resnet18(pretrained=True)
model.fc = nn.Linear(model.fc.in_features, num_classes)
model = model.to(DEVICE)

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=LR)

def evaluate_model(net, loader):
    net.eval()
    all_preds = []
    all_labels = []
    with torch.no_grad():
        for imgs, labels in loader:
            imgs = imgs.to(DEVICE)
            labels = labels.to(DEVICE)
            outs = net(imgs)
            preds = torch.argmax(outs, dim=1)
            all_preds.extend(preds.cpu().numpy().tolist())
            all_labels.extend(labels.cpu().numpy().tolist())
    acc = accuracy_score(all_labels, all_preds)
    f1 = f1_score(all_labels, all_preds, average="weighted")
    return acc, f1, all_labels, all_preds




In [4]:
# Cell 4 - training loop
best_val_acc = 0.0
best_path = MODEL_DIR / "best_model.pth"

for epoch in range(1, NUM_EPOCHS + 1):
    model.train()
    pbar = tqdm(train_loader, desc=f"Epoch {epoch}/{NUM_EPOCHS} - train")
    running_loss = 0.0
    for imgs, labels in pbar:
        imgs = imgs.to(DEVICE)
        labels = labels.to(DEVICE)
        outs = model(imgs)
        loss = criterion(outs, labels)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
        pbar.set_postfix(loss=loss.item())
    avg_loss = running_loss / len(train_loader)
    val_acc, val_f1, _, _ = evaluate_model(model, val_loader)
    print(f"Epoch {epoch} completed. Train loss: {avg_loss:.4f} | Val acc: {val_acc:.4f} | Val F1: {val_f1:.4f}")
    # save best
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        torch.save({
            "model_state": model.state_dict(),
            "classes": classes,
            "epoch": epoch,
            "val_acc": val_acc
        }, best_path)
        print("Saved best model to:", best_path)
print("Training finished. Best val acc:", best_val_acc)


Epoch 1/8 - train:   0%|          | 0/587 [00:00<?, ?it/s]

KeyboardInterrupt: 