# SakuraScan – Modelling (PyTorch)

## Objectives
- Train a binary image classifier to distinguish healthy cherry leaves from leaves with powdery mildew.
- Use transfer learning with a pretrained CNN (ResNet18) for robust performance.
- Save the trained model for use in the SakuraScan Streamlit dashboard.

## Inputs
- Image dataset stored in `Data/source_images/healthy` and `Data/source_images/powdery_mildew`.

## Outputs
- Trained PyTorch model weights saved to `app_pages/src/models/sakuramodel_resnet18.pth`.
- Printed training and validation accuracy and loss.


In [5]:
"""
Model training script for SakuraScan using PyTorch and transfer learning.
"""

from pathlib import Path
import os
from typing import Tuple, Dict, List

import torch
from torch import nn, optim
from torch.utils.data import DataLoader, random_split
from torchvision import datasets, transforms, models

In [15]:
"""
Set up paths, constants, and devive configuration.
"""

from pathlib import Path

# Project root is the parent folder of the notebooks directory
PROJECT_ROOT = Path("..").resolve()

DATA_DIR = PROJECT_ROOT / "Data" / "source_images"
MODEL_DIR = PROJECT_ROOT / "app_pages" / "src" / "models"
MODEL_DIR.mkdir(parents=True, exist_ok=True)

MODEL_PATH = MODEL_DIR / "sakuramodel_resnet18.pth"

BATCH_SIZE = 32  # Number of images processed in one training step.
NUM_EPOCHS = 8  # How many full passes the model makes over the entire training dataset.
LEARNING_RATE = 1e-4  # Controls how big the weight updates are during training.
VAL_SPLIT = 0.2  # Fraction of the dataset reserved for validation to evaluate model performance.
IMAGE_SIZE = 224  # Target resolution for all input images (ResNet models expect 224×224 pixels).

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cpu')

In [12]:
"""
Create training and validation datasets and dataloaders using ImageFolder.
"""

# Data augmentation and normalization for training
train_transform = transforms.Compose(
    [
        transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
        transforms.RandomHorizontalFlip(),
        transforms.RandomRotation(10),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406],
                             [0.229, 0.224, 0.225]),
    ]
)

# Only resize + normalize for validation
val_transform = transforms.Compose(
    [
        transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406],
                             [0.229, 0.224, 0.225]),
    ]
)

# Load all images with a temporary transform (updated per split)
full_dataset = datasets.ImageFolder(root=str(DATA_DIR), transform=train_transform)

class_names: List[str] = full_dataset.classes
class_names

['healthy', 'powdery_mildew']

In [13]:
"""
Functions for training and evaluating the model.
"""

def train_one_epoch(model, dataloader, criterion, optimizer, device):
    """
    Train the model for one epoch and return loss and accuracy.
    """
    model.train()
    total_loss = 0
    correct = 0
    total = 0

    for inputs, labels in dataloader:
        inputs, labels = inputs.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        total_loss += loss.item() * inputs.size(0)
        preds = outputs.argmax(1)
        correct += (preds == labels).sum().item()
        total += labels.size(0)

    return total_loss / total, correct / total


def evaluate(model, dataloader, criterion, device):
    """
    Evaluate the model on validation data and return loss and accuracy.
    """
    model.eval()
    total_loss = 0
    correct = 0
    total = 0

    with torch.no_grad():
        for inputs, labels in dataloader:
            inputs, labels = inputs.to(device), labels.to(device)

            outputs = model(inputs)
            loss = criterion(outputs, labels)

            total_loss += loss.item() * inputs.size(0)
            preds = outputs.argmax(1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)

    return total_loss / total, correct / total

In [14]:
"""
Create a ResNet18 model for binary classification.
"""

def create_model(num_classes):
    """
    Load a pretrained ResNet18 and replace the final layer.
    """
    model = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)

    # Freeze backbone to focus training on final layer
    for param in model.parameters():
        param.requires_grad = False

    # Replace the classifier
    model.fc = nn.Linear(model.fc.in_features, num_classes)

    return model


model = create_model(len(class_names))
model = model.to(device)
model

0.3%

Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to C:\Users\NeosT/.cache\torch\hub\checkpoints\resnet18-f37072fd.pth


100.0%


ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
  

In [16]:
"""
Create training and validation dataloaders from the full dataset.
"""

dataset_size = len(full_dataset)
val_size = int(dataset_size * VAL_SPLIT)
train_size = dataset_size - val_size

train_dataset, val_dataset = random_split(full_dataset, [train_size, val_size])

train_loader = DataLoader(
    train_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True
)

val_loader = DataLoader(
    val_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False
)

len(train_loader.dataset), len(val_loader.dataset)

(3367, 841)

In [17]:
"""
Train the model for several epochs and print loss and accuracy.
"""

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.fc.parameters(), lr=LEARNING_RATE)

history = {"train_loss": [], "train_acc": [], "val_loss": [], "val_acc": []}

for epoch in range(NUM_EPOCHS):
    train_loss, train_acc = train_one_epoch(
        model, train_loader, criterion, optimizer, device
    )

    val_loss, val_acc = evaluate(
        model, val_loader, criterion, device
    )

    history["train_loss"].append(train_loss)
    history["train_acc"].append(train_acc)
    history["val_loss"].append(val_loss)
    history["val_acc"].append(val_acc)

    print(
        f"Epoch {epoch+1}/{NUM_EPOCHS} "
        f"| Train Loss: {train_loss:.4f} Acc: {train_acc:.4f} "
        f"| Val Loss: {val_loss:.4f} Acc: {val_acc:.4f}"
    )

Epoch 1/8 | Train Loss: 0.4909 Acc: 0.8402 | Val Loss: 0.3452 Acc: 0.9631
Epoch 2/8 | Train Loss: 0.2773 Acc: 0.9849 | Val Loss: 0.2116 Acc: 0.9941
Epoch 3/8 | Train Loss: 0.1980 Acc: 0.9878 | Val Loss: 0.1440 Acc: 0.9941
Epoch 4/8 | Train Loss: 0.1455 Acc: 0.9941 | Val Loss: 0.1131 Acc: 0.9964
Epoch 5/8 | Train Loss: 0.1201 Acc: 0.9938 | Val Loss: 0.0874 Acc: 1.0000
Epoch 6/8 | Train Loss: 0.1019 Acc: 0.9926 | Val Loss: 0.0760 Acc: 0.9964
Epoch 7/8 | Train Loss: 0.0821 Acc: 0.9964 | Val Loss: 0.0595 Acc: 1.0000
Epoch 8/8 | Train Loss: 0.0704 Acc: 0.9964 | Val Loss: 0.0529 Acc: 0.9988


In [18]:
"""
Save the trained model weights and metadata for later use.
"""

def save_model(model, path, class_names, image_size):
    """
    Save the model state_dict and basic metadata to a .pth file.
    """
    torch.save(
        {
            "model_state_dict": model.state_dict(),
            "class_names": class_names,
            "image_size": image_size,
            "arch": "resnet18",
        },
        path,
    )
    print(f"Model saved to: {path}")


save_model(model, MODEL_PATH, class_names, IMAGE_SIZE)

Model saved to: C:\Users\NeosT\OneDrive\Skrivbord\VsCode-Projects\SakuraScan\SakuraScan\app_pages\src\models\sakuramodel_resnet18.pth
