In [None]:
from nimlab import datasets as nimds
import numpy as np
from nilearn import image, plotting, maskers
import nibabel as nib
import os
import pandas as pd
import glob
import platform

In [None]:
## Paths Input Here
analysis = "r_map_and_memory_map_conjunction"
if platform.uname().system == 'Darwin': #------------------------------Mac OS X---------------------------------------------------------------
    conn_path1 = r'/Users/cu135/Dropbox (Partners HealthCare)/memory/functional_networks/published_composite_networks/0Fx-DBS-Network_N46.nii'
    conn_path2 = r'/Users/cu135/Dropbox (Partners HealthCare)/memory/functional_networks/published_composite_networks/0memory_network_T_map.nii'
    out_dir = os.path.join(os.path.dirname(conn_path2), f'{analysis}')
    #out_dir = r'path to out dir here'
    print('I have set pathnames in the Mac style')
else: #----------------------------------------------------------------Windows----------------------------------------------------------------
    conn_path = r'C:\Users\calvin.howard\Dropbox (Partners HealthCare)\memory\analyses\roi-roi_correl\matrix_corrMx_AvgR.csv'
    clin_path = r'C:\Users\calvin.howard\Dropbox (Partners HealthCare)\memory\patient_data\AD_Clinical_Data_CDR_ADAS_COG_13.xlsx'
    # clin_path = 'path to clinical values'
    out_dir = r'C:\Users\calvin.howard\Dropbox (Partners HealthCare)\memory\analyses\roi-roi_correl\stats'
    #out_dir = r'path to out dir here'
    x_roi_names = r'C:\Users\calvin.howard\Dropbox (Partners HealthCare)\memory\analyses\roi-roi_correl\matrix_corrMx_names.csv'
    #roi_names = '<path to roi name location>'
    print('I have set pathnames in the Windows style')

In [None]:
## Perform Conjunction Analysios

import nibabel as nib
import numpy as np
from calvin_utils.fisher_z_transform import fisher_z_transform
from scipy.ndimage import measurements
def conjunction_analysis(file1, file2, f1='t', f2='t', absval=True, binarize=True, thresh_by_stdev=True, thresh_by_quantile=False, threshold_stdevs=2, threshold_quantile=0, threshold_voxels=10, conjunction_type='minimum', flip_matrix=False):
    # Load the NIFTI files using nibabel
    img1 = nib.load(file1)
    img2 = nib.load(file2)

    # Extract the data arrays from the NIFTI objects
    data1 = img1.get_fdata(); data1 = np.nan_to_num(data1, nan=0, posinf=10, neginf=-10)
    data2 = img2.get_fdata(); data2 = np.nan_to_num(data2, nan=0, posinf=10, neginf=-10)
    
    if flip_matrix:
        data1 = data1*(-1)
        print('I will flip matrix 1, multiplying by -1')
        
    #Fisher Z transform R values 
    if f1 == 'r':
        data1 = fisher_z_transform(data1)
        data1 = data1
        print('I will Fisher Z transform Data 1')
    if f2 =='r':
        print('I will Fisher Z transform Data 2')
        data2 = fisher_z_transform(data2)

    # Calculate the threshold as 2 standard deviations for each data array
    if thresh_by_stdev:
        threshold1 = threshold_stdevs * np.std(data1)
        print('I will threshold data1 by ', str(threshold1), f' which is {threshold_stdevs} standard deviations)')
        threshold2 = threshold_stdevs * np.std(data2)
        print('I will threshold data2 by ', str(threshold2), f' which is {threshold_stdevs} standard deviations)')
            #Convert the matrices to their absolute values (do this if you don't know the effect of one on the other)
        if absval:
            print('I will take the absolute values of the data')
            data1 = np.abs(data1)
            data2 = np.abs(data2)
        data1[data1 < threshold1] = 0
        data2[data2 < threshold2] = 0
        
    if thresh_by_quantile:
        #Convert the matrices to their absolute values (do this if you don't know the effect of one on the other)
        if absval:
            print('I will take the absolute values of the data')
            data1 = np.abs(data1)
            data2 = np.abs(data2)
        threshold1 = np.quantile(data1, threshold_quantile)
        threshold2 = np.quantile(data2, threshold_quantile)
        data1[data1 < threshold1] = 0
        data2[data2 < threshold2] = 0

    # Perform the conjunction analysis using the specified type
    if conjunction_type == 'minimum':
        conjunction_data = np.minimum(data1, data2)
    elif conjunction_type == 'product':
        conjunction_data = data1 * data2
        print('I have conjoined by product')
    else:
        raise ValueError("Invalid conjunction type")

    if binarize:
        conjunction_data = np.where(conjunction_data > 0, 1, 0)
        print('I will binarize the data')
    # Create a new NIFTI object for the conjunction data
        
    #Eliminate voxels of insufficient volume
    if threshold_voxels is not None:
        #Set the dimensions of the convoluting kernel
        kernel = np.ones((3,3,3))
        binary_mask = conjunction_data
        convolution = []
        for x in range(0, binary_mask.shape[0]): #iterate through all the x xalues
            for y in range(0, binary_mask.shape[1]): #iterate through all the y values
                for z in range(0, binary_mask.shape[2]): #iterate through all the z values
                    try: #If there are 9 voxels, do this
                        # Perform convolution with the binary mask and the kernel, appending to array.
                        convolution.append(np.squeeze(np.array([
                            np.sum(np.multiply(binary_mask[x:x+3, y:y+3, z:z+3], kernel))
                        ])))
                    except: # if there are <9 voxels, do this. 
                        convolution.append(0)
        # Reshape the convolved array to match the shape of the original data
        convolution = np.array(convolution)
        convolution = convolution.reshape(binary_mask.shape[0], binary_mask.shape[1], binary_mask.shape[2])
        # Set regions equal to less than 10 to 0
        binary_mask[convolution < threshold_voxels] = 0    
        conjunction_data = binary_mask
        print('Data threshold to voxel volume of : ', threshold_voxels)
    
    print('num voxels d1: ', np.count_nonzero(data1), 'percent survivors: ', np.count_nonzero(conjunction_data)/np.count_nonzero(data1))
    print('num voxels d2: ', np.count_nonzero(data2), 'percent survivors: ', np.count_nonzero(conjunction_data)/np.count_nonzero(data2))
    print('Surviving voxels: ', np.count_nonzero(conjunction_data))
    print('Dice coefficient is: ', np.count_nonzero(conjunction_data)*2 / )
    conjunction_matrix = conjunction_data #nib.Nifti1Image(conjunction_data, img1.affine, img1.header)

    # Return the conjunction image
    return np.reshape(conjunction_matrix, data1.shape)


In [None]:
dtype1 = 'r'; dtype2 = 't'
threshold_quantile = 0.96; thresh_by_quantile=True
stdev_thresh = 3; thresh_by_stdev=False
vox_thresh = None
absval = False #take positive and negative values
flip_matrix=False #make one matrix negative

#-----------------USER INPUT----------------
conjunction_matrix = conjunction_analysis(conn_path1, conn_path2, f1=dtype1, f2=dtype2, absval=absval, 
                                         threshold_stdevs=stdev_thresh, thresh_by_stdev=thresh_by_stdev,
                                         threshold_quantile=threshold_quantile, thresh_by_quantile=thresh_by_quantile,
                                         threshold_voxels=vox_thresh, 
                                         conjunction_type='product', binarize=False, flip_matrix=flip_matrix)
if type(stdev_thresh) == float:
    stdev_name = str(stdev_thresh).split('.')[0]
    stdev_name = stdev_name + '-point-' + str(stdev_thresh).split('.')[1]
else:
    stdev_name = stdev_thresh

In [None]:
mask = nimds.get_img("mni_icbm152")
print(type(conjunction_matrix), type(conjunction_matrix[1][1]))
ovr_img1 = image.new_img_like(mask, conjunction_matrix)
ovr_html1 = plotting.view_img(ovr_img1, cut_coords=(0,0,0), title=(f'conjunction'), black_bg=False, opacity=.75, cmap='ocean_hot')
ovr_html1

In [None]:
#Identify the clusters using convolution
import scipy.ndimage as ndimage 
import warnings
warnings.filterwarnings('ignore')
#----------------------------------------------------------------user input----------------------------------------------------------------
def convolve_extract_clusters(nifti_image, save_clusters=False):
    conjunction_data = nifti_image.get_fdata()

    kernel = np.ones((3,3,3))
    binary_mask = conjunction_data
    convolution = []
    for x in range(0, binary_mask.shape[0]): #iterate through all the x xalues
        for y in range(0, binary_mask.shape[1]): #iterate through all the y values
            for z in range(0, binary_mask.shape[2]): #iterate through all the z values
                try: #If there are 9 voxels, do this
                    # Perform convolution with the binary mask and the kernel, appending to array.
                    convolution.append(np.squeeze(np.array([
                        np.sum(np.multiply(binary_mask[x:x+3, y:y+3, z:z+3], kernel))
                    ])))
                except: # if there are <9 voxels, do this. 
                    convolution.append(0)
    # Reshape the convolved array to match the shape of the original data
    convolution = np.array(convolution)
    c, numc = ndimage.measurements.label(convolution)
    convolution = convolution.reshape(binary_mask.shape[0], binary_mask.shape[1], binary_mask.shape[2])
    c, numc = ndimage.measurements.label(convolution)
    print('Number clusters: ', numc)

    mask = nimds.get_img("mni_icbm152")
    cluster_dict = {}
    for i in range(1, numc+1):
        cluster_dict[i] = np.where(c == i, 1, 0)
        print(f'Cluster {i} is size: {np.sum(cluster_dict[i])}')
        print(cluster_dict[i].shape)
        if save_clusters:
            cluster_out = out_dir + f'/clustered_rois_{stdev_name}_standard_deviations'
            analysis_name = f'{analysis}_stdevthresh_{stdev_name}_voxthresh_{vox_thresh}_absval_{absval}'
            savename = analysis_name + '_cluster_' + str(i)
            
            if os.path.isdir(cluster_out) != True:
                os.mkdir(cluster_out)
            
            cluster_img = image.new_img_like(mask, cluster_dict[i]);
            cluster_html = plotting.view_img(cluster_img, cut_coords=(0,0,0), title=(f'cluster_{i}'), black_bg=False, opacity=.75, cmap='ocean_hot');
            cluster_img.to_filename(os.path.join(cluster_out, f'{savename}'));
            cluster_html.save_as_html(os.path.join(cluster_out, f'{savename}.html'));

            print('File: ' + savename)
            print('saved to: ', cluster_out)
        # cluster_matrix = cluster.reshape(-1, np.prod(cluster.shape)).T
    print(type(cluster_dict[1]))
    return cluster_dict

In [None]:
##Save your nifti and html
if thresh_by_stdev:
    analysis_name = f'conjunction_stdevthresh_{stdev_name}_voxthresh_{vox_thresh}_absval_{absval}'
elif thresh_by_quantile:
    analysis_name = f'conjunction_percthresh_{int(threshold_quantile*100)}_voxthresh_{vox_thresh}_absval_{absval}'
else:
    analysis_name = 'arb'
print(analysis_name)

In [None]:
#----------------------------------------------------------------Save if desired----------------------------------------------------------------
if os.path.isdir(out_dir)==False:
    os.makedirs(out_dir)
savename = analysis_name 
#Save
ovr_img1.to_filename(os.path.join(out_dir, f'{analysis_name}.nii'))
ovr_html1.save_as_html(os.path.join(out_dir, f'{analysis_name}.html'))

print('Title: ' + savename)
print('saved to: ', out_dir)

In [None]:
# Extract the clusters
convolve_extract_clusters(ovr_img1, save_clusters=True)