# Task 3: Advanced CNN and TransferLearning


We train a CNN, based on a larger [dataset](https://www.kaggle.com/datasets/ahmedhamada0/brain-tumor-detection?select=no) with 3k brain MRI images. Then we further finetune the model based on our original dataset. 

In [None]:
from pathlib import Path
import numpy as np
%matplotlib inline
import matplotlib.pyplot as plt
from sklearn.metrics import accuracy_score, f1_score, plot_confusion_matrix, confusion_matrix, ConfusionMatrixDisplay

import torch
from torchvision import datasets, models, transforms
import torch.nn as nn
from torch.nn import functional as F
from torch.nn import CrossEntropyLoss
from torch.optim import Adam, SGD, RMSprop, lr_scheduler
from torch.utils.data import TensorDataset, DataLoader, ConcatDataset

import shap

from data import get_img_dataset
from project3Lib.transforms import EnhanceContrast
import project3Lib.CNN as cnn
from project3Lib.CNN import train_model, test, predict
import project3Lib.utils as utils

from masked_dataset import MaskedDataset

In [None]:
# Import TransferLearning Dataset
transferlearning_path = "data/tl_dataset"
transform = [EnhanceContrast(reduce_dim=False), transforms.Grayscale()]
tl_train_dataset,tl_val_dataset, tl_test_dataset = get_img_dataset(transform, data_path=transferlearning_path, use_same_transforms = True)
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
print('Device state:', device)
batch_size = 64
tl_trainloader = DataLoader(tl_train_dataset, batch_size=batch_size, shuffle=True)
tl_testloader = DataLoader(tl_test_dataset, batch_size=batch_size, shuffle=False)
tl_validloader = DataLoader(tl_val_dataset, batch_size=batch_size, shuffle=True)

dataloaders = {
    'train' : tl_trainloader, 
    'validation': tl_validloader
}

image_datasets = {
    'train': tl_train_dataset,
    'validation': tl_val_dataset
}

# Model implementation
## 1. Train model on large dataset

In [None]:
model = cnn.CNN()
criterion = CrossEntropyLoss()
optimizer = Adam(model.parameters(), lr=0.001)

epochs = 30
model = train_model(model, criterion, optimizer, dataloaders, image_datasets, num_epochs=epochs)

In [None]:
torch.save(model.state_dict(), f"trained_weights/TL_basemodel.pt")

**Evaluate base model and SHAP** 

In [None]:
model = cnn.CNN()
model.load_state_dict(torch.load("trained_weights/TL_basemodel.pt"))

In [None]:
# Deep Explainer
np.random.seed(123)
indices = np.random.randint(0, high=len(tl_train_dataset), size=100)
bg = torch.utils.data.Subset(tl_train_dataset, indices)
bg = [i for i,j in bg]
bg = torch.stack(bg)

e = shap.DeepExplainer(model, bg)

outs = []
for i in bg:
    pred, out = predict(model,i)
    outs.append((out[0][0].item(), out[0][1].item()))
print(f"Mean values {np.mean([i for i,j in outs])}, {np.mean([j for i,j in outs])}")


indices = np.random.randint(0, high=len(tl_test_dataset), size=10)
sub_test = torch.utils.data.Subset(tl_test_dataset, indices)

test_images = [i for i,j in sub_test]
y_test = [j for i,j in sub_test]

for i, image in enumerate(test_images):
    
    image = image.reshape((1,1,128,128))
    pred, out = predict(model,image)
    shap_values = e.shap_values(image)
    shap_numpy = [np.swapaxes(np.swapaxes(s, 1, -1), 1, 2) for s in shap_values]
    test_numpy = np.swapaxes(np.swapaxes(image.cpu().numpy(), 1, -1), 1, 2)
    print(f"Image #{i}: True Class {y_test[i]}, Prediction {pred}, Probabilities {out}")
    shap.image_plot(shap_numpy, test_numpy, labels = ["no","yes"])

## 2. Transfer: finetune model on original dataset

In [None]:
# load and augment original data
unique = input("Use unique images?[yes/no]").lower() == "yes"
input_path = "data/unique_images" if unique else "data/images"

transform = [EnhanceContrast(reduce_dim=False), transforms.Grayscale()]
train_dataset,val_dataset, test_dataset = get_img_dataset(transform, data_path=input_path, use_same_transforms = True)

device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
print('Device state:', device)
batch_size = 16
trainloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
testloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
validloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=True)

dataloaders = {
    'train' : trainloader, 
    'validation': validloader
}

image_datasets = {
    'train': train_dataset,
    'validation': val_dataset
}

full_retrain = input("Retrain all layers? [yes/no]").lower() == "yes"

In [None]:
transfer_model = cnn.CNN()
transfer_model.load_state_dict(torch.load("trained_weights/TL_basemodel.pt"))
if not full_retrain:
    for param in transfer_model.parameters():
        param.requires_grad = False
    for layer in transfer_model.modules():
        if isinstance(layer, nn.Linear):
            layer.weight.requires_grad = True

total_trainable_params = sum(
    p.numel() for p in transfer_model.parameters() if p.requires_grad)
print(f'{total_trainable_params:,} training parameters.')

In [None]:
criterion = CrossEntropyLoss()
optimizer = Adam(transfer_model.parameters(), lr=0.0005)
epochs = 20
transfer_model = train_model(transfer_model, criterion, optimizer, dataloaders, image_datasets, num_epochs=epochs)

In [None]:
torch.save(transfer_model.state_dict(), f"trained_weights/TL_model.pt")

# Evaluate final Classifier

In [None]:
transfer_model = cnn.CNN()
transfer_model.load_state_dict(torch.load("trained_weights/TL_model.pt"))

In [None]:
x_test = [i for i,j in test_dataset]
y_test = [j for i,j in test_dataset]
preds = []
outs = []
for t in x_test:
    pred, out = predict(transfer_model, t)
    preds.append(pred)
    
print(f"Accuracy: {accuracy_score(preds,y_test)}")
print(f"F1 score: {f1_score(preds,y_test)}")

cm=confusion_matrix(y_test,preds,normalize="true")
cmd = ConfusionMatrixDisplay(cm)
cmd.plot()

```
Accuracy: 0.95
F1 score: 0.9565217391304348
```

![](Plots/CM_CNN_TL.png)

# Interpretability

In [None]:
test_dataset_nomasks = test_dataset
transform = [transforms.Grayscale()]
common_transform = [EnhanceContrast(reduce_dim=False)]
_,_, test_dataset = get_img_dataset(transform = transform, \
                                    use_same_transforms=True, \
                                    common_transforms=common_transform, \
                                    data_path=input_path, \
                                    folder_type = MaskedDataset, \
                                    mask_folder=Path("data/masks"))

## SHAP

In [None]:
# Deep Explainer
bg = [i for i,j in train_dataset]
bg = torch.stack(bg)
e = shap.DeepExplainer(transfer_model, bg)
outs = []
for i in bg:
    pred, out = predict(transfer_model,i)
    outs.append((out[0][0].item(), out[0][1].item()))
print(f"Mean values {np.mean([i for i,j in outs])}, {np.mean([j for i,j in outs])}")

In [None]:
ious = []
for i, (image,mask,target) in enumerate(test_dataset):
    image = image.reshape((1,1,128,128))
    pred, out = predict(transfer_model,image)
    
    shap_values = e.shap_values(image)
    shap_numpy = [np.swapaxes(np.swapaxes(s, 1, -1), 1, 2) for s in shap_values]
    test_numpy = np.swapaxes(np.swapaxes(image.cpu().numpy(), 1, -1), 1, 2)
    print(f"Image #{i}: True Class {target}, Prediction {pred}, Probabilities {out}")
    shap.image_plot(shap_numpy, test_numpy)
    
    predicted_mask = np.copy(shap_values[1].reshape(128,128))
    mask = mask.numpy().reshape((128,128))
    pixels = int(np.sum(mask.flatten()))
    iou = utils.evaluate_interpretability(predicted_mask, mask,pixels)
    print(iou)
    if target == 1:
        ious.append(iou)
    if i == 0:
        np.save("Plots/CNN_TL_SHAP_0", predicted_mask)
    if i == 0:
        np.save("Plots/CNN_TL_SHAP_0", predicted_mask)
    
print(f"Mean IOU: {np.mean(ious)}")

```
Mean IOU: 0.22387248291302617
```

# Integrated Gradients with Captum

In [None]:
ious = []
for i, (image,mask,target) in enumerate(test_dataset):
    data = (image,target)
    if target == 1: 
        class_1, class_0 = utils.plot_grads(data,transfer_model, layer_idx = -1,plot=False,grad_type= "integ_grads")
    else:
        class_0, class_1 = utils.plot_grads(data,transfer_model, layer_idx = -1,plot=False,grad_type= "integ_grads")
    predicted_mask = np.copy(class_1.reshape(128,128))
    mask = mask.numpy().reshape((128,128))
    pixels = int(np.sum(mask.flatten()))
    iou = utils.evaluate_interpretability(predicted_mask, mask,pixels)
    if target == 1:
        ious.append(iou)
    if i == 0:
        np.save("Plots/CNN_TL_IntGrad_0", predicted_mask)
    if i == 1:
        np.save("Plots/CNN_TL_IntGrad_1", predicted_mask)
print(f"The mean iou is {np.mean(ious)}")

The mean iou is 0.15528650486299014

In [None]:
utils.plot_grads_dataloader(test_dataset_nomasks, transfer_model, grad_type= "integ_grads" ,plot=True, save_name="tl_cnn")

# GradCam

In [None]:
ious = []
for i, (image,mask, target) in enumerate(test_dataset):
    data = (image,target)
    if target == 1: 
        class_1, class_0 = utils.plot_grads(data,transfer_model, layer_idx = 4,plot=False,grad_type= "grad_cam")
    else:
        class_0, class_1 = utils.plot_grads(data,transfer_model, layer_idx = 4,plot=False,grad_type= "grad_cam")
    
    predicted_mask = np.copy(class_1.detach().numpy().reshape(128,128))
    mask = mask.numpy().reshape((128,128))
    pixels = int(np.sum(mask.flatten()))
    iou = utils.evaluate_interpretability(predicted_mask, mask,pixels)
    if target == 1:
        ious.append(iou)
    if i == 0:
        np.save(f"Plots/CNN_TL_GradCam_0", predicted_mask)
    if i == 1:
        np.save(f"Plots/CNN_TL_GradCam_1", predicted_mask)
print(f"The mean iou is {np.mean(ious)}")

In [None]:
utils.plot_grads_dataloader(test_dataset_nomasks, transfer_model, grad_type= "grad_cam" ,plot=True,layer_idx=4, save_name="tl_cnn")