## Install required libraries

In [None]:
!pip install opendatasets timm lightning albumentations --upgrade --quiet

## Download the dataset from kaggle

In [None]:
import opendatasets as od

# Assign the Kaggle data set URL into variable
dataset = 'https://www.kaggle.com/datasets/paultimothymooney/chest-xray-pneumonia/data'
od.download(dataset)

## Import the necessary libraries

In [None]:
# custom modules
%load_ext autoreload
%autoreload 2
import utilities
#import model_functions
import model_factory

#lightning modules and callbacks
import lightning_data
import lightning_model
import train_info
import learning_curves
import confusion_matrix

import os

# timm models
import timm

# torch modules (temporarily)
import torch.nn as nn
import torch
# pytorch lightning (for checkpointing callbacks)
import pytorch_lightning as pl
from pytorch_lightning.loggers import CSVLogger

## Mount GDrive partition to store checkpoints

In [None]:
# necessary, as checkpoints will be saved on GDrive
from google.colab import drive
drive.mount('/content/drive')

## Set up the model configuration

In [None]:
# default training configuration
config = {
    'model_name': 'densenet121', # name of the pretrained model
    'classifier_name': 'linear', # name of the classifier (e.g. linear/nonlinear)
    'classifier_type': None, # leave it None
    'layers': 'all', # layers to train (e.g. first (starting from last), second (starting from last), all)
    'augmentation': 'strong', # augmentation type (e.g. normal or strong)
    'classes_weight': None, # weights for each class
    'batch_size': 64,
    'val_split': 0.1,
    'n_epochs': 20,
    'optimizer': 'SGD',
    'scheduler': 'CosineAnnealingLR10', # leave it empty to not use any scheduling
    'ensemble': False,
    'image_size': None,
    'mean': None,
    'std': None
    }

config['classifier_type'] = model_factory.get_linear_classifer if config['classifier_name'] == 'linear' else model_factory.get_simple_non_linear_classifier


# path to save the checkpoints of this model
checkpoint_path = os.path.join('/content/drive/MyDrive/models/', config['model_name'], config['classifier_name']+ " " + config['layers'] + " " + config['augmentation'] + " " + config['optimizer'] + " " + config['scheduler'])
checkpoint_path

## Set up the Pytorch Lightning modules and callbacks

In [None]:
# dataloader
pneumonia_data = lightning_data.PneumoniaDataModule(config)
# lightning module
pneumonia_model = lightning_model.PneumoniaModel(config)

# callback to print training info
training_info_callback = train_info.PrintTrainingInfoCallback()
# callback to show learning curves after training is done
learning_curves_callback = learning_curves.PlotLearningCurvesCallback()
# callback to show confusion matrix after test is done
conf_matrix_callback = confusion_matrix.PlotConfusionMatrixCallback()
#callback to track training times
timer = pl.callbacks.Timer()
# store the metrics in a csv format
csv_logger = CSVLogger(save_dir = checkpoint_path, name="logs")

# callback to save the best model found during training
checkpoint_callback = pl.callbacks.ModelCheckpoint(
    dirpath=checkpoint_path, # checkpoints are saved to GDrive, in case runtime disconnects
    monitor="val_acc",
    mode='max',
    save_top_k=1,
    verbose=True,
    save_last = True, # save last model (to restore in case runtime disconnects)
    )

callbacks = [training_info_callback,learning_curves_callback,conf_matrix_callback,checkpoint_callback,timer]

# create the trainer
trainer = pl.Trainer(
        max_epochs=config['n_epochs'],
        log_every_n_steps=1,
        callbacks = callbacks,
        logger = csv_logger
    )

## Training

In [None]:
# uncomment to restore training from last saved checkpoint (in case something went wrong) and comment the other
#trainer.fit(pneumonia_model, datamodule=pneumonia_data, ckpt_path=os.path.join(checkpoint_path,'last.ckpt'))
trainer.fit(pneumonia_model, datamodule=pneumonia_data)

In [None]:
# plot the learning curves up to the last concluded epoch
#utilities.plot_results(checkpoint_path)

## Test

In [None]:
# retrieve the best model from checkpoints
best_model_path = checkpoint_callback.best_model_path
#best_model_path = checkpoint_path + "/epoch=6-step=259.ckpt"
best_model = lightning_model.PneumoniaModel.load_from_checkpoint(best_model_path, h=config)
pneumonia_model = best_model

# test the best model
trainer.test(pneumonia_model, datamodule=pneumonia_data)
#trainer.test(ckpt_path="best", datamodule=pneumonia_data) # alternatively, to test the best model directly

In [None]:
print(f"Test accuracy: {pneumonia_model.test_acc*100:.2f}%")
print(f"Precision: {pneumonia_model.test_precision*100:.2f}%")
print(f"Recall: {pneumonia_model.test_recall*100:.2f}%")
print(f"F1-score: {pneumonia_model.test_f1*100:.2f}%")
print(f"AUC: {pneumonia_model.test_auc*100:.2f}%")

## Retrieve training loss and accuracy

In [None]:
# to retrieve training info
ckpt = torch.load(os.path.join(checkpoint_path,'last.ckpt'))
#ckpt['callbacks']['PlotLearningCurvesCallback']

In [None]:
# time elapsed in seconds to 'train' and 'validate'
ckpt['callbacks']['Timer']['time_elapsed']