In [1]:
import os
import numpy as np
import pandas as pd
from glob import glob
from natsort import natsorted
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision.datasets import ImageFolder
from torchvision.transforms import transforms
from torch.utils.data import DataLoader
from efficientnet_pytorch import EfficientNet
import torchvision.datasets as datasets
import matplotlib
import matplotlib.pyplot as plt
matplotlib.style.use('ggplot')

# Define transforms for the dataset
train_transform = transforms.Compose([
    transforms.RandomResizedCrop(224),
    transforms.RandomHorizontalFlip(0.5),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

val_transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

In [None]:
#define dataset and dataloaders
train_df_src = r'\\fatherserverdw\Kevin\unstained_blank_classifier\train_df.xlsx'
train_df = pd.read_excel(train_df_src) # 1= white , 0=nonwhite, unbalanced, 79271 0's and 195376 1's. Need stratifiedgroupKfold for CV.

In [None]:
#define dataset and dataloaders
train_dataset = ImageFolder(root='path/to/train/folder', transform=train_transform)
train_dataloader = DataLoader(train_dataset, batch_size=32, shuffle=True)

val_dataset = ImageFolder(root='path/to/val/folder', transform=val_transform)
val_dataloader = DataLoader(val_dataset, batch_size=32, shuffle=False)

In [None]:
# #define pre-trained resNet-18 model
# model = torch.hub.load('pytorch/vision', 'resnet18', pretrained=True)
# num_ftrs = model.fc.in_features
# model.fc = nn.Linear(num_ftrs, 2)  #replace/edit output layer with a new linear layer for binary classification- blank or not blank

# use efficientnetv2 small instead, should do better than resnet18/50
model = EfficientNet.from_pretrained('efficientnetv2-s')

# Modify the last layer to output a single binary classification output
in_features = model.classifier.in_features
model.classifier = nn.Linear(in_features, 1)

#define loss function, optimizer and device
criterion = nn.BCEWithLogitsLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
# optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model.to(device)

### training loop:

In [None]:
def save_model(epoch, model, optimizer, criterion):
    torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'loss': criterion,
                }, 'outputs/model.pth')

In [None]:
def save_plots(train_accuracy_list, val_accuracy_list,train_loss_list,val_loss_list):    # accuracy plots
    plt.figure(figsize=(10, 7))
    plt.plot(
        train_accuracy_list, color='green', linestyle='-',
        label='train accuracy'
    )
    plt.plot(
        val_accuracy_list, color='blue', linestyle='-',
        label='validation accuracy'
    )
    plt.xlabel('Epochs')
    plt.ylabel('Accuracy')
    plt.legend()
    plt.savefig('outputs/accuracy.png')

    # loss plots
    plt.figure(figsize=(10, 7))
    plt.plot(
        train_loss_list, color='orange', linestyle='-',
        label='train loss'
    )
    plt.plot(
        val_loss_list, color='red', linestyle='-',
        label='validataion loss'
    )
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.legend()
    plt.savefig('outputs/loss_vs_epochs.png')

In [None]:
#training loop
num_epochs = 10
train_loss_list, val_loss_list = [], []
train_accuracy_list, val_accuracy_list = [], []

for epoch in range(num_epochs):
    train_loss = 0.0
    train_correct = 0
    train_total = 0

    for inputs, labels in train_dataloader:
        inputs, labels = inputs.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        _, predicted = torch.max(outputs.data, 1)
        train_loss += loss.item()
        train_correct += (predicted == labels).sum().item()
        train_total += labels.size(0)

    train_loss = train_loss / len(train_dataset)
    train_accuracy = train_correct / train_total
    train_loss_list.append(train_loss)
    train_accuracy_list.append(train_accuracy)

    val_loss = 0.0
    val_correct = 0
    val_total = 0

    with torch.no_grad():
        for inputs, labels in val_dataloader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, labels)

            _, predicted = torch.max(outputs.data, 1)
            val_loss += loss.item()
            val_correct +=  (predicted == labels).sum().item()
            val_total += labels.size(0)

    val_loss = val_loss / len(val_dataset)
    val_accuracy = val_correct / val_total
    val_loss_list.append(val_loss)
    val_accuracy_list.append(val_accuracy)

    print(f"Epoch [{epoch+1}/{num_epochs}], Train Loss: {train_loss:.4f}, Train Accuracy: {train_accuracy:.4f}, Val Loss: {val_loss:.4f}, Val Accuracy: {val_accuracy:.4f}")

save_model(epoch, model, optimizer, criterion)
save_plots(train_accuracy, val_accuracy,train_loss,val_loss)