In [None]:
#imports
import sys
import pandas as pd
import numpy as np
import os
import random
import logging
from livelossplot import PlotLosses
import pickle
import torch
import monai
import time
from monai.data import DataLoader
from monai.transforms import (
    AddChanneld,
    CenterSpatialCropd,
    Compose,
    Resized,
    RandSpatialCropd,
    ScaleIntensityd,
    ToTensord,
    LoadImaged,
    Identityd,
)
from sklearn.linear_model import LogisticRegression
from monai.utils import InterpolateMode
import nibabel as nib
import lime.lime_tabular
from skimage.segmentation import slic
logging.basicConfig(stream=sys.stdout, level=logging.INFO)

In [None]:
#hyperparameters which were selected during hyperparameter tuning
lr=1e-3
opt="none"
strategy="adam"
epoch=95

In [None]:
#definitions of paths
MODEL_DIR = os.path.join("./EfficientNet_pretrained/")
path_train_data=os.path.join("../../data/trainValid_DL.csv")
filename_predictions_for_platt_scaling=os.path.join("./EfficientNet_pretrained/predictions_for_platt_scaling.csv")
mapping_ML_DL=os.path.join("../../additional_data/Mapping_DKT_Regions_Deep_ML_new.csv")
aspects_filename=os.path.join("../../additional_data/aspects05_new.pkl")
freesurfer_mapping_filename=os.path.join("../../additional_data/freesurferMappingReduced.csv")
LIME_image_directory=os.path.join("./EfficientNet_pretrained/LIME/")
LIME_save_individual_results_path=os.path.join("./EfficientNet_pretrained/LIME/LIME_individual.csv")
LIME_save_global_results_path=os.path.join("./EfficientNet_pretrained/LIME/LIME_global.csv")
mean_values_for_brain_regions_file=os.path.join('../../additional_data/meanRegionValueCNAD_new.npy')

In [None]:
#if model directory not exists create LIME directory
if not os.path.exists(LIME_image_directory):
    os.makedirs(LIME_image_directory)

In [None]:
BATCH_SIZE=1

In [None]:
#load training dataset
trainValidMerged=pd.read_csv(path_train_data,index_col="PTID")

In [None]:
#load data augmentations
validation_transforms = Compose(
        [
            LoadImaged(keys=["img","segmentation"]),
            AddChanneld(keys=["img","segmentation"]),
            ScaleIntensityd(keys=["img"]),
            Resized(keys=["img"],spatial_size=(256,256,256)),
            Resized(keys=["segmentation"],spatial_size=(256,256,256),mode=InterpolateMode.NEAREST),
            CenterSpatialCropd(keys=["img","segmentation"],roi_size=(224,224,224)),
            ToTensord(keys=["img","segmentation"]),
        ]
    )


In [None]:
#define function to set seeds for reproducibility
def set_seed(seed):
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)

In [None]:
#train Logistic Regression model for Platt's scaling
pred = pd.read_csv(filename_predictions_for_platt_scaling)
predictions = np.expand_dims(pred.predictions.to_numpy(), axis=1)
clf = LogisticRegression(random_state=0).fit(predictions, pred.labels)

In [None]:
#reformat training dataset to pytorch
Y_train=pd.get_dummies(trainValidMerged.DX,drop_first=True).to_numpy().squeeze()
Y_train=Y_train.tolist()
trainDSNew = [{"PTID":ptid,"img": img, "label": label,"segmentation":segmentation} for ptid,img, label,segmentation in zip(trainValidMerged.index,trainValidMerged.filename, Y_train,trainValidMerged.filenameSeg)]
set_seed(123)
train_ds = monai.data.Dataset(data=trainDSNew, transform=validation_transforms)
train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=8, pin_memory=torch.cuda.is_available())

In [None]:
#choose cuda as the device if it is available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
#load DL model using monai
model = monai.networks.nets.EfficientNetBN("efficientnet-b0",spatial_dims=3, in_channels=1, num_classes=2)
#load final model weights
PATH=MODEL_DIR+"model_"+str(opt)+"_"+str(lr)+"_"+str(strategy)+"_"+str(epoch)+"_final_model_polyak_averaged.pth"
model.load_state_dict(torch.load(PATH))  # Choose whatever GPU device number you want
model=model.to(device)

In [None]:
#load mean intensities for brain regions
with open(mean_values_for_brain_regions_file, 'rb') as f:
    meanReg = np.load(f)

In [None]:
#define function how to mask the image during LIME computation (masking with mean intensity image for aspects)
def mask_image(zs, segmentation, image,colors):
    maskIds= np.where(zs==0)[0]
    colorMask=colors[maskIds]
    mask = torch.isin(segmentation, colorMask)
    out=image.clone()
    out[mask]=meanImg[mask]
    return out
#define function to iterate over all LIME entries which should be considered, apply masking and calculate predictions
def funcLime(z):
    preds=[]
    index_i=0 
    #iterate over LIME entries
    for znew in z:
        set_seed(123)
        index_i+=1
        #calculate prediction of masked image
        pred=model(mask_image(znew,segmentation,img,torch.unique(segmentation)))
        pred=torch.nn.functional.softmax(pred,dim=1)
        pred=pred.cpu().detach().numpy()[:,1]
        pred=np.expand_dims(pred, axis=1)
        #calculate calibrated results
        pred=clf.predict_proba(pred)
        preds.append(pred.tolist())
    preds=np.array(preds)
    preds=preds[:,0]
    return preds

In [None]:
#load aspects
with open(aspects_filename, 'rb') as f:
    aspects = pickle.load(f)
#load ML/DL feature mapping
mapping=pd.read_csv(mapping_ML_DL)

In [None]:
#load freesurfer segmentation mapping
freesurferMapping=pd.read_table(freesurfer_mapping_filename,sep=",")

In [None]:
#identify all FreeSurfer segmented regions
colorsF = freesurferMapping.ID.to_numpy()

In [None]:
#mapping between brain regions and freesurfer segmentation regions
mergedDF=pd.merge(mapping,freesurferMapping,how="outer",right_on="brain region",left_on="feature_Deep")
#map aspects to freesurfer segmentations
for key in aspects:
    aspects[key]=mergedDF[mergedDF.feature_ML.isin(aspects[key])].ID.tolist()

#identify brain regions not available in ML models
notInML=freesurferMapping[~freesurferMapping["brain region"].isin(mapping.feature_Deep)]

In [None]:
#add brain regions without ML volumes to aspects
for index, row in notInML.iterrows():
    aspects[row["brain region"]]=[row["ID"]]       

In [None]:
#initialize dataframe to save activated brain regions of GradCAM map
column_names = ["PTID","label","pred"]+list(aspects.keys())
df = pd.DataFrame(columns = column_names)

In [None]:
index_i=0
#change model to evaluation model
model.eval()    
#iterate over training dataset
for data in train_loader:
    #store starting time of subject LIME calculation
    start = time.time()
    index_i+=1
    print(index_i)
    #load input scans, segmentation PTID and label
    img=data["img"]
    PTID=data["PTID"]
    segmentationImg=data["segmentation"]
    label=data["label"].cpu().detach().numpy()[0]
    #convert input to numpy format to apply slic for segmenting similar sized structures
    imgSeg=np.squeeze(img.numpy())
    #apply slic algorithm for segmentation
    segmentation = slic(imgSeg, n_segments=100, compactness=1,channel_axis=None,start_label=0)
    #expand dimentions for consistency with original pytorch MRI scans
    segmentation=np.expand_dims(segmentation,axis=0)
    segmentation=np.expand_dims(segmentation,axis=0)
    img=img.cuda()
    #calculate model prediction
    pred=model(img)
    pred=torch.nn.functional.softmax(pred,dim=1)
    pred=pred.cpu().detach().numpy()[0,1]
    #initalize image with mean value for all brain regions for LIME masking
    meanImg=np.zeros_like(segmentationImg)
    i=0
    #compute image with mean value for all brain regions for LIME masking
    for color in colorsF:
        maskSeg=mask = np.isin(segmentationImg, color)
        maskSeg= torch.from_numpy(maskSeg)
        meanImg[maskSeg]=meanReg[i]
        i+=1
    #convert mean image to pytorch
    meanImg=torch.from_numpy(meanImg)
    meanImg=meanImg.type(torch.float32)
    meanImg=meanImg.cuda()
    #convert slic segmentation to pytorch
    segmentation=torch.from_numpy(segmentation)
    segmentation=segmentation.type(torch.float32)
    segmentation=segmentation.cuda()
    #generate training dataset, with zeros and ones for all slic segmentations to initialize range of values
    train=np.array([np.zeros(torch.unique(segmentation).shape[0]),np.ones(torch.unique(segmentation).shape[0])])
    set_seed(123)
    #calculate LIME values for slic segmentation
    explainer = lime.lime_tabular.LimeTabularExplainer(train, categorical_features=range(0,torch.unique(segmentation).shape[0]),random_state=423)
    exp = explainer.explain_instance(np.ones((torch.unique(segmentation).shape[0])), funcLime,num_samples=1000,num_features=(torch.unique(segmentation).shape[0]))
    #sort explanations by segmentation
    sorted_by_second = sorted(exp.as_map()[1], key=lambda tup: tup[0])
    #extract LIME explanations for each slic segmentation region
    first_tuple_elements = [a_tuple[1] for a_tuple in sorted_by_second]
    #initialize image containing LIME values for each pixel in dependence of the slic segmentations
    out=np.zeros(img.shape)
    #compute image containing LIME values for each pixel in dependence of the slic segmentations
    for i in range(len(first_tuple_elements)):
        region=(segmentation==i)
        out[region.cpu().detach().numpy()]=first_tuple_elements[i]
    #save image containing LIME values as nifti file
    result_test=np.squeeze(out)
    result_image = nib.Nifti1Image(result_test, affine=np.eye(4))
    nib.save(result_image,LIME_image_directory+"/LIME_PTID_"+PTID[0]+".nii.gz")
    #flatten FreeSurfer segmentation
    seg_flattened=segmentationImg.flatten()
    #generate new segmentation based on aspects
    seg_Flattened_new=np.zeros(seg_flattened.shape)
    i=1
    for aspect in aspects:
        for value in aspects[aspect]:
            seg_Flattened_new[seg_flattened==value]=i
        i+=1
    #flatten LIME image
    out_Flattened=out.flatten()
    #identify segmentation labels and counts
    colors, counts = np.unique(seg_Flattened_new, axis=0, return_counts=True)
    #initalize array to save summed values of LIME scores per aspect
    summedValues=[0.0]*len(aspects)
    summedValues=np.asarray(summedValues)
    j=0
    #calculate summed values of LIME scores
    for i in range(1,len(aspects)+1):
        summedValues[j]=out_Flattened[seg_Flattened_new==i].sum()
        if (seg_Flattened_new==i).sum()==0:
            counts=np.insert(counts,j,1)
        j+=1
    #calculate mean values of LIME scores
    meanValues=summedValues/counts
    #save mean LIME scores for all aspects
    d = {'ID': list(aspects.keys()), 'lime_values': meanValues}
    dfSub = pd.DataFrame(data=d)
    column_names = ["PTID","label","pred"]+list(aspects.keys())
    values=[PTID[0],label,pred]+meanValues.tolist()
    df2 = pd.DataFrame([values],columns = column_names)
    df=pd.concat([df,df2],ignore_index=True)
    end = time.time()
    print(format(end-start))

In [None]:
#save results at subject level
df.to_csv(LIME_save_individual_results_path)

In [None]:
#save global results
df=df.drop(["PTID"],axis=1)
d = {'feature': df.abs().sum().index.tolist(), 'LIMEImportance': df.abs().sum().tolist()}
resultsSum=pd.DataFrame(data=d)
resultsSum.to_csv(LIME_save_global_results_path)