# 00 Import libraries

In [None]:
import numpy as np
import json
import torch
from scipy.stats import ttest_ind, levene
import os
import seaborn as sns
import matplotlib.pyplot as plt
import nibabel as nib
from nilearn import plotting
import pandas as pd
from statsmodels.stats.multitest import fdrcorrection

# 01 load coordinates

In [None]:
data = nib.load('..data/atlas/HCPMMP1_for_ABIDE.nii.gz') # original

In [None]:
coordinates = plotting.find_parcellation_cut_coords(labels_img=data)

In [None]:
coordinates.shape

# 02 load interpretability

In [None]:
healthy_path = # YOUR PATH HERE (denoted as target 0)
sub_list = list(set([i.split('_')[0] for i in os.listdir(healthy_path)]))

In [None]:
high_contribution = np.zeros((len(sub_list), 180, 180))
low_contribution = np.zeros((len(sub_list), 180, 180))
ultralow_contribution = np.zeros((len(sub_list), 180, 180))


for i, name in enumerate(sub_list):
    activation_path = healthy_path+f'/{name}_att_mat_activation.json'
    gradient_path = healthy_path+f'/{name}_att_mat_gradient.json'
    with open(activation_path, 'r') as file:
        activations = json.load(file)
    with open(gradient_path, 'r') as file:
        gradients = json.load(file)
    
    # high
    high_act = activations['high_spatial_attention']
    high_grad = gradients['high_spatial_attention']
    high_act_mean = torch.mean(torch.tensor(high_act), dim=0)
    high_grad_mean = torch.mean(torch.tensor(high_grad).squeeze(dim=0), dim=0)
    contributions = torch.matmul(high_act_mean, high_grad_mean)
    high_contribution[i, :, :] = contributions
    
    # low
    low_act = activations['low_spatial_attention']
    low_grad = gradients['low_spatial_attention']
    low_act_mean = torch.mean(torch.tensor(low_act), dim=0)
    low_grad_mean = torch.mean(torch.tensor(low_grad).squeeze(dim=0), dim=0)
    contributions = torch.matmul(low_act_mean, low_grad_mean)
    low_contribution[i, :, :] = contributions
    
    # ultralow
    ultralow_act = activations['ultralow_spatial_attention']
    ultralow_grad = gradients['ultralow_spatial_attention']
    ultralow_act_mean = torch.mean(torch.tensor(ultralow_act), dim=0)
    ultralow_grad_mean = torch.mean(torch.tensor(ultralow_grad).squeeze(dim=0), dim=0)
    contributions = torch.matmul(ultralow_act_mean, ultralow_grad_mean)
    ultralow_contribution[i, :, :] = contributions

In [None]:
ASD_path = # YOUR PATH HERE (denoted as target 1)
sub_list = list(set([i.split('_')[0] for i in os.listdir(ASD_path)]))

In [None]:
high_contribution_ASD = np.zeros((len(sub_list), 180, 180))
low_contribution_ASD = np.zeros((len(sub_list), 180, 180))
ultralow_contribution_ASD = np.zeros((len(sub_list), 180, 180))


for i, name in enumerate(sub_list):
    activation_path = ASD_path+f'/{name}_att_mat_activation.json'
    gradient_path = ASD_path+f'/{name}_att_mat_gradient.json'
    with open(activation_path, 'r') as file:
        activations = json.load(file)
    with open(gradient_path, 'r') as file:
        gradients = json.load(file)
    
    # high
    high_act = activations['high_spatial_attention']
    high_grad = gradients['high_spatial_attention']
    high_act_mean = torch.mean(torch.tensor(high_act), dim=0)
    high_grad_mean = torch.mean(torch.tensor(high_grad).squeeze(dim=0), dim=0)
    contributions = torch.matmul(high_act_mean, high_grad_mean)
    high_contribution_ASD[i, :, :] = contributions
    
    # low
    low_act = activations['low_spatial_attention']
    low_grad = gradients['low_spatial_attention']
    low_act_mean = torch.mean(torch.tensor(low_act), dim=0)
    low_grad_mean = torch.mean(torch.tensor(low_grad).squeeze(dim=0), dim=0)
    contributions = torch.matmul(low_act_mean, low_grad_mean)
    low_contribution_ASD[i, :, :] = contributions
    
    # ultralow
    ultralow_act = activations['ultralow_spatial_attention']
    ultralow_grad = gradients['ultralow_spatial_attention']
    ultralow_act_mean = torch.mean(torch.tensor(ultralow_act), dim=0)
    ultralow_grad_mean = torch.mean(torch.tensor(ultralow_grad).squeeze(dim=0), dim=0)
    contributions = torch.matmul(ultralow_act_mean, ultralow_grad_mean)
    ultralow_contribution_ASD[i, :, :] = contributions

# 03 t-test between ASD and HC

## 3-1 High frequency

In [None]:
p_values_high = np.zeros((180, 180))
t_stats_high = np.zeros((180, 180))
cohens_d_high = np.zeros((180, 180))
for i in range(180):
    for j in range(180):
        _, p_levene = levene(high_contribution[:, i, j], high_contribution_ASD[:, i, j])
        equal_var = True if p_levene > 0.05 else False
        t_stat, p_value = ttest_ind(high_contribution[:, i, j], high_contribution_ASD[:, i, j], equal_var=equal_var)
        cohens_d = (np.mean(high_contribution[:, i, j]) - np.mean(high_contribution_ASD[:, i, j])) / np.sqrt(((len(high_contribution[:, i, j]) - 1) * np.var(high_contribution[:, i, j]) + (len(high_contribution_ASD[:, i, j]) - 1) * np.var(high_contribution_ASD[:, i, j])) / (len(high_contribution[:, i, j]) + len(high_contribution_ASD[:, i, j]) - 2))
        p_values_high[i, j] = p_value
        t_stats_high[i, j] = t_stat
        cohens_d_high[i, j] = np.abs(cohens_d)
        
# FDR correction
p_values_flat = p_values_high.flatten()
_, p_values_corrected_flat = fdrcorrection(p_values_flat, alpha=0.05)
p_values_high_corrected = p_values_corrected_flat.reshape(180, 180)

In [None]:
# statistically significant connectivity which are responsible for distinguishing two groups
filtered_matrix_high = np.where(p_values_high_corrected <= 0.05, 1, 0)
sns.heatmap(filtered_matrix_high)

In [None]:
# effect size
sns.heatmap(cohens_d_high)

In [None]:
sns.heatmap(cohens_d_high*filtered_matrix_high)

In [None]:
t_sign_high = np.where(t_stats_high < 0, -1, 1)
sns.heatmap(cohens_d_high*filtered_matrix_high*t_sign_high)

## 3-2 low frequency

In [None]:
p_values_low = np.zeros((180, 180))
t_stats_low = np.zeros((180, 180))
cohens_d_low = np.zeros((180, 180))
for i in range(180):
    for j in range(180):
        _, p_levene = levene(low_contribution[:, i, j], low_contribution_ASD[:, i, j])
        equal_var = True if p_levene > 0.05 else False
        t_stat, p_value = ttest_ind(low_contribution[:, i, j], low_contribution_ASD[:, i, j], equal_var=equal_var)
        cohens_d = (np.mean(low_contribution[:, i, j]) - np.mean(low_contribution_ASD[:, i, j])) / np.sqrt(((len(low_contribution[:, i, j]) - 1) * np.var(low_contribution[:, i, j]) + (len(low_contribution_ASD[:, i, j]) - 1) * np.var(low_contribution_ASD[:, i, j])) / (len(low_contribution[:, i, j]) + len(low_contribution_ASD[:, i, j]) - 2))
        p_values_low[i, j] = p_value
        t_stats_low[i, j] = t_stat
        cohens_d_low[i, j] = np.abs(cohens_d)
        
# FDR correction
p_values_flat = p_values_low.flatten()
_, p_values_corrected_flat = fdrcorrection(p_values_flat, alpha=0.05)
p_values_low_corrected = p_values_corrected_flat.reshape(180, 180)

In [None]:
# statistically significant connectivity which are responsible for distinguishing two groups
filtered_matrix_low = np.where(p_values_low_corrected <= 0.05, 1, 0)
sns.heatmap(filtered_matrix_low)

In [None]:
# effect size
sns.heatmap(cohens_d_low)

In [None]:
sns.heatmap(cohens_d_low*filtered_matrix_low)

In [None]:
t_sign_low = np.where(t_stats_low < 0, -1, 1) # reversed
sns.heatmap(cohens_d_low*filtered_matrix_low*t_sign_low)

## 3-3 ultralow frequency

In [None]:
p_values_ultralow = np.zeros((180, 180))
t_stats_ultralow = np.zeros((180, 180))
cohens_d_ultralow = np.zeros((180, 180))
for i in range(180):
    for j in range(180):
        _, p_levene = levene(ultralow_contribution[:, i, j], ultralow_contribution_ASD[:, i, j])
        equal_var = True if p_levene > 0.05 else False
        t_stat, p_value = ttest_ind(ultralow_contribution[:, i, j], ultralow_contribution_ASD[:, i, j], equal_var=equal_var)
        cohens_d = (np.mean(ultralow_contribution[:, i, j]) - np.mean(ultralow_contribution_ASD[:, i, j])) / np.sqrt(((len(ultralow_contribution[:, i, j]) - 1) * np.var(ultralow_contribution[:, i, j]) + (len(ultralow_contribution_ASD[:, i, j]) - 1) * np.var(ultralow_contribution_ASD[:, i, j])) / (len(ultralow_contribution[:, i, j]) + len(ultralow_contribution_ASD[:, i, j]) - 2))
        p_values_ultralow[i, j] = p_value
        t_stats_ultralow[i, j] = t_stat
        cohens_d_ultralow[i, j] = np.abs(cohens_d)
        
# FDR correction
p_values_flat = p_values_ultralow.flatten()
_, p_values_corrected_flat = fdrcorrection(p_values_flat, alpha=0.05)
p_values_ultralow_corrected = p_values_corrected_flat.reshape(180, 180)

In [None]:
# statistically significant connectivity which are responsible for distinguishing two groups
filtered_matrix_ultralow = np.where(p_values_ultralow_corrected <= 0.05, 1, 0)
sns.heatmap(filtered_matrix_ultralow)

In [None]:
# effect size
sns.heatmap(cohens_d_ultralow)

In [None]:
sns.heatmap(cohens_d_ultralow*filtered_matrix_ultralow)

In [None]:
t_sign_ultralow = np.where(t_stats_ultralow < 0, -1, 1)
sns.heatmap(cohens_d_ultralow*filtered_matrix_ultralow*t_sign_ultralow)

# 04 Load atlas meta data

In [None]:
# all hemispheres
atlas_info = pd.read_csv('../data/coordinates/HCP-MMP1_UniqueRegionList.csv', encoding='unicode_escape')

In [None]:
atlas_info = atlas_info.replace(r'\n', ' ', regex=True)

In [None]:
atlas_info

# 05 Visualization on a glass brain

## 5-1 High frequency

In [None]:
significant_elements = cohens_d_high*filtered_matrix_high
flattened_indices = np.argsort(-significant_elements, axis=None)[:100]
filtered_matrix_high_mask = np.zeros_like(filtered_matrix_high)
rows, cols = np.unravel_index(flattened_indices, filtered_matrix_high.shape)

for row, col in zip(rows, cols):
    if t_sign_high[row, col]==abs(t_sign_high[row, col]):
        filtered_matrix_high_mask[row, col] = 1  
    else:
        filtered_matrix_high_mask[row, col] = -1

In [None]:
sns.heatmap(filtered_matrix_high_mask)

In [None]:
# original -> symmetric

for i, j in zip(rows, cols):
    p_value = p_values_high[i, j]
    cohen = cohens_d_high[i, j]
    row = atlas_info[atlas_info['regionID'] == i+1][['regionLongName']].values[0][0]
    col = atlas_info[atlas_info['regionID'] == j+1][['regionLongName']].values[0][0]
    if filtered_matrix_high_mask[i, j] > 0:
        description = 'HC'
    elif filtered_matrix_high_mask[i, j] < 0:
        description = 'ASD'
    else:
        description = 'None'
    print(row.replace('_', ' ')[:-2]+','+col.replace('_', ' ')[:-2]+','+description+','+str(round(p_value, 3))+','+str(round(cohen, 3)))

In [None]:
answer = np.zeros((360, 360))
answer[:180, :180] = filtered_matrix_high_mask
answer[180:, 180:] = filtered_matrix_high_mask
sns.heatmap(answer)

view = plotting.view_connectome(answer *(-1),coordinates, node_size=3.0)
view.save_as_html('reversed_color_symmetric_ASD_ROI_180_high_freq_sign.html')

## 02 low

In [None]:
significant_elements = cohens_d_low*filtered_matrix_low
flattened_indices = np.argsort(-significant_elements, axis=None)[:100]

filtered_matrix_low_mask = np.zeros_like(filtered_matrix_low)

rows, cols = np.unravel_index(flattened_indices, filtered_matrix_low.shape)

for row, col in zip(rows, cols):
    if t_stats_low[row, col]==abs(t_stats_low[row, col]):
        filtered_matrix_low_mask[row, col] = 1
    else:
        filtered_matrix_low_mask[row, col] = -1

In [None]:
sns.heatmap(filtered_matrix_low_mask)

In [None]:
# original -> symmetric

for i, j in zip(rows, cols):
    p_value = p_values_low[i, j]
    cohen = cohens_d_low[i, j]
    row = atlas_info[atlas_info['regionID'] == i+1][['regionLongName']].values[0][0]
    col = atlas_info[atlas_info['regionID'] == j+1][['regionLongName']].values[0][0]
    if filtered_matrix_low_mask[i, j] > 0:
        description = 'HC'
    elif filtered_matrix_low_mask[i, j] < 0:
        description = 'ASD'
    else:
        description = 'None'
    if p_value <= 0.05:
        print(row.replace('_', ' ')[:-2]+','+col.replace('_', ' ')[:-2]+','+description+','+str(round(p_value, 3))+','+str(round(cohen, 3)))

In [None]:
answer = np.zeros((360, 360))
answer[:180, :180] = filtered_matrix_low_mask
answer[180:, 180:] = filtered_matrix_low_mask
sns.heatmap(answer)

view = plotting.view_connectome(answer*(-1),coordinates, node_size=3.0)
view.save_as_html('reversed_color_symmetric_ASD_ROI_180_low_freq_sign.html')

## 5-3 ultralow

In [None]:
significant_elements = cohens_d_ultralow*filtered_matrix_ultralow
flattened_indices = np.argsort(-significant_elements, axis=None)[:100]

filtered_matrix_ultralow_mask = np.zeros_like(filtered_matrix_ultralow)

rows, cols = np.unravel_index(flattened_indices, filtered_matrix_ultralow.shape)

for row, col in zip(rows, cols):
    print(significant_elements[row, col], t_sign_ultralow[row, col])
    if t_sign_ultralow[row, col]==abs(t_sign_ultralow[row, col]):
        filtered_matrix_ultralow_mask[row, col] = 1
    else:
        filtered_matrix_ultralow_mask[row, col] = -1

In [None]:
sns.heatmap(filtered_matrix_ultralow_mask)

In [None]:
# original -> symmetric

for i, j in zip(rows, cols):
    p_value = p_values_ultralow[i, j]
    cohen = cohens_d_ultralow[i, j]
    row = atlas_info[atlas_info['regionID'] == i+1][['regionLongName']].values[0][0]
    col = atlas_info[atlas_info['regionID'] == j+1][['regionLongName']].values[0][0]
    if filtered_matrix_ultralow_mask[i, j] > 0:
        description = 'HC'
    elif filtered_matrix_ultralow_mask[i, j] < 0:
        description = 'ASD'
    else:
        description = 'None'
    print(row.replace('_', ' ')[:-2]+','+col.replace('_', ' ')[:-2]+','+description+','+str(round(p_value, 3))+','+str(round(cohen, 3)))

In [None]:
answer = np.zeros((360, 360))
answer[:180, :180] = filtered_matrix_ultralow_mask
answer[180:, 180:] = filtered_matrix_ultralow_mask
sns.heatmap(answer)

view = plotting.view_connectome(answer*(-1),coordinates, node_size=3.0)
view.save_as_html('reversed_color_symmetric_ASD_ROI_180_ultralow_freq_sign.html')