IMPORTS

In [8]:
#standard Python libraries
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor
from collections import OrderedDict

#Weights and Biases - included for experiment tracking
#import wandb

#custom libraries included with project
from utils.RunBuilder import RunBuilder
from utils.RunManager import RunManager
from utils.ModelFactory import ModelFactory
from utils.DatasetFactory import DatasetFactory

SETTING PARAMETERS

In [9]:
#Dataset Parameters
test_dataset_split_size = 0.2

#image input dimensions - different architectures expect different input sizes 
#generally ok to keep this at 224 x 224 unless using Inception (299 x 299) or AlexNet (227 x 227)
image_input_width = 224
image_input_height = 224

#where to look for the training data
data_training_directory = "training_data/V1"
dataset_version = "V1"

HYPER-PARAMETER DICTIONARY

In [10]:
params = OrderedDict(
    #LEARNING RATE - Number of steps adjust the model by at each epoch
    lr = [0.001, 0.0001],
    #BATCH SIZE - how many images to pass at once to the model
    #WARNING: setting batch size too large may cause out of memeory issues
    batch_size = [64],
    #SHUFFLE - shuffle the dataset at each epoch to change order in which images are learnt
    shuffle = [True],
    #NUM WORKERS - how many processes to use for dataloading - can speed up learning if you have resources
    num_workers = [0],
    #EPOCHS - how many epochs to run the learning for
    epochs = [5],
    
    trainset = ['dataset_squarepad', 'dataset_fullpad'],
    model = ['resnet50'],
    
    # For a full list of possible models check out https://pytorch.org/vision/0.21/models.html

)

CHECK FOR GPU SUPPORT

In [11]:
#check for GPU support 
#WARNING - need to check correct version of PyTorch is installed for GPU support
if torch.cuda.is_available:
   device = 'cuda'
else:
   device = 'cpu'

print(device)

cuda


DEFINE DATASET & CLASSLIST

In [12]:
dataset = datasets.ImageFolder(data_training_directory)

#Class Information from dataset
number_of_classes = len(dataset.classes)
print("Number of classes: " ,number_of_classes)
class_names = dataset.classes
print(class_names)


Number of classes:  6
['Asterionellopsis', 'Chaetoceros', 'Dinophysis', 'Octactis', 'Pseudo-nitzschia', 'Tripos']


LOAD DATASET

In [13]:

#define active datasets
trainsets = {
    #image only datasets
    'dataset' : DatasetFactory.get_dataset('dataset', data_training_directory, test_dataset_split_size,  image_input_width, image_input_height),
    'dataset_normalised' : DatasetFactory.get_dataset('dataset_normalised', data_training_directory, test_dataset_split_size,  image_input_width, image_input_height),
    'dataset_squarepad' : DatasetFactory.get_dataset('dataset_squarepad', data_training_directory, test_dataset_split_size,  image_input_width, image_input_height),
    'dataset_fullpad' : DatasetFactory.get_dataset('dataset_fullpad', data_training_directory, test_dataset_split_size,  image_input_width, image_input_height),
    'dataset_reflectpad' : DatasetFactory.get_dataset('dataset_reflectpad', data_training_directory, test_dataset_split_size,  image_input_width, image_input_height),
    'dataset_squarepad_normalised' : DatasetFactory.get_dataset('dataset_squarepad_normalised', data_training_directory, test_dataset_split_size,  image_input_width, image_input_height),
    'dataset_fullpad_normalised' : DatasetFactory.get_dataset('dataset_fullpad_normalised', data_training_directory, test_dataset_split_size,  image_input_width, image_input_height),
    'dataset_reflectpad' : DatasetFactory.get_dataset('dataset_reflectpad', data_training_directory, test_dataset_split_size,  image_input_width, image_input_height),
}

MODEL TRAINING

In [14]:
rm = RunManager(number_of_classes)

for run in RunBuilder.get_runs(params):
 
    #define project and run names 
    run_name = f'{dataset_version}-{run.model}_{run.trainset}_b{run.batch_size}_lr{run.lr}_e{run.epochs}'

    #setup all the WandB variables
    # wandb.init(
    #     project=project_name,
    #     config={
    #         "learning_rate": run.lr,
    #         "batch_size": run.batch_size,
    #         "architecture": run.model,
    #         "dataset": run.trainset,
    #         "epochs": run.epochs,
    #         "run_name" : custom_run_name
    #     }) 
    # wandb.run.name = custom_run_name

    #passing batch of images to model
    loader = torch.utils.data.DataLoader(
        trainsets[run.trainset]['train'], 
        batch_size=run.batch_size, 
        shuffle=run.shuffle, 
        num_workers = run.num_workers,
    )

    #this is used at the end of every epoch to give a score each epoch
    validation_loader = torch.utils.data.DataLoader(
        trainsets[run.trainset]['test'],
        batch_size=run.batch_size, 
        num_workers = run.num_workers
    )

    #load a model from the model factory
    model = ModelFactory.get_network(run.model, number_of_classes).to(device)

    #define loss function and optimizer
    loss_fn = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=run.lr)

    #variables to keep track of the best performing epoch (based on F1)
    best_epoch_F1 = 0
    best_epoch = 1

    #initialilse the run in the RunManager
    rm.begin_run(run, model, loader, device, run_name)
    

    for epoch in range(run.epochs):
        
        #initialilse the current epoch in the RunManager
        rm.begin_epoch()
        
        print(f'\nEPOCH {epoch+1}')
        print("\nTRAINING")

        #variables to track loss and correct predictions over the epoch
        running_loss = 0
        correct = 0
        
        #get total number of images in dataset
        size = len(loader.dataset)

        #set the model to train mode
        model.train()

        
        for batch, (images,labels) in enumerate(loader):
            images = images.to(device)
            labels = labels.to(device)
            
            if run.model == 'inception_v3':
                preds, x = model(images) #pass batch
            else:
                preds = model(images)

            loss = loss_fn(preds, labels) # calculate loss

            
            correct += preds.argmax(dim=1).eq(labels).sum().item()
            
            loss.backward() # calculate gradients
            optimizer.step() # update weights
            optimizer.zero_grad() # zero out gradients

            if batch % 100 == 0:
                f_loss, current = loss.item(), (batch + 1) * len(images)
                print(f"loss: {f_loss:>7f}  [{current:>5d}/{size:>5d}]")


            running_loss += loss.item()
                
        
        running_loss /= len(loader)
        accuracy = correct / len(loader.dataset)

        print(f"Correct: {correct} / {len(loader.dataset)} ")        
        print(f"Accuracy: {(100*accuracy):>0.1f}%, Avg loss: {running_loss:>8f} \n")
        

        rm.epoch_loss = running_loss
        rm.epoch_accuracy = accuracy

        print("TESTING")

        test_loss = 0
        test_correct = 0

        test_size = len(validation_loader.dataset)
        num_batches = len(validation_loader)

        model.eval()
        with torch.no_grad():
            for test_images, test_labels in validation_loader:

                test_images = test_images.to(device)
                test_labels = test_labels.to(device)
                  
                test_preds = model(test_images)
                test_loss += loss_fn(test_preds, test_labels).item() # calculate loss
                test_correct += (test_preds.argmax(1) == test_labels).type(torch.float).sum().item()
                
                rm.update_metrics(test_preds, test_labels)
        
        test_loss /= len(validation_loader)
        test_accuracy = test_correct / len(validation_loader.dataset)  
        
        print(f"Correct: {test_correct} / {len(validation_loader.dataset)} ")      
        print(f"Accuracy: {(100*test_accuracy):>0.1f}%, Avg loss: {test_loss:>8f} \n")
                    
        #pass test loss and accuracy to the run manager 
        rm.epoch_test_loss = test_loss
        rm.epoch_test_accuracy =  test_accuracy
        rm.calculate_metrics()

        print("FINISHING EPOCH")
        print("Current Epoch F1: ", rm.epoch_weighted_F1)
        print("Best Epoch F1: ", best_epoch_F1)
        
        #only save the best models according to F1 score
        if rm.epoch_weighted_F1 > best_epoch_F1:
            print("Saving new best model...")
            torch.save(model.state_dict(), 'models/BEST_' + run_name + '_' + str(epoch+1) ) 
            best_epoch_F1 = rm.epoch_weighted_F1 
            best_epoch = epoch+1

        #log key metrics to weights and biases
        # wandb.log({"acc": accuracy, "loss": loss, "val_precision": rm.epoch_precision, "val_recall": rm.epoch_recall,
        #             "val_acc": rm.epoch_test_accuracy,"val_loss": test_loss, "F1 Score": rm.epoch_F1,
        #             "F1 weighted Score": rm.epoch_weighted_F1, "auprc": rm.epoch_auprc, "auroc": rm.epoch_auroc })
        
        rm.end_epoch()
    print("\nBest Epoch: " , best_epoch , " \n Best F1 Score: " , best_epoch_F1)
    torch.cuda.empty_cache()
    rm.save(run_name)
    #wandb.finish()
rm.end_run()


EPOCH 1

TRAINING
loss: 1.791747  [   64/ 4800]
Correct: 4572 / 4800 
Accuracy: 95.2%, Avg loss: 0.149822 

TESTING
Correct: 1076.0 / 1200 
Accuracy: 89.7%, Avg loss: 0.333117 

FINISHING EPOCH
Current Epoch F1:  0.8945530652999878
Best Epoch F1:  0
Saving new best model...

EPOCH 2

TRAINING
loss: 0.142402  [   64/ 4800]
Correct: 4677 / 4800 
Accuracy: 97.4%, Avg loss: 0.080727 

TESTING
Correct: 1173.0 / 1200 
Accuracy: 97.8%, Avg loss: 0.065745 

FINISHING EPOCH
Current Epoch F1:  0.9773824214935303
Best Epoch F1:  0.8945530652999878
Saving new best model...

EPOCH 3

TRAINING
loss: 0.045774  [   64/ 4800]
Correct: 4782 / 4800 
Accuracy: 99.6%, Avg loss: 0.013503 

TESTING
Correct: 1186.0 / 1200 
Accuracy: 98.8%, Avg loss: 0.045454 

FINISHING EPOCH
Current Epoch F1:  0.9882955551147461
Best Epoch F1:  0.9773824214935303
Saving new best model...

EPOCH 4

TRAINING
loss: 0.000489  [   64/ 4800]
Correct: 4736 / 4800 
Accuracy: 98.7%, Avg loss: 0.042640 

TESTING
Correct: 194.0 / 1200