In [1]:
import torch
import numpy as np
import matplotlib.pyplot as plt
import os
%matplotlib inline
import SimpleITK as sitk
import sys
sys.path.insert(1, 'Z:/grodriguez/CardiacOCT/code/utils')
from counts_utils import create_image_png
import shap

Using `tqdm.autonotebook.tqdm` in notebook mode. Use `tqdm.tqdm` instead to force console mode (e.g. in jupyter console)


In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [3]:
def get_features(task: int) -> dict:
    """Obtain feature maps for the selected frames

    Args:
        task (int): task corresponding to the model

    Raises:
        ValueError: task needs to be either 601 or 604

    Returns:
        dict: dict with {frame: [OCT frame, pred seg, features]}
    """    

    features_path = r'Z:\grodriguez\CardiacOCT\feature_maps'
    features_path_list = os.listdir(features_path)

    #Frame names that were selected to plot
    frames = ['NLDAMPH0017_1_frame473_008', 'NLDAMPH0063_1_frame40_016', 'NLDISALA0084_1_frame59_041', 'NLDRADB0089_1_frame440_106']
    
    if task == 601:
        rgb = [0,1,2]

    elif task == 604:
        rgb = [9,10,11]

    else: raise ValueError

    all_imgs = {}

    count = 0
    for file in range(len(features_path_list)):

        if features_path_list[file].endswith(str('{}_v2'.format(task))):

            print(features_path_list[file])

            pred_seg = sitk.GetArrayFromImage(sitk.ReadImage(os.path.join(features_path, features_path_list[file], '{}.nii.gz'.format(frames[count]))))
            frame_raw_r = sitk.GetArrayFromImage(sitk.ReadImage(r'Z:\grodriguez\CardiacOCT\data-2d\nnUNet_raw_data\Task{}_CardiacOCT\imagesTs\{}_{}.nii.gz'.format(task, frames[count], "%04d" % (rgb[0]))))
            frame_raw_g = sitk.GetArrayFromImage(sitk.ReadImage(r'Z:\grodriguez\CardiacOCT\data-2d\nnUNet_raw_data\Task{}_CardiacOCT\imagesTs\{}_{}.nii.gz'.format(task, frames[count], "%04d" % (rgb[1]))))
            frame_raw_b = sitk.GetArrayFromImage(sitk.ReadImage(r'Z:\grodriguez\CardiacOCT\data-2d\nnUNet_raw_data\Task{}_CardiacOCT\imagesTs\{}_{}.nii.gz'.format(task, frames[count], "%04d" % (rgb[2]))))
            frame_raw = np.stack((frame_raw_r[0], frame_raw_g[0], frame_raw_b[0]), axis=-1)

            #All features have the following size
            final_feature1 = np.zeros((32, 768, 768))
            final_feature2 = np.zeros((64, 384, 384))
            final_feature3 = np.zeros((128, 192, 192))
            final_feature4 = np.zeros((256, 96, 96))
            final_feature5 = np.zeros((480, 48, 48))
            final_feature6 = np.zeros((480, 24, 24))
            final_feature7 = np.zeros((480, 12, 12))
            final_feature8 = np.zeros((480, 6, 6))

            features = [final_feature1, final_feature2, final_feature3, final_feature4, final_feature5, final_feature6, final_feature7, final_feature8]
            folds = []

            for feature in range(len(features)):

                channels, x, y = features[feature].shape

                #Check the map in each fold
                for i in range(5):

                    a = torch.load(os.path.join(features_path, features_path_list[file], 'conv_torch.Size([1, {}, {}, {}])_{}.pt'.format(channels, x, y, i)))[0,:,:,:].cpu()
                    folds.append(a)
                    
                    #We sum the feature map over each fold
                    features[feature] = np.add(features[feature], a)

                #Finally we do an average
                features[feature] /= 5

                #And get the maximum for each pixel (so we get 2D)
                features[feature] = torch.max(features[feature], dim=0)[0]

            #Create dict with the OCT frame, pred segmentation and features
            all_imgs[features_path_list[file]] = [frame_raw, pred_seg, features[0], features[1], features[2], features[3], features[4], features[5], features[6], features[7]]

            count += 1
            
    return all_imgs

In [4]:
def plot_features(features_dict: dict):
    """Get all the features in PNG and save them

    Args:
        features_dict (dict): dict obtained in the previous function
    """    

    i = 0
    for key in features_dict.keys():

        plt.imshow(features_dict[key][0].astype(np.uint8), alpha = 0.8)
        plt.imshow(create_image_png(features_dict[key][1][0].astype(np.uint8)), alpha = 0.5)
        plt.title('Raw + pred seg')
        plt.axis('off')
        plt.savefig('{}_raw'.format(key))
        plt.close()

        plt.imshow(features_dict[key][2])
        plt.title('Conv 1')
        plt.axis('off')
        plt.savefig('{}_1'.format(key))
        plt.close() 

        plt.imshow(features_dict[key][3])
        plt.title('Conv 2')
        plt.axis('off')
        plt.savefig('{}_2'.format(key))
        plt.close()

        plt.imshow(features_dict[key][4])
        plt.title('Conv 3')
        plt.axis('off')
        plt.savefig('{}_3'.format(key))
        plt.close()

        plt.imshow(features_dict[key][5])
        plt.title('Conv 4')
        plt.axis('off')
        plt.savefig('{}_4'.format(key))
        plt.close()

        plt.imshow(features_dict[key][6])
        plt.title('Conv 5')
        plt.axis('off')
        plt.savefig('{}_5'.format(key))
        plt.close()

        plt.imshow(features_dict[key][7])
        plt.title('Conv 6')
        plt.axis('off')
        plt.savefig('{}_6'.format(key))
        plt.close()

        plt.imshow(features_dict[key][8])
        plt.title('Conv 7')
        plt.axis('off')
        plt.savefig('{}_7'.format(key))
        plt.close()

        plt.imshow(features_dict[key][9])
        plt.title('Conv 8')
        plt.axis('off')
        plt.savefig('{}_8'.format(key))
        plt.close()

        i += 1

In [5]:
features_dict_601 = get_features(601)

amph_0017_601_v2
amph_0063_601_v2
isala_0084_601_v2
radb_0089_601_v2


In [85]:
plot_features(features_dict_601)

In [6]:
features_dict_604 = get_features(604)

amph_0017_604_v2
amph_0063_604_v2
isala_0084_604_v2
radb_0089_604_v2


In [92]:
plot_features(features_dict_604)