In [1]:
import os
import numpy as np
import pandas
from PIL import Image
from tqdm import tqdm

import torch
import torch.nn as nn
import torchvision.transforms as T
import torch.optim as optim
from torch.utils.data import DataLoader


from resnet import ResNet18, ResidualBlock
from data import ImgDataset

In [2]:
seed = 42
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.manual_seed(seed)
np.random.seed(seed)
print(f"Using device: {device}")

Using device: cuda


In [3]:
CSV_PATH = "/home/banana9205/Desktop/Main/Uni/DATH/Dataset/plantvillage dataset/dataframes"
train_csv_path = os.path.join(CSV_PATH, "train_labels.csv")
val_csv_path = os.path.join(CSV_PATH, "val_labels.csv")
test_csv_path = os.path.join(CSV_PATH, "test_labels.csv")

print(f"Path to training CSV: {train_csv_path}")
print(f"Path to validation CSV: {val_csv_path}")
print(f"Path to test CSV: {test_csv_path}")

train_csv = pandas.read_csv(train_csv_path)
val_csv = pandas.read_csv(val_csv_path)
test_csv = pandas.read_csv(test_csv_path)

Path to training CSV: /home/banana9205/Desktop/Main/Uni/DATH/Dataset/plantvillage dataset/dataframes/train_labels.csv
Path to validation CSV: /home/banana9205/Desktop/Main/Uni/DATH/Dataset/plantvillage dataset/dataframes/val_labels.csv
Path to test CSV: /home/banana9205/Desktop/Main/Uni/DATH/Dataset/plantvillage dataset/dataframes/test_labels.csv


In [None]:
IMG_PATH = "/home/banana9205/Desktop/Main/Uni/DATH/Dataset/plantvillage dataset"
TRAIN_IMG = os.path.join(IMG_PATH, "train")
VAL_IMG = os.path.join(IMG_PATH, "val")
TEST_IMG = os.path.join(IMG_PATH, "test")

# Augmentation
transform = T.Compose([
    T.RandomHorizontalFlip(),
    T.RandomVerticalFlip(),
    T.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
    T.ToTensor()
])

train_ds = ImgDataset(train_csv, TRAIN_IMG, transform = transform)
val_ds = ImgDataset(val_csv, VAL_IMG)
test_ds = ImgDataset(test_csv, TEST_IMG)

num_classes = len(train_ds.classes)
print(f"Number of classes: {num_classes}")
print(train_ds.classes)

Number of classes: 15
         plant                             disease
0        Apple                          Apple_scab
1        Apple                           Black_rot
2        Apple                    Cedar_apple_rust
3        Apple                             healthy
4        Grape                           Black_rot
5        Grape                Esca_(Black_Measles)
6        Grape  Leaf_blight_(Isariopsis_Leaf_Spot)
7        Grape                             healthy
8        Peach                      Bacterial_spot
9        Peach                             healthy
10      Potato                        Early_blight
11      Potato                         Late_blight
12      Potato                             healthy
13  Strawberry                         Leaf_scorch
14  Strawberry                             healthy


In [5]:
# Training params
batch_size = 32
lr = 2e-4
weight_decay = 1e-3
num_epochs = 10
img_size = 256

In [6]:
MODEL_DIR = os.path.join("/home/banana9205/Desktop/Main/Uni/DATH/models")

train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=2)
val_loader   = DataLoader(val_ds, batch_size=batch_size, shuffle=False, num_workers=2)
test_loader  = DataLoader(test_ds, batch_size=batch_size, shuffle=False, num_workers=2)

model = ResNet18(block = ResidualBlock,
                 blocks_per_layer = [2, 2, 2, 2],
                 n_channels = 3,
                 n_classes = num_classes).to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(), 
                        lr = lr,
                        weight_decay = weight_decay)

In [7]:
# Training loop
best_val_acc = 0

for epoch in range(num_epochs):
    print(f"\nEpoch [{epoch+1}/{num_epochs}]")
    model.train()

    loss, correct, total = 0, 0, 0
    for images, labels in tqdm(train_loader, desc="Training"):
        images, labels = images.to(device), labels.to(device)

        outputs = model(images)
        batch_loss = criterion(outputs, labels)

        optimizer.zero_grad()
        batch_loss.backward()
        optimizer.step()

        loss += batch_loss.item()
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

    train_loss = loss / len(train_loader)
    train_acc = correct / total * 100
    print(f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}%")

    model.eval()
    loss, correct, total = 0, 0, 0
    with torch.no_grad():
        for images, labels in tqdm(val_loader, desc="Validation"):
            images, labels = images.to(device), labels.to(device)

            outputs = model(images)
            batch_loss = criterion(outputs, labels)

            loss += batch_loss.item()
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    val_loss = loss / len(val_loader)
    val_acc = correct / total * 100
    print(f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.2f}%")

    if val_acc > best_val_acc:
        best_val_acc = val_acc
        torch.save(model.state_dict(), os.path.join(MODEL_DIR, "best_model.pth"))
        print("Best model saved.")


Epoch [1/10]


Training: 100%|██████████| 298/298 [00:46<00:00,  6.44it/s]


Train Loss: 0.8925, Train Acc: 70.60%


Validation: 100%|██████████| 64/64 [00:03<00:00, 18.60it/s]


Val Loss: 0.7948, Val Acc: 73.13%
Best model saved.

Epoch [2/10]


Training: 100%|██████████| 298/298 [00:50<00:00,  5.85it/s]


Train Loss: 0.4310, Train Acc: 85.32%


Validation: 100%|██████████| 64/64 [00:03<00:00, 18.15it/s]


Val Loss: 0.9139, Val Acc: 71.96%

Epoch [3/10]


Training: 100%|██████████| 298/298 [00:49<00:00,  6.06it/s]


Train Loss: 0.2663, Train Acc: 91.12%


Validation: 100%|██████████| 64/64 [00:03<00:00, 17.81it/s]


Val Loss: 0.3224, Val Acc: 88.57%
Best model saved.

Epoch [4/10]


Training: 100%|██████████| 298/298 [00:54<00:00,  5.47it/s]


Train Loss: 0.2087, Train Acc: 93.19%


Validation: 100%|██████████| 64/64 [00:03<00:00, 17.47it/s]


Val Loss: 0.2140, Val Acc: 92.92%
Best model saved.

Epoch [5/10]


Training: 100%|██████████| 298/298 [00:58<00:00,  5.11it/s]


Train Loss: 0.1474, Train Acc: 95.27%


Validation: 100%|██████████| 64/64 [00:03<00:00, 18.21it/s]


Val Loss: 0.1981, Val Acc: 93.21%
Best model saved.

Epoch [6/10]


Training: 100%|██████████| 298/298 [00:58<00:00,  5.07it/s]


Train Loss: 0.1216, Train Acc: 95.99%


Validation: 100%|██████████| 64/64 [00:03<00:00, 17.98it/s]


Val Loss: 0.2789, Val Acc: 91.01%

Epoch [7/10]


Training: 100%|██████████| 298/298 [00:58<00:00,  5.08it/s]


Train Loss: 0.0943, Train Acc: 96.90%


Validation: 100%|██████████| 64/64 [00:06<00:00,  9.63it/s]


Val Loss: 0.1096, Val Acc: 96.48%
Best model saved.

Epoch [8/10]


Training: 100%|██████████| 298/298 [00:59<00:00,  5.05it/s]


Train Loss: 0.0869, Train Acc: 97.31%


Validation: 100%|██████████| 64/64 [00:03<00:00, 17.81it/s]


Val Loss: 0.1069, Val Acc: 95.99%

Epoch [9/10]


Training: 100%|██████████| 298/298 [00:58<00:00,  5.11it/s]


Train Loss: 0.0763, Train Acc: 97.53%


Validation: 100%|██████████| 64/64 [00:06<00:00,  9.72it/s]


Val Loss: 0.1115, Val Acc: 96.34%

Epoch [10/10]


Training: 100%|██████████| 298/298 [01:00<00:00,  4.93it/s]


Train Loss: 0.0832, Train Acc: 97.31%


Validation: 100%|██████████| 64/64 [00:03<00:00, 17.95it/s]

Val Loss: 0.1367, Val Acc: 95.46%





In [8]:
# Test the model
model.load_state_dict(torch.load(os.path.join(MODEL_DIR, "best_model.pth")))
model.eval()
loss, correct, total = 0, 0, 0
with torch.no_grad():
    for images, labels in tqdm(test_loader, desc="Testing"):
        images, labels = images.to(device), labels.to(device)

        outputs = model(images)
        batch_loss = criterion(outputs, labels)

        loss += batch_loss.item()
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

test_loss = loss / len(test_loader)
test_acc = correct / total * 100
print(f"Test Loss: {test_loss:.4f}, Test Acc: {test_acc:.2f}%")

Testing: 100%|██████████| 64/64 [00:03<00:00, 17.72it/s]

Test Loss: 0.1082, Test Acc: 96.80%



