In [1]:
import torch
from torch.utils.data import DataLoader

from torchvision.models import resnet18
from torchvision.datasets import PCAM
import torchvision.transforms as transforms

from torcheval.metrics import MulticlassAUROC, MulticlassAccuracy

import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm as _tqdm
import os


In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

cuda


In [3]:
## Dataset and data loaders

transform = transforms.Compose([
    transforms.PILToTensor()
])

train_dataset = PCAM(root='data', split='test', download=True, transform=transform)
val_dataset = PCAM(root='data', split='val', download=True, transform=transform)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=True)


In [4]:
## Model

model = resnet18(pretrained=True)
model.to(device)

# Freeze all layers except last
for param in model.parameters():
    param.requires_grad = False

# Create classification layer    
num_classes = 2
model.fc = torch.nn.Linear(model.fc.in_features, num_classes)

## Optimizer
optimizer = torch.optim.SGD(model.fc.parameters(), lr=0.01, momentum=0.9)

## Loss Function
loss_fun = torch.nn.CrossEntropyLoss()




In [13]:
def uniquify(path):
    """
    Creates unique path name by appending number if given path already exists
    """
    
    filename, extension = os.path.splitext(path)
    counter = 1

    while os.path.exists(path):
        path = filename + "_" + str(counter) + extension
        counter += 1

    return path


def tqdm(*args, **kwargs):
    """
    Wrapper for loop progress bar
    """
    
    return _tqdm(*args, **kwargs, mininterval=1)  # Safety, do not overflow buffer


def train(model, train_loader, val_loader, loss_fun, optimizer, num_epochs, num_classes, device, save_ckpt_path=None):
    """
    Trains model
    """
    
    model.to(device)
    
    # Create metric monitors
    auc = MulticlassAUROC(num_classes=num_classes)
    accuracy = MulticlassAccuracy()
    
    for epoch in range(num_epochs):
        
        # Set the model to train mode
        model.train()

        # Initialize the running loss and metrics
        curr_loss = 0.0
        auc.reset()
        accuracy.reset()
        
        ## Train
        for inputs, labels in tqdm(train_loader, desc=f'Epoch {epoch+1}/{num_epochs}, Training'):
            
            # Move the inputs and labels to the device
            inputs = inputs.float().to(device)
            labels = labels.to(device)

            # Zero the optimizer gradients
            optimizer.zero_grad()

            # Forward pass
            outputs = model(inputs)
            _, preds = torch.max(outputs, 1)
            loss = loss_fun(outputs, labels)

            # Backward pass and optimizer step
            loss.backward()
            optimizer.step()

            # Update the running loss and metrics
            curr_loss += loss.item() * inputs.size(0)
            auc.update(outputs, labels)
            accuracy.update(outputs, labels)

        # Calculate the train loss and metrics
        train_loss = curr_loss / len(train_dataset)
        train_acc = accuracy.compute()
        train_auc = auc.compute()

        # Set the model to evaluation mode
        model.eval()

        # Initialize the running loss
        curr_loss = 0.0
        
        # Initialize the metrics
        auc.reset()
        accuracy.reset()

        ## Validate
        with torch.no_grad():
            for inputs, labels in tqdm(val_loader, desc=f'Epoch {epoch+1}/{num_epochs}, Validation'):
                
                # Move the inputs and labels to the device
                inputs = inputs.float().to(device)
                labels = labels.to(device)

                # Forward pass
                outputs = model(inputs)
                _, preds = torch.max(outputs, 1)
                loss = loss_fun(outputs, labels)

                # Update the running loss and metrics
                curr_loss += loss.item() * inputs.size(0)
                auc.update(outputs, labels)
                accuracy.update(outputs, labels)

        # Calculate the validation loss, accuracy and AUC
        val_loss = curr_loss / len(val_dataset)
        val_acc = accuracy.compute()
        val_auc = auc.compute()

        # Print the epoch results
        print('Train Loss: {:.4f}, Train Acc: {:.4f}, Train AUC: {:.4f}, \n Val Loss: {:.4f}, Val Acc: {:.4f}, Val AUC: {:.4f}\n'
              .format(train_loss, train_acc, train_auc, val_loss, val_acc, val_auc))
        
        ## Save model checkpoint
        if save_ckpt_path is None:
            save_ckpt_path = os.path.join('models',f'{model.__class__.__name__}.pt')
            if not os.path.exists('models'):  # If folder 'models' doesn't exist, create it
                os.makedirs('models')
        save_ckpt_path = uniquify(save_ckpt_path)  # Create unique path name by appending number if given path already exists
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'train_loss': train_loss,
            'train_acc': train_acc,
            'train_auc': train_auc,
            'val_loss': val_loss,
            'val_acc':val_acc,
            'val_auc': val_auc,
            }, save_ckpt_path)
        

In [14]:
train(model, train_loader, val_loader, loss_fun, optimizer, num_epochs=1, num_classes = 2, device=device)

Epoch 1/1, Training: 100%|██████████| 1024/1024 [00:29<00:00, 35.07it/s]
Epoch 1/1, Validation: 100%|██████████| 1024/1024 [00:29<00:00, 35.16it/s]

Train Loss: 0.8931, Train Acc: 0.7650, Train AUC: 0.8324, 
 Val Loss: 1.8996, Val Acc: 0.6412, Val AUC: 0.8051



