In [2]:
import matplotlib.pyplot as plt
from ipywidgets import interact
import numpy as np
import os
import nibabel as nib
import pandas as pd
import matplotlib as mpl
import matplotlib.patches as patches
import matplotlib.patheffects as path_effects

In [2]:
from ipynb.fs.full.Feature_extraction_classes_functions import rescale, reshape_arr_and_axis
from ipynb.fs.full.project_helper_functions_classes import take_folder_list, take_special_file_list, extract_case_id
from ipynb.fs.full.project_helper_functions_classes import create_folder

In [3]:
def explore_3D_array(arr: np.ndarray, slice_ind = 2, cmap: str = 'gray',
                         fig_size = (4,4), axis = 'off', show_colorbar = False):
      """
      Given a 3D array with shape (Z,X,Y) This function will create an interactive
      widget to check out all the 2D arrays with shape (X,Y) inside the 3D array. 
      The purpose of this function to visual inspect the 2D arrays in the image. 

      Args:
        arr : 3D array with shape (Z,X,Y) that represents the volume of a MRI image
        cmap : Which color map use to plot the slices in matplotlib.pyplot
      """
          
      def fn(SLICE):
            
            if slice_ind == 0 :
                plt.figure(figsize=fig_size)
                plt.imshow(arr[SLICE, :, :], cmap=cmap)
                plt.title('Explore Layers of Brain MRI', fontsize=10)
                plt.axis(axis)


            if slice_ind == 1 :

                plt.figure(figsize=fig_size)
                plt.imshow(arr[:, SLICE, :], cmap=cmap)
                plt.title('Explore Layers of Brain MRI', fontsize=10)
                plt.axis(axis)


            if slice_ind == 2 :

                plt.figure(figsize=fig_size)
                plt.imshow(arr[:, :, SLICE], cmap=cmap)
                plt.title('Explore Layers of Brain MRI', fontsize=10)
                plt.axis(axis)



            if show_colorbar:
                plt.colorbar()

      interact(fn, SLICE=(0, arr.shape[slice_ind] -1))

In [4]:
def explore_3D_array_with_mask(arr: np.ndarray, mask: np.ndarray, slice_ind = 2, cmap: str = 'gray', 
                                   mask_cmap = 'viridis', alpha = 0.5, fig_size = (4,4),
                                  axis = 'off', show_colorbar = False):
      """
      Given a 3D array with shape (Z,X,Y) This function will create an interactive
      widget to check out all the 2D arrays with shape (X,Y) inside the 3D array. 
      The purpose of this function to visual inspect the 2D arrays in the image. 

      Args:
        arr : 3D array with shape (Z,X,Y) that represents the volume of a MRI image
        cmap : Which color map use to plot the slices in matplotlib.pyplot
      """
      def fn(SLICE):
             
            if slice_ind == 0 :
                plt.figure(figsize=fig_size)
                plt.imshow(arr[SLICE, :, :], cmap=cmap)
                plt.imshow(mask[SLICE, :, :], cmap=mask_cmap, alpha = alpha)
                plt.title('Explore Layers of Brain MRI', fontsize=10)
                plt.axis(axis)


            if slice_ind == 1 :

                plt.figure(figsize=fig_size)
                plt.imshow(arr[:, SLICE, :], cmap=cmap)
                plt.imshow(mask[:, SLICE, :], cmap=mask_cmap, alpha = alpha)
                plt.title('Explore Layers of Brain MRI', fontsize=20)
                plt.axis(axis)


            if slice_ind == 2 :

                plt.figure(figsize=fig_size)
                plt.imshow(arr[:, :, SLICE], cmap=cmap)
                plt.imshow(mask[:, :, SLICE], cmap=mask_cmap, alpha = alpha)
                plt.title('Explore Layers of Brain MRI', fontsize=10)
                plt.axis(axis)



            if show_colorbar:
                plt.colorbar()


      interact(fn, SLICE=(0, arr.shape[slice_ind] -1))


In [12]:
def figure_original_and_stripped_slice(slice1, slice2,
                                       title1 = 'original', title2 = 'stripped', 
                    figure_size = (10, 10), show_axis = 'off',
                    save_figure = False, figure_path = ''):
    """
    
    """
    

    fig, ax = plt.subplots(1,2, figsize=figure_size)

    im0 = ax[0].imshow(slice1, interpolation='none',  cmap='gray')
    ax[0].set_title(title1)

    im1 = ax[1].imshow(slice2, interpolation='none',  cmap='gray')
    ax[1].set_title(title2)

    ax[0].axis(show_axis)
    ax[1].axis(show_axis)


    if save_figure:
        plt.savefig(figure_path)

    plt.show()
        

In [16]:
def figure_rescaled_data(slice1, slice2, original_array, 
                        title1 = 'original', title2 = 'rescaled', 
                        figure_size = (10, 10), show_axis = 'off', 
                         show_colorbar = False, set_brightness = False,
                        save_figure = False, figure_path = ''):
    """
    
    """
    
    if set_brightness:
        vmin1, vmax1 = np.min(slice1), np.max(slice1)
        vmin2, vmax2 = np.min(slice2), np.max(slice2)
    else:
        vmin1, vmax1 = np.min(original_array), np.max(original_array)
        vmin2, vmax2 = 0, 1
    
    fig, ax = plt.subplots(1,2, figsize=figure_size)

    im0 = ax[0].imshow(slice1, interpolation='none',  cmap='gray', vmin = vmin1, vmax = vmax1)
    ax[0].set_title(title1)

    im1 = ax[1].imshow(slice2, interpolation='none',  cmap='gray', vmin = vmin2, vmax = vmax2)
    ax[1].set_title(title2)

    ax[0].axis(show_axis)
    ax[1].axis(show_axis)
    
    if show_colorbar:
        plt.colorbar(im0,  ax=ax[0], orientation = 'horizontal', pad=0.01, aspect=35)
        plt.colorbar(im1,  ax=ax[1], orientation = 'horizontal', pad=0.01, aspect=35)

    if save_figure:
        plt.savefig(figure_path)

    plt.show()


In [None]:
def figure_MRC_AUC_(csv_path, MRI_column_name, MRI_type, shape, save_figure = False, figure_path = ''):
    
    """
    
    """
    
    ## get stats, prepare data
    stat_df = pd.read_csv(csv_path)
    stats = {}
    for i in range (1, 4):
        stats[str(i)] = {}
        for j in range (1, 9):
            df = stat_df[(stat_df['primary_rate'] == i ) & (stat_df['secondary_rate'] == j )]
            stats[str(i)][str(j)] = df[['AUC_value', 'step', 'power']][(df[MRI_column_name] == MRI_type) & (df['shape'] == shape)].values
    
    ## create figure
    f, axes = plt.subplots(nrows = 8, ncols = 3, sharex=True, sharey = True)    
 
    cmap = mpl.colors.ListedColormap([ 'limegreen','greenyellow' , 'yellowgreen',
                                   'yellow','orange','pink', 'red' ,'lightskyblue', 'dodgerblue', 'navy' ])
    
    bounds = [0.50, 0.55, 0.60, 0.65, 0.70, 0.75, 0.80, 0.85, 0.90]
    norm = mpl.colors.BoundaryNorm(bounds,  cmap.N, extend='both')   

    for i in range (1, 4):       
        axes[0, i-1].set_xlabel("PR" + str(i)) 
        axes[0, i-1].xaxis.set_label_coords(0.5, 1.18)        
        for j in range (1, 9):                   
            df = pd.DataFrame(stats[str(i)][str(j)], columns=['AUC', 'size', 'power'])                 
            im = axes[j-1][i-1].scatter( x=df['size'], y=df['power'], c=df['AUC'], cmap=cmap, 
                                      marker=".", alpha=1, s = 100, vmin = 0.45, vmax = 0.95)
            if j > 1:
                axes[j-1,-1].set_ylabel("AO" + str(j-1), rotation = 270)
                axes[j-1,-1].yaxis.set_label_coords(1.05, 0.5)    

    f.set_figwidth(5)
    f.set_figheight(8)
    f.supxlabel(' n \n (Number of pixels from center pixel to the ROI edge)',
                x = 0.5, y = 0.035, va = 'bottom', multialignment = 'center')        
    f.supylabel('p\n (Power of PR’s)', x = 0, y = 0.5, va = 'center', multialignment = 'center')

    plt.subplots_adjust(hspace=0.1)
    plt.subplots_adjust(wspace=0.1)    
    cbar_ax = f.add_axes([0.13, 0, 0.75, 0.02])
    f.colorbar(im, cax=cbar_ax, location='bottom', extend = 'both', label='AUC')  
    
    if save_figure:
        plt.savefig(figure_path)

    plt.show()


In [2]:
class epileptic_focus_visualization():
    
    def __init__(self,stripped_data_folder = 'data/stripped_data',
                original_data_folder = 'data/ordered_data',
                feature_output_folder = 'old_output/features',
                features_csv_path = 'old_output/EPE_stat_results/BA_results.csv',
                feature_ind = 0, 
                sequences = {'t1': ['t1_tra', 't1_sag'], 't2' : ['t2_cor', 't2_tra']}):
        
     
        ## prepare dataframe 
        features_df = pd.read_csv(features_csv_path)
        feature_vals = features_df.iloc[feature_ind, :6].values
        feature_pkl_path = take_pkl_file_path(feature_vals, feature_output_folder)
        feature_df = pd.read_pickle(feature_pkl_path)
        threshold = features_df['threshold'].iloc[feature_ind]
        self.step = features_df['step'].iloc[feature_ind]
        self.ROI_shape = features_df['shape'].iloc[feature_ind]
        self.foci_df = feature_df[feature_df['value'] > threshold]
        
        
        ## get files paths  
        self.stripped_path =  os.path.abspath(stripped_data_folder)
        self.original_path =  os.path.abspath(original_data_folder)
        self.feature_ind = feature_ind
        situations = take_folder_list(self.stripped_path)
        file_list, folder_path_list = take_special_file_list(situations, self.stripped_path, self.original_path)
        MRI_types = take_folder_list(os.path.join(self.stripped_path, situations[0]))
        MRI_type = feature_vals[0]
        self.collective_analysis = False
        self.file_list_dict = {}
        if MRI_type in MRI_types:        
            for situation in situations:
                self.file_list_dict[situation] = [fl for fl in file_list if (MRI_type in fl[0]) & (situation in fl[0])]

        else:
            self.collective_analysis = True
            sq_MRI = sequences[MRI_type]
            for situation in situations:
                self.file_list_dict[situation] = [fl for fl in file_list for MRI in sq_MRI  if (MRI in fl[0]) & (situation in fl[0])]
    
        
    def show_figure(self, save_figure = False,  folder_path = ''):
        
        self.save_figure = save_figure
        self.folder_path = folder_path
        
        if self.save_figure:
            self.folder_path = os.path.join(self.folder_path, 'feature_' + str(self.feature_ind + 1))
            create_folder(self.folder_path)
        
        for self.situation in self.file_list_dict.keys():
            for files in self.file_list_dict[self.situation]:
               
                self.case_id = extract_case_id(files[0], collective_analysis = self.collective_analysis)
                ef_coordinates = self.foci_df['coordinates'][(self.foci_df['situation'] == self.situation) & (self.foci_df['case'] == self.case_id)].values
                
                print(f'{len(ef_coordinates)} coordinates were predicted as epileptic foci at {files} in {self.situation}')
                
                for coordinate in ef_coordinates:                    
                    stripped_file_path = files[0]
                    stripped_img = nib.load(stripped_file_path)
                    axes_id = nib.aff2axcodes(stripped_img.affine)
                    stripped_data = stripped_img.get_fdata()

                    #reshape data for matching with coordinates obtained within proccess
                    self.img_reshaped, self.axes_labels = reshape_arr_and_axis(stripped_data, axes_id) 
                    img_shape = self.img_reshaped.shape

                    self.coordinate = np.array(coordinate)                   
                    self.display_epileptic_foci()
                    
    def take_patches(self):
        
        ind0, ind1, ind2 = int(self.coordinate[0]), int(self.coordinate[1]), int(self.coordinate[2])  
        step = self.step
        
        if self.ROI_shape == 'square':
            roi_len = (step * 2) + 1
            xb,yb= int(ind1-step) - 0.5, int(ind2-step)- 0.5
            patch1 = patches.Rectangle((yb, xb), roi_len, roi_len, linewidth=1,
                                     edgecolor='blue', facecolor="none" )
            roi_len = (step * 2) 
            xb,yb= int(ind1-step), int(ind2-step)
            patch2 = patches.Rectangle((yb, xb), roi_len, roi_len, linewidth=0,
                                  facecolor="yellow", alpha = 0.3 )
            shape_len = (step * 6) + 1
            dif_ = step * 3
            xb2,yb2= ind1-dif_-0.5  , ind2-dif_ - 0.5 
            patch3 = patches.Rectangle((yb2, xb2), shape_len, shape_len, linewidth=3,
                                     edgecolor='r', facecolor="none")
            return([patch1, patch2, patch3])
        
        elif self.ROI_shape == 'circle':
            
            xb,yb= int(ind1) , int(ind2)
            patch1 = patches.Circle((yb, xb), step, linewidth=1,
                                     edgecolor='blue', facecolor="none" )            
            xb,yb= int(ind1), int(ind2)
            patch2 = patches.Circle((yb, xb), step, linewidth=0,
                                  facecolor="yellow", alpha = 0.3 )
            shape_radius = step * 3
            xb2,yb2= int(ind1), int(ind2) 
            patch3 = patches.Circle((yb2, xb2), shape_radius, linewidth=3,
                                     edgecolor='r', facecolor="none")
            return([patch1, patch2, patch3])

                
    
    def display_epileptic_foci(self):
        
        ind0, ind1, ind2 = int(self.coordinate[0]), int(self.coordinate[1]), int(self.coordinate[2]) 
        step = self.step
        img_slice = self.img_reshaped[ind0, :, :]        
        fig, ax = plt.subplots(1)
        # Display the image
        ax.imshow(img_slice, interpolation='none',  cmap='gray')
        axes_signs = find_coord_labels(self.axes_labels, img_slice.shape)

        for i in range (0, 4):
            xdist, ydist, axis_label = axes_signs[i][0], axes_signs[i][1], axes_signs[i][2]
            plt.text(xdist, ydist, axis_label, color='white', fontsize=10)
            
        patch_list = self.take_patches()
                  
        # Add the patch to the Axes
        for patch in patch_list:
            ax.add_patch(patch)

        if self.save_figure:
            file_name = self.situation + self.case_id  
            file_name += '_crd_' + str(ind0) + '_' + str(ind1) + '_' +str(ind2) + '.png'
            if file_not_exist:
                plt.savefig(os.path.join(self.folder_path, file_name))

        fig, ax = plt.subplots(1)
        focused_image = img_slice[int(ind1-step):int(ind1+step+1), int(ind2-step):int(ind2+step+1)]

        for i in range (0, len(focused_image)):
            for j in range (0, len(focused_image[0])):
                plt.text(i-0.25, j+0.05, str(int(focused_image[i][j])), color='red', fontsize=8)
        
        ### add patch for showing center pixel
        circle = patches.Circle((step, step), 0.3, facecolor="yellow", alpha = 0.3)
        # Add the patch to the Axes
        ax.add_patch(circle)
        
        if self.ROI_shape == 'circle':
            circle = patches.Circle((step, step), step, linewidth=1,
                                     edgecolor='blue', facecolor="none")
            ax.add_patch(circle)
            

        # Display the image
        ax.imshow(focused_image, interpolation='none',  cmap='gray', 
                  vmin = np.min(img_slice), vmax = np.max(img_slice))

        ax.set(yticks=np.arange(0, (step * 2) + 1 , 1), yticklabels= np.arange(ind1 - step, ind1 + step + 1, 1))
        ax.set(xticks=np.arange(0, (step * 2) + 1 , 1), xticklabels= np.arange(ind2 - step, ind2 + step + 1, 1))

        if self.save_figure:
            file_name = self.situation + self.case_id  
            file_name += '_crd_' + str(ind0) + '_' + str(ind1) + '_' +str(ind2) + 'focus.png' 
            if file_not_exist:
                plt.savefig(os.path.join(self.folder_path, file_name))

        plt.show()

In [None]:
def find_coord_labels(axes, img_shape):
    """
    
    """
    result = []
    ax_2 = {'L': 'R', 'A':'P', 'S':'I'}
    xlabel = axes[1]
    ylabel = axes[2]
    dist0 = int(img_shape[0] / 2)
    dist1 = int(img_shape[1] / 2) 
    x_list = [dist1, img_shape[0]-5, xlabel]
    x1_list = [dist1, 5, ax_2[xlabel]]
    y_list = [img_shape[1]-5, dist0, ylabel] 
    y1_list = [5, dist0, ax_2[ylabel]] 
    result = [x_list, x1_list, y_list, y1_list]
    return(result)

In [None]:
def file_not_exist(file_path):
    
    not_exist = True
    
    if os.path.exists(file_path):
        not_exist = False
        
    return(not_exist)

In [8]:
def take_pkl_file_path(feature_vals, feature_output_folder):
    """

    """

    features_list = [str(feat) for feat in feature_vals] + ['.pkl']
    file_name = "-".join(features_list)
    file_path = os.path.join(feature_output_folder, file_name)

    return(file_path) 

In [None]:
def take_case_matching_list(case_info_excel_path, control_id, patient_id):
    """
    
    """

    case_info = pd.read_excel(case_info_excel_path, sheet_name = None)
    sheet_names = list(case_info.keys())
    control_sheets = [sheet for sheet in sheet_names if control_id in sheet]
    control_MRI_types = [sheet.split(control_id + '_')[-1] for sheet in control_sheets]
    all_control_original_names = [list(case_info[sheet]['input_name'].values) for sheet in control_sheets]
    all_control_original_names = sorted(set().union(*all_control_original_names), key=str.lower)
    controls_number = len(all_control_original_names)
    patient_sheets = [sheet for sheet in sheet_names if patient_id in sheet]
    patient_MRI_types = [sheet.split(patient_id + '_')[-1] for sheet in patient_sheets]
    all_patient_original_names = [list(case_info[sheet]['input_name'].values) for sheet in patient_sheets]
    all_patient_original_names = sorted(set().union(*all_patient_original_names), key=str.lower)
    patients_number = len(all_patient_original_names)
    mutual_MRI_type = [MRI_type for MRI_type in control_MRI_types if MRI_type in patient_MRI_types]
    case_info_df = pd.DataFrame()
    case_info_df['input_name'] = all_control_original_names + all_patient_original_names
    case_info_df['case_id'] = ['control' + str(i+1) for i in range(controls_number)] + ['patient' + str(i+1) for i in range(patients_number)]
    case_info_df['situation'] = [control_id] * controls_number + [patient_id] * patients_number
    
    for i, MRI_type in enumerate(mutual_MRI_type):
        df_control = case_info[control_id + '_' + MRI_type][['input_name', 'case_name']]
        df_patient = case_info[patient_id + '_' + MRI_type][['input_name', 'case_name']]
        df = pd.concat([df_control, df_patient], ignore_index=True)
        df = df.rename(columns={"case_name": MRI_type})
        case_info_df = case_info_df.merge(df, how = 'outer', on = 'input_name')
    case_info_df.drop(['input_name'], axis=1, inplace=True)    
    return(case_info_df)



In [2]:
def take_diagnosis_results(features_df, feature_output_folder, case_info_excel_path, control_id, patient_id,
                           sequences_dict = {}, not_include_features_indexes = [],
                          save_data = False, saving_path = ''):

    ## read features parameters
    case_MRI_type_matching_df = take_case_matching_list(case_info_excel_path, control_id, patient_id)
    study_result_df = case_MRI_type_matching_df[['case_id']].copy()
    all_case_number = len(study_result_df)

    ## take each param results
    for feat_num in range(len(features_df)):
        
        if feat_num not in not_include_features_indexes:
        
            feature_params = features_df.iloc[feat_num, :6].values
            data_type = feature_params[0]
            threshold = features_df['threshold'].iloc[feat_num]
            feature_id = 'parameter_' + str(feat_num+1)
            collective_analysis = False
            if data_type in sequences_dict.keys():
                collective_analysis = True
            study_result_df[feature_id] = [-1] * len(study_result_df)
            column_id = study_result_df.columns.get_loc(feature_id)
            feature_pkl_path = take_pkl_file_path(feature_params, feature_output_folder)
            feature_values = pd.read_pickle(feature_pkl_path)
            positives = feature_values[['case', 'situation']][feature_values['value'] > threshold].drop_duplicates()
            all_cases = feature_values[['case', 'situation']].drop_duplicates()


            for case_num in range(len(all_cases)): 
                case, situation = list(all_cases.iloc[case_num].values)
                if collective_analysis:
                    data_type, case = case.split('-')
                    ind = case_MRI_type_matching_df[(case_MRI_type_matching_df[data_type] == case) & (case_MRI_type_matching_df['situation'] == situation)].index[0]
                    study_result_df.iloc[ind, column_id] = 0
                else:
                    ind = case_MRI_type_matching_df[(case_MRI_type_matching_df[data_type] == case) & (case_MRI_type_matching_df['situation'] == situation)].index[0]
                    study_result_df.iloc[ind, column_id] = 0


            for case_num in range(len(positives)): 
                case, situation = list(positives.iloc[case_num].values)
                if collective_analysis:
                    data_type, case = case.split('-')
                    ind = case_MRI_type_matching_df[(case_MRI_type_matching_df[data_type] == case) & (case_MRI_type_matching_df['situation'] == situation)].index[0]
                    study_result_df.iloc[ind, column_id] = 1
                else:
                    ind = case_MRI_type_matching_df[(case_MRI_type_matching_df[data_type] == case) & (case_MRI_type_matching_df['situation'] == situation)].index[0]
                    study_result_df.iloc[ind, column_id] = 1

    if save_data:
        study_result_df.to_csv(saving_path)
        
    return(study_result_df)

In [1]:
def figure_diagnosis(diagnosis_df, figure_size = (15, 8), marker_size = 50,
                     x_label_rotation = 90, colormap = 'viridis_r', show_grid = False,
                    save_figure = False, saving_path = '', bbox_to_anchor = (-0.08, -0.12)):
    """
    
    
    """
    diagnosis_df = diagnosis_df.set_index(['case_id'])
    case_list = diagnosis_df.index.to_list()

    feature_list = list(diagnosis_df.columns)
    columns = ['case_id', 'feature_id', 'diagnose']
    df = pd.DataFrame(columns = columns)
    for case in case_list:
        for feature in feature_list:
            temp_df = pd.DataFrame(data = [[case, feature, diagnosis_df[feature][case]]], columns = columns)
            df = pd.concat([df, temp_df])
    
    plt.figure(figsize=figure_size)
    scatter = plt.scatter(df['case_id'], df['feature_id'], c=df['diagnose'], 
                          cmap = colormap, marker=".", s = marker_size)
    
    
    
    plt.xticks(rotation = x_label_rotation)
    value_means = ['no data', 'normal','epilepsy']
    plt.legend(handles=scatter.legend_elements()[0], 
               labels=value_means,
               title="diagnose", loc='lower left', bbox_to_anchor=bbox_to_anchor)
    
    ax = scatter.axes
    ax.invert_yaxis()
    if show_grid:
        plt.grid(axis = 'x', color='y', linestyle='-', linewidth=0.5, alpha = 0.2)
    if save_figure:
        plt.save_figure(saving_path)
    plt.show()