In [None]:
from google.colab import drive
drive.mount("/content/gdrive")

import os
os.chdir("/content/gdrive/My Drive/CS 444/proj")

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import numpy as np

In [None]:
from model import BaseNetwork
from loss import quadratic_weighted_kappa
from util import load_data, plot

In [None]:
# Create the BaseNet and define a couple of parameters
num_classes = 6 # ISUP scores range from 0 to 5
basenet = BaseNetwork(num_classes)

learning_rate = 0.001
num_epochs = 20
batch_size = 24

In [None]:
print(torch.cuda.is_available())
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [None]:
# Load training and validation data
train_data_path = 'data/train/train_data_cleaned/'
val_data_path = 'data/validation/val_data_cleaned/'

train_dataset = load_data(train_data_path)
val_dataset = load_data(val_data_path)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=True, num_workers=2)

In [None]:
# For different models select the optimizer and loss function you want to use 

# Choose between SGD or Adam
optimizer = torch.optim.SGD(basenet.parameters(), lr=learning_rate, momentum=0.9, weight_decay=5e-4) 

# Choose between Categorical Cross Entropy, Mean Square Error, or Quadratic Weighted Kappa 
loss = nn.CrossEntropyLoss()


In [None]:
# Training Loop
train_loss = []
val_loss = []
train_acc = []
val_acc = []

for epoch in range(num_epochs):
    basenet.train()

    # Update learning rate late in training
    if epoch == 30 or epoch == 40:
        learning_rate /= 10.0

    for i, data in enumerate(train_loader):
        data = data.to(device)
        image, target = data 
        loss_value = loss(pred, target)

        # Zero out the gradient and take an optimizer step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        train_loss.append(loss_value)

    with torch.no_grad():
        basenet.eval()

        for i, data in enumerate(val_loader):
            data = data.to(device)
            image, target = data
            pred = basenet(image)
            loss_value = loss(pred, target)
            val_loss.append(loss_value)

In [None]:
# Display the results
plot('SGD', 'Quadratic Weighted Kappa', train_loss=train_loss, val_loss=val_loss, train_acc=train_acc, val_acc=val_acc)