In [None]:
import copy, cv2, math, numpy as np, os, time, torch, torchvision as tv

In [None]:
from data_utils import *
from helpers import *
from model_utils import *
from training_utils import *

In [None]:
def set_model_grad_requirement(model, mode):
    """Set the parameters of a model to require gradients or not

        Args:
            model: The model whose parameter mutability should be set
            is_required: Whether or not gradients are required
            last_frozen_layer: The last layer that should not be tuned, if mode == 'partial'
    """
    
    is_required = True if mode.lower() == 'full' else False
    for param in model.parameters():
        param.requires_grad = is_required  

In [None]:
def model_reshape(model, target_size):
    """Reshapes the final layer of a model

        Args:
            model: The model to be reshaped
            target_size: The number of outputs required
    """
    input_size = model.fc.in_features
    model.fc = torch.nn.Linear(input_size, target_size)

In [None]:
import sklearn.metrics as skm

def train_model(model, epochs, criterion, optimizer, scheduler, dataloaders, dataset_sizes, device, output_filename):
    """Runs the training loop for a model

        Args:
            model: The model to be trained
            epochs: The number of epochs to train for
            criterion: The loss function to be optimised
            optimizer: The optimizer to use
            scheduler: The learning rate scheduler to use
            dataloaders: A dictionary containing the training ('train') and validation ('val') data loaders
            dataset_sizes: The number of samples in the training ('train') and validation ('val') datasets
            device: The device on which training should be run
            output_filename: The output model filename
        
        Returns:
            A tuple containing the best model and the metric history
    """
    statistics = {
        'train': { 'loss': [], 'accuracy': [], 'f1': [] },
        'val': { 'loss': [], 'accuracy': [], 'f1': [] },
        'test': { 'loss': [], 'accuracy': [], 'f1': [] }
    }
    best_stat = 0.0
    best_epoch = 0
    best_model = copy.deepcopy(model.state_dict())

    epoch = 1
    learning = True
    while learning:
        for phase in ['train', 'val']:
            if phase == 'train':
                model.train()
            else:
                model.eval()

            predictions = [None] * len(dataloaders[phase])
            truths = [None] * len(dataloaders[phase])
            running_loss = 0.0

            for b, (inputs, labels) in enumerate(dataloaders[phase]):
                inputs = inputs.to(device)
                labels = labels.to(device)

                optimizer.zero_grad()

                with torch.set_grad_enabled(phase == 'train'):
                    outputs = model(inputs)
                    _, preds = torch.max(outputs, 1)
                    loss = criterion(outputs, labels)

                    if phase == 'train':
                        loss.backward()
                        optimizer.step()

                    predictions[b] = preds
                    truths[b] = labels.data

                running_loss += loss.item() * inputs.size(0)
                
            update_statistics(statistics[phase], running_loss / dataset_sizes[phase], predictions, truths)
            print_statistics(statistics, phase, epoch)

            if phase == 'val' and statistics['val']['f1'][-1] > best_stat:
                best_stat = statistics['val']['f1'][-1]
                best_model = copy.deepcopy(model.state_dict())
                best_epoch = epoch
                
            if phase == 'val':
                last_lr = optimizer.param_groups[0]['lr']
                scheduler.step(running_loss / dataset_sizes[phase])
                if optimizer.param_groups[0]['lr'] != last_lr:
                    model.load_state_dict(best_model)

        print()
        epoch += 1
        if epoch > epochs:
            learning = False

    print('Best validation F1 score: {:4f}'.format(best_stat))

    model.load_state_dict(best_model)
    save_model(model, f"{output_filename}")

    return model, statistics

In [None]:
def train(baseline_filename, epochs, weights, output_filename, pretrained=True, finetune='full', in_classes=2, out_classes=2):
    """Trains the classifier layer of a network.
    
        Args:
            baseline_filename: The name of the file containing pretrained weights
            epochs: The number of epochs to train for
            weights: The class weights to use for loss function normalisation
            output_filename: The output model filename
            pretrained: Whether a pretrained model should be loaded
            finetune: The type of finetuning to apply if using a pretrained model ('full', 'classifier', 'partial')
            num_classes: The number of classes in the new output layer
    
        Returns:
            The best trained model based on f1 score
    """
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    model = tv.models.resnet18(pretrained=pretrained)
    mode = 'full' if not pretrained else finetune.lower()
    set_model_grad_requirement(model, mode)
    model_reshape(model, in_classes)
    model.load_state_dict(torch.load(baseline_filename, map_location=device))
    model_reshape(model, out_classes)
    print_parameters(model)
    model = model.to(device)

    if not pretrained:
        reinit_conv_layers(model)

    criterion = torch.nn.CrossEntropyLoss(torch.as_tensor(weights, device=device, dtype=torch.float))
    optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=3, threshold=1e-4)

    t0 = time.perf_counter()
    model, epoch_statistics = train_model(model, epochs, criterion, optimizer, scheduler, dataloaders, dataset_sizes, device,
                                          output_filename)
    t1 = time.perf_counter()
    print(f"Network {iter} trained in {t1 - t0} s")

    return model

In [None]:
set_seed(42)

batch_size = 128
num_classes = 4

transform = get_resnet_transforms()
training_set, validation_set = make_datasets(f"images", transform, valid_size=10000, train_size=8000)
dataloaders, dataset_sizes, weights = prepare_dataloaders(training_set, validation_set, batch_size, num_classes)

In [None]:
model = train("model_baseline.pt", 10, weights, "model_transfer", pretrained=True, finetune='classifier', in_classes=2, out_classes=num_classes)

In [None]:
model = train("model_transfer.pt", 5, weights, "model_final", pretrained=True, finetune='full', in_classes=num_classes,
              out_classes=num_classes)