In [1]:
%matplotlib inline

In [2]:
%load_ext autoreload
%autoreload 2

In [3]:
import torch
import torch.nn as nn
from neural_network.trainer import Agent
from skimage.transform import resize
import matplotlib.pyplot as plt
import numpy as np
import nibabel as nib

In [4]:
agent = Agent('resnet50');

Setup configurations...
Dataset sizes - Training: 421 Validation: 47 Test: 0
Loading architecture from logs/tb_logs/lightning/resnet50/version_2/checkpoints/epoch=34-step=2424.ckpt (checkpoint)..
Model [LightningModel] was created
logs/tb_logs/


AssertionError: Torch not compiled with CUDA enabled

In [None]:
#agent.fit()

# Setup Input!

In [None]:
agent.load_model()

# Example model to plot CAMS

In [None]:
from neural_network.utils.cam import get_cam
from neural_network.utils.cam_plots import plot_CAM_grid
from neural_network.utils import move_to_device, to_cpu_numpy
from neural_network.utils import interactive_slices, interactive_slices_masked
def cam_example(agent, extractor_name:str='SmoothGradCAMpp',img_size:tuple=(79,224,224), plot_type:str='grid',cmap:str='jet', alpha:float=0.3, observed_class:str=None, load_image:str='CN'):
    #label = 0
    target_layer = 'model.layer4'
    model = type(agent.model.model).__name__
    
    image = {
        'CN':nib.load('data/SPM_categorised/AIH/CN/CN_ADNI_998.nii').get_fdata,
        'MCI':nib.load('data/SPM_categorised/AIH/MCI/MCI_ADNI_1586.nii').get_fdata,
        'AD':nib.load('data/SPM_categorised/AIH/AD/AD_ADNI_2975.nii').get_fdata
    }[load_image]()
    label_to_class = {
        0:'CN',
        1:'MCI',
        2:'AD'
    }
    class_to_label = {v: k for k, v in label_to_class.items()}
    
    image = torch.from_numpy(resize(image, img_size)).float()
    model = agent.model

    model, image = move_to_device(model, image, 'cuda')
    
    mask, predicted_label = get_cam(model, image, extractor_name=extractor_name, target_layer=target_layer, observed_class=class_to_label[observed_class])
    
    #print("Number of slices above threshold:", sum(mask > 150))
    
    predicted_override=True if observed_class else False
    
    if plot_type == 'grid':
        fig = plot_CAM_grid(to_cpu_numpy(image), mask,layer=target_layer, predicted_label=label_to_class[predicted_label], expected_label=load_image,extractor=extractor_name,cmap=cmap, alpha=alpha, predicted_override=predicted_override)
        
    elif plot_type == 'slice':
        testplot = interactive_slices()
        testplot.multi_slice_viewer(to_cpu_numpy(image))
        #testplot.cycle(0.1)
        testplot.close()
    elif plot_type == 'slice_masked':
        testplot = interactive_slices_masked()
        testplot.multi_slice_viewer(to_cpu_numpy(image), mask)
        #testplot.cycle(1)
        
        testplot.close()
    return (fig,mask, predicted_label)
        

In [None]:
# Iterate all cams
import time
from torch.utils.tensorboard import SummaryWriter
tb_writer = SummaryWriter(log_dir=f'logs/visualisation/version_{round(time.time())}', filename_suffix='.CAM')

plot_cams = {
    'CAM':{},
    'GradCAM':{},
    'GradCAMpp':{},
    'SmoothGradCAMpp':{},
    'Saliency':{'cmap':'hot', 'alpha':1},
    'ScoreCAM':{},
    'SSCAM':{},
    'ISCAM':{},
}
images = ['CN','MCI','AD']#[:1]
observed_classes = ['CN','MCI','AD',None]
errors = []

for extractor_name, params in plot_cams.items():
    for image in images:
        for observed_class in observed_classes:
            try:
                fig,_,_ = cam_example(agent,extractor_name=extractor_name, **params,load_image=image,observed_class=observed_class)
                tb_writer.add_figure(f"{extractor_name}/{image}/{observed_class}",fig)

            except RuntimeError as e:
                errors.append((f'Model: {extractor_name} {e}'))

In [None]:
fig,mask,_ = cam_example(agent,extractor_name='SmoothGradCAMpp', **params,load_image='AD',observed_class='AD')