## Training for ResNet18 - 128
with 128 x 128 input size and modified conv1 layer

*if using google drive, run following two cells*

In [None]:
import os
from google.colab import drive
drive.mount('/content/drive')
project_dir = '/content/drive/MyDrive/master_courses/BIDH5001 Capstone/Project/\
deep-classification'
os.chdir(project_dir)

In [None]:
%pip install pydicom    

imports, initiating dataloaders

In [1]:
# imports
import datetime
import time
import warnings
import torch
import config
from dataio.dataloader import create_dataloader
from networks.resnet_classifier import resnet18_128_classifier
from training.evaluation import AccuracyEvaluator, LossEvaluator
from training.utility.progress_bar import ProgressBar
from training.utility.early_stopper import ValLoss as EarlyStopper

# whenever possible, use cuda instead of cpu
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")
    warnings.warn("cuda is using CPU, this can be very slow")


In [2]:
# paths to save and load state dictionary
import os
import re

load_sd = False
save_sd = False

bin_path = "bin"
assert os.path.isdir(bin_path)
sd_files = [
    filename for filename in os.listdir(bin_path)
    if re.match(r'res18_128_state_[0-9]{2}-[0-9]{2}\.pkl', filename)]
sd_files.sort(reverse=True)
sd_load_filename = sd_files[0] if len(sd_files) > 0 else None
sd_save_filename = f"res18_128_state_{datetime.datetime.now().strftime('%m-%d')}.pkl"
sd_load_path = os.path.join(bin_path, sd_load_filename) \
    if sd_load_filename is not None else None
sd_save_path = os.path.join(bin_path, sd_save_filename)

preparing the training:
1. read dicoms and labels from configuration and tracking table, initiate dataloaders
2. set up criterions and optimizers, training parameters

In [3]:
# reading from configurations
dicoms = config.tracking_table['dicom_path'].to_list()
labels = config.tracking_table['label'].astype('int16').to_list()

# initiate the dataloaders
dataloader_dict = create_dataloader(
    dicoms, labels,
    dicom_dir = config.dicom_dir,
    batch_size = 16,
    validation_size = config.validation_size,
    test_size = config.test_size,
    img_size=(128, 128)
)
training_dataloader, validation_dataloader, test_dataloader = \
    (dataloader_dict.get(key) for key in ('training_dataloader', 'validation_dataloader', 'test_dataloader'))

In [6]:
# some settings, move to config in future
learning_rate = 0.0005 # 0.00005, originally without lr scheduler
n_epoches = 15
min_epoches = 7
starting_epoch = 0

# the classifier model
model = resnet18_128_classifier
# load state dict into 
if load_sd and sd_load_path:
    state_dict = torch.load(sd_load_path)
    model.load_state_dict(state_dict)

# optimizers
loss_function = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
# scheduler = torch.optim.lr_scheduler.OneCycleLR(
#     optimizer, max_lr=learning_rate,
#     steps_per_epoch=len(training_dataloader), epochs=n_epoches,
#     )
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, 'min', factor=0.25, patience=1,
    # threshold=0.05, threshold_mode='rel',
    min_lr = learning_rate / 32,
    cooldown=7,
)
earlystopper = EarlyStopper(tolerance=3, target=0.25, delta=0.001)

# accuracy and loss evaluators
training_loss = LossEvaluator(loss_function)
validation_loss = LossEvaluator(loss_function)
validation_accuracy = AccuracyEvaluator()

In [5]:
# to continue training for n_epoches
learning_rate = 0.001
n_epoches = 3

training for resnet18 image quality classifier

In [7]:
# evaluators for training/validation loss and accuracy
# training_loss = LossEvaluator(loss_function)
# validation_loss = LossEvaluator(loss_function)
# validation_accuracy = AccuracyEvaluator()

# training epoches loop
for epoch in range(n_epoches):
    current_epoch = epoch + 1 + starting_epoch
    print(f'epoch # {current_epoch}')
    progress = ProgressBar(len(training_dataloader) + len(validation_dataloader))
    epoch_start_time = time.time()
    # reset pred and actual labels after each epoch
    validation_accuracy.reset()
    training_loss.reset()
    validation_loss.reset()
    # training network
    for images, labels in training_dataloader:
        # images = images.to(torch.float32)
        # labels = labels.to(torch.float32)
        labels_pred = model(images)
        # train loss
        loss_training = loss_function(labels_pred, labels)
        training_loss.append_loss(loss_training.item())
        # backpropagation
        optimizer.zero_grad()
        loss_training.backward()
        optimizer.step()
        progress.step()
    # validating network
    for images, labels in validation_dataloader:
        # images = images.to(torch.float32)
        # labels = labels.to(torch.float32)
        labels_pred = model(images)
        loss_validation = loss_function(labels_pred, labels)
        validation_loss.append_loss(loss_validation.item())
        validation_accuracy.append(labels_pred, labels)
        progress.step()
    epoch_end_time = time.time()
    # step scheduler
    scheduler.step(validation_loss.value())
    # step early stopper
    earlystopper.step(validation_loss.value())
    print(
        f'train loss: {training_loss}  \
        validation loss: {validation_loss}  \
        validation accuracy: {validation_accuracy}')
    print(
        f'precision: {round(validation_accuracy.precision(), 3)}  \
        recall: {round(validation_accuracy.recall(), 3)}  \
        f1: {round(validation_accuracy.f1(), 3)}')
    # the evaluators have __str__
    print(
        f'duration: {round(epoch_end_time - epoch_start_time, 2)} s  \
        learning rate: {round(scheduler.get_last_lr()[0], 9)}')
    print('-'*75)
    if earlystopper.stop() and current_epoch > min_epoches:
        print('early stopper triggered, break')
        break
    

# update the starting epoch number
starting_epoch += epoch

epoch # 1
train loss: 0.632          validation loss: 1.339          validation accuracy: 14.29 %
precision: -1          recall: -1          f1: -1.0
duration: 25.23 s          learning rate: 0.0008
---------------------------------------------------------------------------
epoch # 2
train loss: 0.376          validation loss: 1.436          validation accuracy: 14.29 %
precision: -1          recall: -1          f1: -1.0
duration: 24.63 s          learning rate: 0.0008
---------------------------------------------------------------------------
epoch # 3
train loss: 0.283          validation loss: 1.815          validation accuracy: 14.29 %
precision: -1          recall: -1          f1: -1.0
duration: 24.67 s          learning rate: 0.0002
---------------------------------------------------------------------------
epoch # 4
train loss: 0.116          validation loss: 1.586          validation accuracy: 14.29 %
precision: -1          recall: -1          f1: -1.0
duration: 24.68 s        

save state dict

In [None]:
# save model state_dicts to target path
if save_sd:
    torch.save(model.state_dict(), sd_save_path)
    print(f'state dict saved to path: {sd_save_path}')

(debugging the network)

In [None]:
[torch.argmax(lbl).item() for lbl in validation_accuracy.labels_actual_raw]

In [9]:
validation_accuracy.labels_actual

[1,
 1,
 1,
 1,
 2,
 2,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0]

In [None]:
torch.Tensor(1) 

In [None]:
sample_image = torch.randn(1, 1, )

In [None]:
images.shape
# model(images)

In [None]:
# resnet18_128_classifier.layer1

In [None]:
resnet18_128_classifier.conv1(sample_input)

In [None]:
resnet18_128_classifier.bn1

In [None]:
## debugging
sample_input = torch.randn(1,1,128,128)
sample_1 = resnet18_128_classifier.conv1(sample_input)
sample_2 = resnet18_128_classifier.bn1(sample_1)
sample_1.shape, sample_2.shape

In [None]:
sample_1.size()

In [None]:
sample_input = torch.randn(1,1,128,128)
resnet18_128_classifier(sample_input).shape
resnet18_128_classifier()

In [None]:
from networks.resnet_classifier import resnet18_classifier
sample_input = torch.randn(1,1,64,64)
sample_1 = resnet18_classifier.conv1(sample_input)
sample_2 = resnet18_classifier.bn1(sample_1)
sample_1.shape, sample_2.shape

In [None]:
sample_1.size()

In [None]:
sample_input = torch.randn(1,1,64,64)
resnet18_classifier(sample_input).shape

In [None]:
len(validation_loss.labels_actual_raw)

In [None]:
import datetime
state_dict_save_path = f"bin/res18_state_{datetime.datetime.now().strftime('%m-%d')}"
state_dict_save_path

In [None]:
import datetime
state_dict_save_path = f"bin/res18_state_{datetime.datetime.now().strftime('%m-%d')}.pkl"
torch.save(model.state_dict(), state_dict_save_path)

In [None]:
validation_accuracy._tfpn()

In [None]:
sum([len(training_dataloader), len(validation_dataloader), len(test_dataloader)])

In [None]:
validation_pred_labels = [torch.argmax(lbl).item() for lbl in validation_accuracy.labels_pred_raw]
validation_actual_labels = [torch.argmax(lbl).item() for lbl in validation_accuracy.labels_actual_raw]
sum(validation_pred_labels)

In [None]:
n_total = len(validation_actual_labels)
n_correct = 0
for pred_label, actual_label in zip(validation_pred_labels, validation_actual_labels):
    if pred_label == actual_label:
        n_correct += 1
n_correct / n_total

In [None]:
validation_accuracy.labels_pred_raw[:3]

In [None]:
validation_accuracy.accuracy()

calculate test set loss and accuracy

In [None]:
# evaluators for test loss and accuracy
test_loss = LossEvaluator(criterion=loss_function)
test_accuracy = AccuracyEvaluator()

# iterate through test dataloader
for images, labels in test_dataloader:
    images = images.to(torch.float32)
    labels = labels.to(torch.float32)
    labels_pred = model(images)
    loss = loss_function(labels_pred, labels)
    test_loss.append_loss(loss)
    test_accuracy.append(labels_pred, labels)
print(f'test loss: {test_loss}\
    test accuracy: {test_accuracy}')
print(
    f'precision: {round(test_accuracy.precision(), 3)}  \
    recall: {round(test_accuracy.recall(), 3)}  \
    f1: {round(test_accuracy.f1(), 3)}')


In [None]:
# function shows saliency map
def calculate_saliency_map(model, image, target_class):

    """
    Calculates the saliency map for a given image and target class.

    Args:
    model: A PyTorch model.
    image: A PyTorch tensor containing the image.
    target_class: The target class.

    Returns:
    numpy array containing the saliency map.
    """

    # Forward pass the image through the model.
    output = model(image)
    # Get the gradient of the output with respect to the input image.
    # gradient = torch.autograd.grad(output[0][target_class], image[0])
    gradient = torch.autograd.grad(
        output, image, grad_outputs=target_class.view(1,-1),
        allow_unused=True)[0][0]
    # Calculate the absolute value of the gradient.
    # saliency_map = torch.abs(gradient)
    saliency_map = torch.norm(gradient)
    # Normalize the saliency map.
    saliency_map = saliency_map / torch.max(saliency_map)
    saliency_map_np = saliency_map.detach().numpy()
    return saliency_map_np

In [None]:
gradient = torch.autograd.grad(output, image, grad_outputs)

In [None]:
import matplotlib.pyplot as plt
image = images[0][0]
plt.imshow(image, cmap='binary_r')

In [None]:
import tempfile
for images in test_dataloader.view()

In [None]:
len(validation_accuracy.labels_actual_raw)

## Appendix
additional details about the model

structure of the resnet_18 model

In [None]:
model

In [None]:
isinstance(model, torch.nn.Module)