## Training for ResNet18

*if using google drive, run following two cells*

In [None]:
import os
from google.colab import drive
drive.mount('/content/drive')
project_dir = '/content/drive/MyDrive/master_courses/BIDH5001 Capstone/Project/\
deep-classificataon'
os.chdir(project_dir)

In [None]:
%pip install pydicom

imports, initiating dataloaders

In [1]:
# imports
import datetime
import time
import tempfile
import warnings
import torch
import config
from dataio.dataloader import create_dataloader
from networks.resnet_classifier import resnet18_classifier
from training.evaluation import AccuracyEvaluator, LossEvaluator
from training.utility.progress_bar import ProgressBar
from training.utility.early_stopper import ValLoss as EarlyStopper
from training.utility.signal_control import SignalFileControl
import os
import re

# whenever possible, use cuda instead of cpu
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")
    warnings.warn("torch is using CPU, as cuda is unavailable. this is inefficient.")
torch.set_default_device(device)

# paths to save and load state dictionary
load_sd = False
save_sd = False

bin_path = os.path.abspath("bin")
assert os.path.isdir(bin_path)

# if load_sd set True, load from state dict
sd_load_path = None
starting_epoch = 0
sd_file_pattern = r'resnet18_state_[0-9]{2}-[0-9]{2}-ep_[0-9]{2}\.pkl'
sd_epoch_pattern = r'resnet18_state_[0-9]{2}-[0-9]{2}-ep_([0-9]{2})\.pkl'
if load_sd:
    sd_files = [
        filename for filename in os.listdir(bin_path)
        if re.match(sd_file_pattern, filename)]
    sd_files.sort(reverse=True)
    if len(sd_files) > 0:
        sd_filename = sd_files[0]
        sd_load_path = os.path.join(bin_path, sd_filename) \
            if sd_filename is not None else None
        match = re.search(sd_epoch_pattern, sd_filename)
        starting_epoch = int(match.group(1))
        print(f'- loaded state from: {sd_filename}')
print(f"- starting from epoch: {starting_epoch + 1}")

# tempfile to save the model every time it reaches the best validation loss
state_dict_backup_path = tempfile.mktemp(prefix='state-dict_', suffix='.pth')

- starting from epoch: 1


preparing the training:
1. read dicoms and labels from configuration and tracking table, initiate dataloaders
2. set up criterions and optimizers, training parameters

*the data*

In [2]:
# reading from configurations
dicoms = config.dicom_paths # tracking_table['dicom_path'].to_list()
labels = config.labels

# initiate the dataloaders
dataloader_dict = create_dataloader(
    dicoms, labels,
    batch_size = 8,
    validation_size = config.validation_size,
    test_size = config.test_size,
    img_size = (224, 224),
    use_3_channels = True,
)
training_dataloader, validation_dataloader, test_dataloader = \
    (dataloader_dict.get(key) for key in ('training_dataloader', 'validation_dataloader', 'test_dataloader'))

*the model*

In [3]:
# some settings, move to config in future
learning_rate = 0.0001 # 0.00005, originally without lr scheduler
n_epoches = 7
min_epoches = 3 # minimum epoches before early stopping allowed

# the classifier model
model = resnet18_classifier.to(device=device)
# load state dict into 
if load_sd and sd_load_path:
    state_dict = torch.load(sd_load_path)
    model.load_state_dict(state_dict)

# optimizers
loss_function = torch.nn.BCELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
# scheduler = torch.optim.lr_scheduler.OneCycleLR(
#     optimizer, max_lr=learning_rate,
#     steps_per_epoch=len(training_dataloader), epochs=n_epoches,
#     )
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, 'min', factor=0.25, patience=1,
    # threshold=0.05, threshold_mode='rel',
    min_lr = learning_rate / 32,
    cooldown=7,
)

# accuracy and loss evaluators
training_loss_evaluator = LossEvaluator(loss_function)
validation_loss_evaluator = LossEvaluator(loss_function)
validation_accuracy_evaluator = AccuracyEvaluator()

best_epoch, best_loss = None, None

In [None]:
# to continue training for n_epoches
n_epoches = 10

training for resnet18 image quality classifier

In [4]:
# evaluators for training/validation loss and accuracy
# training_loss = LossEvaluator(loss_function)
# validation_loss = LossEvaluator(loss_function)
# validation_accuracy = AccuracyEvaluator()
earlystopper = EarlyStopper(tolerance=4, target=0.25)
signalstopper = SignalFileControl()

# training epoches loop
for epoch in range(n_epoches):
    current_epoch = epoch + 1 + starting_epoch
    print(f'epoch # {current_epoch}')
    progress = ProgressBar(len(training_dataloader) + len(validation_dataloader))
    epoch_start_time = time.time()
    # reset pred and actual labels after each epoch
    validation_accuracy_evaluator.reset()
    training_loss_evaluator.reset()
    validation_loss_evaluator.reset()
    
    # training network
    for images, labels in training_dataloader:
        images = images.to(device)
        labels = labels.to(device)
        labels_pred = model(images)
        # train loss
        training_loss = loss_function(labels_pred, labels)
        # backpropagation
        optimizer.zero_grad()
        training_loss.backward()
        optimizer.step()
        progress.step()
        training_loss_evaluator.append_loss(training_loss.item())
    
    # validating network
    for images, labels in validation_dataloader:
        images = images.to(device)
        labels = labels.to(device)
        labels_pred = model(images)
        validation_loss = loss_function(labels_pred, labels)
        progress.step()
        validation_loss_evaluator.append_loss(validation_loss.item())
        validation_accuracy_evaluator.append(labels_pred, labels)
    epoch_end_time = time.time()
    # step scheduler
    scheduler.step(validation_loss_evaluator.value())
    print(
        f'train loss: {training_loss_evaluator}  \
        validation loss: {validation_loss_evaluator}  \
        validation accuracy: {validation_accuracy_evaluator}')
    print(
        f'precision: {round(validation_accuracy_evaluator.precision(), 3)}  \
        recall: {round(validation_accuracy_evaluator.recall(), 3)}  \
        f1: {round(validation_accuracy_evaluator.f1(), 3)}')
    # the evaluators have __str__
    print(
        f'duration: {round(epoch_end_time - epoch_start_time, 2)} s  \
        learning rate: {round(scheduler.get_last_lr()[0], 9)}')
    
    # if reaching best validation loss, back up the state
    if validation_loss_evaluator.is_best():
        best_epoch = current_epoch
        best_loss = validation_loss_evaluator.value()
        if save_sd and current_epoch > 1:
            # remove existing tempfile if exists
            os.remove(state_dict_backup_path) if os.path.isfile(state_dict_backup_path) else None
            torch.save(model.state_dict(), state_dict_backup_path)
    
    # step early stopper
    earlystopper.step(validation_loss_evaluator.value())
    print('-'*75)
    if earlystopper.stop() and epoch >= min_epoches:
        print('early stopper triggered, break')
        break
    if signalstopper.stop():
        print('signalled stop')
        break

signalstopper.reset()
# update the starting epoch number
starting_epoch = current_epoch
print('-' * 75)
if best_epoch is not None:
    print(f'best validation loss: {best_loss}, epoch {best_epoch}')
    print(f'saved to temp file: {state_dict_backup_path}')
else:
    print(f'validation loss does not improve')

signal file path: C:\Users\user\AppData\Local\Temp\signal-05-14-o80lrdi0
epoch # 1
train loss: 0.186          validation loss: 0.69          validation accuracy: 9.24 %
precision: 0.118          recall: 1.0          f1: 0.211
duration: 30.44 s          learning rate: 0.0001
---------------------------------------------------------------------------
epoch # 2
train loss: 0.049          validation loss: 0.649          validation accuracy: 9.24 %
precision: 0.118          recall: 1.0          f1: 0.211
duration: 29.38 s          learning rate: 0.0001
---------------------------------------------------------------------------
epoch # 3
train loss: 0.024          validation loss: 0.503          validation accuracy: 54.62 %
precision: 0.1          recall: 0.357          f1: 0.156
duration: 29.2 s          learning rate: 0.0001
---------------------------------------------------------------------------
epoch # 4
train loss: 0.016          validation loss: 0.456          validation accuracy: 8

In [5]:
print(validation_accuracy_evaluator.labels_pred)
print(validation_accuracy_evaluator.labels_actual)

[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]
[2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]


In [5]:
images.shape

torch.Size([8, 3, 224, 224])

In [6]:
images.dtype, labels.dtype

(torch.float64, torch.float32)

In [None]:
from dataio.transforms import grayscale_to_rgb
images_rgb = grayscale_to_rgb(images[0])
images_rgb.shape

In [None]:
training_dataset = dataloader_dict.get('training_dataset')
training_dataset._use_3_channels

save state dict

In [None]:
# save model state_dicts to target path
sd_save_filename = f"resnet18_state_{datetime.datetime.now().strftime('%m-%d')}-ep_{current_epoch}.pkl"
sd_save_path = os.path.join(bin_path, sd_save_filename)
torch.save(model.state_dict(), sd_save_path)
print(f'state dict saved to path: {sd_save_path}')

*(debugging the network)*

In [None]:
model.drop_rate

In [None]:
len(validation_loss_evaluator.labels_actual_raw)

In [None]:
import datetime
state_dict_save_path = f"bin/res18_state_{datetime.datetime.now().strftime('%m-%d')}.pkl"
torch.save(model.state_dict(), state_dict_save_path)

In [None]:
validation_accuracy_evaluator._tfpn()

In [None]:
sum([len(training_dataloader), len(validation_dataloader), len(test_dataloader)])

In [None]:
validation_pred_labels = [torch.argmax(lbl).item() for lbl in validation_accuracy_evaluator.labels_pred_raw]
validation_actual_labels = [torch.argmax(lbl).item() for lbl in validation_accuracy_evaluator.labels_actual_raw]
sum(validation_pred_labels)

In [None]:
n_total = len(validation_actual_labels)
n_correct = 0
for pred_label, actual_label in zip(validation_pred_labels, validation_actual_labels):
    if pred_label == actual_label:
        n_correct += 1
n_correct / n_total

In [None]:
validation_accuracy_evaluator.labels_pred_raw[:3]

In [None]:
validation_accuracy_evaluator.accuracy()

calculate test set loss and accuracy

In [None]:
# evaluators for test loss and accuracy
test_loss_evaluator = LossEvaluator(criterion=loss_function)
test_accuracy_evaluator = AccuracyEvaluator()

# iterate through test dataloader
for images, labels in test_dataloader:
    images = images.to(device=device)
    labels = labels.to(device=device)
    labels_pred = model(images)
    loss = loss_function(labels_pred, labels)
    test_loss_evaluator.append_loss(loss)
    test_accuracy_evaluator.append(labels_pred, labels)
print(f'test loss: {test_loss_evaluator}\
    test accuracy: {test_accuracy_evaluator}')
print(
    f'precision: {round(test_accuracy_evaluator.precision(), 2)}  \
    recall: {round(test_accuracy_evaluator.recall(), 2)}  \
    f1: {round(test_accuracy_evaluator.f1(), 2)}')

In [None]:
# function shows saliency map
def calculate_saliency_map(model, image, target_class):

    """
    Calculates the saliency map for a given image and target class.

    Args:
    model: A PyTorch model.
    image: A PyTorch tensor containing the image.
    target_class: The target class.

    Returns:
    numpy array containing the saliency map.
    """

    # Forward pass the image through the model.
    output = model(image)
    # Get the gradient of the output with respect to the input image.
    # gradient = torch.autograd.grad(output[0][target_class], image[0])
    gradient = torch.autograd.grad(
        output, image, grad_outputs=target_class.view(1,-1),
        allow_unused=True)[0][0]
    # Calculate the absolute value of the gradient.
    # saliency_map = torch.abs(gradient)
    saliency_map = torch.norm(gradient)
    # Normalize the saliency map.
    saliency_map = saliency_map / torch.max(saliency_map)
    saliency_map_np = saliency_map.detach().numpy()
    return saliency_map_np

In [None]:
gradient = torch.autograd.grad(output, image, grad_outputs)

In [None]:
import matplotlib.pyplot as plt
image = images[0][0]
plt.imshow(image, cmap='binary_r')

In [None]:
import tempfile
for images in test_dataloader.view()

In [None]:
len(validation_accuracy_evaluator.labels_actual_raw)

## Appendix
additional details about the model

structure of the resnet_18 model

In [None]:
model

In [None]:
isinstance(model, torch.nn.Module)