In [None]:
import os
os.chdir('..')
from utils.model import rebuild_kneenet
from utils.preprocess_image import preprocess
import torch
import torchvision
import torch.nn as nn
import cv2
import matplotlib.pyplot as plt
import numpy as np

In [None]:
from utils.googledrive_requests import download_from_googledrive

In [None]:
file_id = os.getenv("mendeley_dataset_file_id")

In [None]:
file_id

In [None]:
target_filepath = "data/mendeley/KneeKL299.zip"
download_from_googledrive(file_id, target_filepath)

# Load Data and verify output

In [None]:
kneenet = rebuild_kneenet()

In [None]:
img_folders = ['data/mendeley/kneeKL299/test/' + str(x) for x in range(5)]
n_images_per_score = 1
images = []
for img_folder in img_folders:
    first_n_files = os.listdir(img_folder)[:n_images_per_score]
    for file in first_n_files:
        img = cv2.imread(os.path.join(img_folder,file), 0).astype("float")
        processed_image = preprocess(img)
        input_image = processed_image.reshape((1,) + processed_image.shape)
        input_image = torch.from_numpy(input_image)
        input_image = input_image.float()
        
        
        images.append(input_image)
        

In [None]:
plt.imshow(images[0][0][0],cmap='gray')

In [None]:
images = torch.cat(images,dim=0)

In [None]:
logits = kneenet(images)

In [None]:
softmax = nn.Softmax(dim=1)

In [None]:
proba = softmax(logits).detach().numpy()

In [None]:
print(proba)

# Test Captum Integrated Gradients

In [None]:
"""Causes kernel to die:
from captum.attr import IntegratedGradients
integrated_gradients = IntegratedGradients(kneenet,multiply_by_inputs=True)
integrated_gradients.attribute(images)
"""

# Test Captum DeepLift

In [None]:
from captum.attr import DeepLift
DeepLift = DeepLift(kneenet,multiply_by_inputs=True)
attr = DeepLift.attribute(images, target=4)

In [None]:
attr = attr.detach().numpy()
positive_attr = attr * (attr>0)
negative_attr = -1*attr * (attr<0)
positive_attr = np.rollaxis(positive_attr,1,4)
negative_attr = np.rollaxis(negative_attr,1,4)

In [None]:
def normalize(img):
    img = img - img.min()
    img = img / img.max()
    return img

In [None]:
for image_num in range(n_images_per_score*5):
    image_n = normalize(np.rollaxis(images[image_num].detach().numpy(),0,3))
    pos_attr_n = normalize(positive_attr[image_num,:,:,0])
    neg_attr_n = normalize(negative_attr[image_num,:,:,0])
    
    
    image_n[:,:,0] = image_n[:,:,0] + pos_attr_n # Put positive attribution in the red channel
    image_n[:,:,1] = image_n[:,:,1] + neg_attr_n # Put negative attribution in the green channel
    
    # Normalize image for matplotlib
    image_n = image_n - image_n.min()
    image_n = image_n / image_n.max()
    
    # Visualize
    prediction = np.argmax(proba,axis=1)
    plt.figure(figsize=(10,10))
    plt.title("DeepLift explanation for prediction KL = " + str(prediction[image_num]))
    plt.imshow(image_n)
    
    
    

# Test Captum Guided GradCam

In [None]:
from captum.attr import GuidedGradCam
ggc = GuidedGradCam(kneenet,kneenet.features.denseblock4.denselayer32.conv2)
# get attributes for all KL-classes
attr_4 = ggc.attribute(images, target=4)
attr_3 = ggc.attribute(images, target=3)
attr_2 = ggc.attribute(images, target=2)
attr_1 = ggc.attribute(images, target=1)
attr_0 = ggc.attribute(images, target=0)

In [None]:
for image_num in range(n_images_per_score*5):
    # Get predictions for all images
    predictions = np.argmax(proba, axis=1)
    # Get the attributes for the predicted class of the current image
    attr = eval(f'attr_{predictions[image_num]}')
    
    # Normalize the image and the attributes
    image_n = normalize(images[image_num].detach().numpy()).transpose((1, 2, 0))
    attr_class = normalize(attr[predictions[image_num]].detach().numpy()).transpose((1, 2, 0))
    
    # Combine the image and the attributes
    combined_image = normalize(image_n + attr_class)
    
    # Visualize
    plt.figure(figsize=(10,10))
    plt.title("DeepLift explanation for prediction KL = " + str(predictions[image_num]))
    plt.imshow(combined_image)

In [None]:
kneenet