In [3]:
# notebooks/train_model.ipynb

# --- Install required libraries if needed ---
# !pip install torch torchvision

# --- Imports ---
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.models as models
from torch.optim import lr_scheduler
import time
import copy
import os
import sys
sys.path.append("../scripts")

from train_utils import EarlyStopping
from data_loader import load_tile_dataset

# --- Load Data ---
processed_tiles_folder = "../data/processed_tiles/"
batch_size = 32
image_size = 224

train_loader, val_loader, class_names = load_tile_dataset(
    data_dir=processed_tiles_folder,
    batch_size=batch_size,
    shuffle=True,
    image_size=image_size,
    val_split=0.2
)

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# --- Build Model ---
model = models.resnet50(weights="DEFAULT")
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, len(class_names))
model = model.to(device)

# --- Define Optimizer and Scheduler ---
optimizer = optim.Adam(model.parameters(), lr=1e-4)
scheduler = lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1)
criterion = nn.CrossEntropyLoss()

# --- Early Stopping ---
early_stopping = EarlyStopping(patience=5, verbose=True)

# --- Define train_model() function ---
def train_model(model, train_loader, val_loader, criterion, optimizer, scheduler, early_stopping, num_epochs=10):
    since = time.time()

    best_model_wts = copy.deepcopy(model.state_dict())
    best_acc = 0.0

    for epoch in range(num_epochs):
        print(f"\nEpoch {epoch+1}/{num_epochs}")
        print('-' * 20)

        # --- Training Phase ---
        model.train()
        running_loss = 0.0
        running_corrects = 0

        for inputs, labels in train_loader:
            inputs = inputs.to(device)
            labels = labels.to(device)

            optimizer.zero_grad()  # <- Only reset gradients
            outputs = model(inputs)
            _, preds = torch.max(outputs, 1)
            loss = criterion(outputs, labels)

            loss.backward()
            optimizer.step()

            running_loss += loss.item() * inputs.size(0)
            running_corrects += torch.sum(preds == labels.data)

        scheduler.step()

        epoch_loss = running_loss / len(train_loader.dataset)
        epoch_acc = running_corrects.double() / len(train_loader.dataset)

        # --- Validation Phase ---
        model.eval()
        val_loss = 0.0
        val_corrects = 0

        with torch.no_grad():
            for inputs, labels in val_loader:
                inputs = inputs.to(device)
                labels = labels.to(device)

                outputs = model(inputs)
                _, preds = torch.max(outputs, 1)
                loss = criterion(outputs, labels)

                val_loss += loss.item() * inputs.size(0)
                val_corrects += torch.sum(preds == labels.data)

        val_epoch_loss = val_loss / len(val_loader.dataset)
        val_epoch_acc = val_corrects.double() / len(val_loader.dataset)

        print(f"Train Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}")
        print(f"Val   Loss: {val_epoch_loss:.4f} Acc: {val_epoch_acc:.4f}")

        # Save best model based on validation accuracy
        if val_epoch_acc > best_acc:
            best_acc = val_epoch_acc
            best_model_wts = copy.deepcopy(model.state_dict())

        # --- Early Stopping Check ---
        early_stopping(val_epoch_loss)
        if early_stopping.early_stop:
            print("\nEarly stopping triggered. Stopping training early!")
            break

    time_elapsed = time.time() - since
    print(f"\nTraining complete in {time_elapsed // 60:.0f}m {time_elapsed % 60:.0f}s")
    print(f"Best Val Acc: {best_acc:.4f}")

    model.load_state_dict(best_model_wts)
    return model

# --- Actually Train the Model ---
trained_model = train_model(
    model,
    train_loader,
    val_loader,
    criterion,
    optimizer,
    scheduler,
    early_stopping,
    num_epochs=30
)

# --- Save the Trained Model ---
torch.save(trained_model.state_dict(), "../data/models/trained_resnet50.pth")
print("✅ Model saved!")


Using device: cpu
Downloading: "https://download.pytorch.org/models/resnet50-11ad3fa6.pth" to C:\Users\frank/.cache\torch\hub\checkpoints\resnet50-11ad3fa6.pth


100%|██████████| 97.8M/97.8M [00:08<00:00, 11.5MB/s]



Epoch 1/30
--------------------
Train Loss: 0.9653 Acc: 0.6345
Val   Loss: 0.8519 Acc: 1.0000

Epoch 2/30
--------------------
Train Loss: 0.5942 Acc: 0.9848
Val   Loss: 0.5498 Acc: 0.9796

Epoch 3/30
--------------------
Train Loss: 0.2664 Acc: 0.9949
Val   Loss: 0.3079 Acc: 1.0000

Epoch 4/30
--------------------
Train Loss: 0.1578 Acc: 0.9949
Val   Loss: 0.1324 Acc: 1.0000

Epoch 5/30
--------------------
Train Loss: 0.0873 Acc: 0.9949
Val   Loss: 0.0485 Acc: 1.0000

Epoch 6/30
--------------------
Train Loss: 0.0509 Acc: 0.9949
Val   Loss: 0.0188 Acc: 1.0000

Epoch 7/30
--------------------
Train Loss: 0.0210 Acc: 1.0000
Val   Loss: 0.0149 Acc: 1.0000

Epoch 8/30
--------------------
Train Loss: 0.0256 Acc: 1.0000
Val   Loss: 0.0171 Acc: 1.0000
EarlyStopping counter: 1 out of 5

Epoch 9/30
--------------------
Train Loss: 0.0203 Acc: 1.0000
Val   Loss: 0.0189 Acc: 1.0000
EarlyStopping counter: 2 out of 5

Epoch 10/30
--------------------
Train Loss: 0.0398 Acc: 1.0000
Val   Loss: 