In [1]:
random_seeds = [42, 100, 0, 10, 13, 20, 3, 18, 23, 105]
run = 0

In [2]:
username = 'meganorm-mznasrabadi'
datasets = {
    'BTNRH': {
        'base_dir': '/project/meganorm/Data/BTNRH/BIDS',
        'task': "task-rest",
        "ending" : "meg.fif"
    },
    'CAMCAN': {
        'base_dir': '/project/meganorm/Data/camcan/BIDS',
        'task': "task-rest",
        "ending" : "meg.fif"
    },
    'NIMH': {
        'base_dir': '/project/meganorm/Data/NIMH',
        'task': "task-rest",
        "ending" : "meg.ds"
    },
    'OMEGA': {
        'base_dir': '/project/meganorm/Data/Omega',
        'task': "task-rest",
        "ending" : "meg.ds"
    },
    'HCP': {
        'base_dir': '/project/meganorm/Data/HCP',
        'task': "",
        "ending" : "4-Restin/4D"
    }
    }

package_path = f'/home/{username}/MEGaNorm/'

In [3]:
import os
os.chdir(package_path)
from utils.parallel import submit_jobs, check_jobs_status, collect_results
from utils.nm import hbr_data_split, estimate_centiles, evaluate_mace, shapiro_stat, abnormal_probability
from plots.plots import plot_nm_range_site2, plot_comparison, plot_neurooscillochart, plot_age_dist2, plot_growthcharts, plot_quantile_gauge, box_plot_auc
from utils.nm import model_quantile_evaluation, calculate_oscilochart, prepare_prediction_data, cal_stats_for_gauge
from utils.IO import merge_datasets_with_regex, merge_fidp_demo, merge_datasets_with_glob
import pandas as pd
import json
from pcntoolkit.normative_parallel import execute_nm, rerun_nm, collect_nm
import warnings
import pickle  
import numpy as np
from pcntoolkit.util.utils import z_to_abnormal_p, anomaly_detection_auc
from scipy.stats import false_discovery_control
warnings.filterwarnings("ignore")

# Configuration

In [4]:
def make_config(project, path=None):

    # preprocess configurations =================================================
    # downsample data
    config = dict()

    # You could also set layout to None to have high 
    # choices: all, lobe, None
    config["which_layout"] = "all"

    # which sensor type should be used
    # choices: meg, mag, grad, eeg, opm
    config["which_sensor"] = "meg"
    # config['fs'] = 1000

    # ICA configuration
    config['ica_n_component'] = 30
    config['ica_max_iter'] = 800
    config['ica_method'] = "fastica"

    # lower and upper cutoff frequencies in a bandpass filter
    config['cutoffFreqLow'] = 1
    config['cutoffFreqHigh'] = 45

    config["resampling_rate"] = 1000
    config["digital_filter"] = True
    config["notch_filter"] = False

    config["apply_ica"] = True

    config["auto_ica_corr_thr"] = 0.9

    # options are "average", "REST", and None 
    config["rereference_method"]= "average"

    # variance threshold across time
    config["mag_var_threshold"] = 4e-12
    config["grad_var_threshold"] = 4000e-13
    config["eeg_var_threshold"] = 40e-6
    # flatness threshold across time
    config["mag_flat_threshold"] = 10e-15
    config["grad_flat_threshold"] = 10e-15
    config["eeg_flat_threshold"] = 40e-6
    # variance thershold across channels
    config["zscore_std_thresh"] = 15 # change this

    # segmentation ==============================================
    #start time of the raw data to use in seconds, this is to avoid possible eye blinks in close-eyed resting state. 
    config['segments_tmin'] = 20
    # end time of the raw data to use in seconds, this is to avoid possible eye blinks in close-eyed resting state.
    config['segments_tmax'] = -20
    # length of MEG segments in seconds
    config['segments_length'] = 10
    # amount of overlap between MEG sigals in seconds
    config['segments_overlap'] = 2

    # PSD ==============================================
    # Spectral estimation method
    config['psd_method'] = "welch"
    # amount of overlap between windows in Welch's method
    config['psd_n_overlap'] = 1
    config['psd_n_fft'] = 2
    # number of samples in psd
    config["psd_n_per_seg"] = 2

    # fooof analysis configurations ==============================================
    # Desired frequency range to run FOOOF
    config['fooof_freqRangeLow'] = 3
    config['fooof_freqRangeHigh'] = 40
    # which mode should be used for fitting; choices (knee, fixed)
    config["aperiodic_mode"] = "knee"
    # minimum acceptable peak width in fooof analysis
    config["fooof_peak_width_limits"] = [1.0, 12.0]
    #Absolute threshold for detecting peaks
    config['fooof_min_peak_height'] = 0
    #Relative threshold for detecting peaks
    config['fooof_peak_threshold'] = 2

    # feature extraction ==========================================================
    # Define frequency bands
    config['freq_bands'] = {
                            'Theta': (3, 8),
                            'Alpha': (8, 13),
                            'Beta': (13, 30),
                            'Gamma': (30, 40),
                            # 'Broadband': (3, 40)
                            }

    # Define individualized frequency range over main peaks in each freq band
    config['individualized_band_ranges'] = { 
                                            'Theta': (-2, 3),
                                            'Alpha': (-2, 3), # change to (-4,2)
                                            'Beta': (-8, 9),
                                            'Gamma': (-5, 5)
                                            }

    # least acceptable R squred of fitted models
    config['min_r_squared'] = 0.9 
 
    config['feature_categories'] = {
                                    "Offset":False,
                                    "Exponent":False,
                                    "Peak_Center":False,
                                    "Peak_Power":False,
                                    "Peak_Width":False,
                                    "Adjusted_Canonical_Relative_Power":True, 
                                    "Adjusted_Canonical_Absolute_Power":False,
                                    "Adjusted_Individualized_Relative_Power":False,
                                    "Adjusted_Individualized_Absolute_Power":False,
                                    "OriginalPSD_Canonical_Relative_Power":False, 
                                    "OriginalPSD_Canonical_Absolute_Power":False,
                                    "OriginalPSD_Individualized_Relative_Power":False,
                                    "OriginalPSD_Individualized_Absolute_Power":False,
                                    }
    
    config["fooof_res_save_path"] = None

    config["random_state"] = 42

    if path is not None:
        out_file = open(os.path.join(path, project + ".json"), "w") 
        json.dump(config, out_file, indent = 6) 
        out_file.close()

    return config 

In [5]:
project = "_natureCommunicationPaper"

project_dir = f'/home/{username}/Results/{project}/'

mainParallel_path = os.path.join(package_path, 'src', 'mainParallel.py')

features_dir = os.path.join(project_dir, 'Features')
features_log_path = os.path.join(features_dir, 'log')
features_temp_path = os.path.join(features_dir,'temp')

nm_processing_dir = os.path.join(project_dir, 'NM', 'Run_' + str(run))

job_configs = {'log_path':features_log_path, 'module':'mne', 'time':'1:00:00', 'memory':'20GB', 
                'partition':'normal', 'core':1, 'node':1, 'batch_file_name':'batch_job'}

if not os.path.isdir(features_log_path):
    os.makedirs(features_log_path)

if not os.path.isdir(features_temp_path):
    os.makedirs(features_temp_path)
    
if not os.path.isdir(nm_processing_dir):
    os.makedirs(nm_processing_dir)
    
configs = make_config(project, project_dir)

subjects = merge_datasets_with_glob(datasets)

# f-IDPs extraction

In [6]:
### Parallel feature extraction  

# # Running Jobs
# start_time = submit_jobs(mainParallel_path, features_dir, subjects, 
#                 features_temp_path, job_configs=job_configs, config_file=os.path.join(project_dir, project+'.json'))
# # Checking jobs
# failed_jobs = check_jobs_status(username, start_time)

# falied_subjects = {failed_job:subjects[failed_job] for failed_job in failed_jobs}

# while len(failed_jobs)>0:
#     # Re-running Jobs
#     start_time = submit_jobs(mainParallel_path, features_dir, falied_subjects, 
#                 features_temp_path, job_configs=job_configs, config_file=os.path.join(project_dir, project+'.json'))
#     # Checking jobs
#     failed_jobs = check_jobs_status(username, start_time)

# collect_results(features_dir, subjects, features_temp_path, file_name='all_features', clean=False)

# Train-Test split

In [7]:
### Data preparation for Normative Modeling
data_base_dirs = [values["base_dir"] for values in datasets.values()]
dataset_names = list(datasets.keys())
merged_data, data_patient = merge_fidp_demo(data_base_dirs, features_dir, dataset_names, include_patients=False)

biomarker_num = hbr_data_split(merged_data, nm_processing_dir, drop_nans=True, batch_effects=['sex', 'site'], random_seed=random_seeds[run], train_split=0.5)

biomarker_names = list(merged_data.columns[3:])

In [None]:
site = 2
print("Dataset name:", list(datasets.keys())[site])
print("size: ", merged_data[merged_data.site==site].shape[0], "participants")
print("mean age: ", merged_data[merged_data.site==site]["age"].mean())
print("std age: ", merged_data[merged_data.site==site]["age"].std())
print("female num: ", merged_data[np.logical_and(merged_data.site==site, merged_data.sex==1)].shape[0], "participants")

In [9]:
### Setting up NM configs

python_path = '/project/meganorm/Software/Miniconda3/envs/mne/bin/python' 

hbr_configs = {
                # 'homo_Gaussian_linear':{'model_type':'linear', 'likelihood':'Normal', 'linear_sigma':'False',
                #                         'random_slope_mu':'False', 'linear_epsilon':'False', 'linear_delta':'False'}, 
                # 'homo_Gaussian_bspline':{'model_type':'bspline', 'likelihood':'Normal', 'linear_sigma':'False',
                #                         'random_slope_mu':'False', 'linear_epsilon':'False', 'linear_delta':'False'}, 
                # 'homo_SHASH_linear':{'model_type':'linear', 'likelihood':'SHASHb', 'linear_sigma':'False',
                #                     'random_slope_mu':'False', 'linear_epsilon':'False', 'linear_delta':'False'}, 
                # 'homo_SHASH_bspline':{'model_type':'bspline', 'likelihood':'SHASHb', 'linear_sigma':'False',
                #                     'random_slope_mu':'False', 'linear_epsilon':'False', 'linear_delta':'False'}, 
                # 'hetero_Gaussian_linear':{'model_type':'linear', 'likelihood':'Normal', 'linear_sigma':'True',
                #                         'random_slope_mu':'False', 'linear_epsilon':'False', 'linear_delta':'False'},
                # 'hetero_Gaussian_bspline':{'model_type':'bspline', 'likelihood':'Normal', 'linear_sigma':'True',
                #                         'random_slope_mu':'False', 'linear_epsilon':'False', 'linear_delta':'False'},
                # 'hetero_SHASH_linear':{'model_type':'linear', 'likelihood':'SHASHb', 'linear_sigma':'True',
                #                     'random_slope_mu':'False', 'linear_epsilon':'True', 'linear_delta':'True'},
                'hetero_SHASH_bspline':{'model_type':'bspline', 'likelihood':'SHASHb', 'linear_sigma':'True',
                                        'random_slope_mu':'False', 'linear_epsilon':'True', 'linear_delta':'True'},
            }

inscaler='None' 
outscaler='None' 
batch_size = 1
outputsuffix = '_estimate'

respfile = os.path.join(nm_processing_dir, 'y_train.pkl')
covfile = os.path.join(nm_processing_dir, 'x_train.pkl')

testrespfile_path = os.path.join(nm_processing_dir, 'y_test.pkl')
testcovfile_path = os.path.join(nm_processing_dir, 'x_test.pkl')

trbefile = os.path.join(nm_processing_dir, 'b_train.pkl')
tsbefile = os.path.join(nm_processing_dir, 'b_test.pkl')

memory = '2gb'
duration = '5:00:00'
cluster_spec = 'slurm'

# Running NM

In [10]:
#for method in hbr_configs.keys():
method = 'hetero_SHASH_bspline'
processing_dir = os.path.join(nm_processing_dir, method) + '/'
nm_log_path = os.path.join(processing_dir, 'log') + '/'

if not os.path.isdir(processing_dir):
    os.makedirs(processing_dir)
if not os.path.isdir(nm_log_path):
    os.makedirs(nm_log_path)

# execute_nm(processing_dir, python_path,
#             'NM', covfile, respfile, batch_size, memory, duration, alg='hbr', 
#             log_path=nm_log_path, binary=True, testcovfile_path=testcovfile_path, 
#             testrespfile_path=testrespfile_path,trbefile=trbefile, tsbefile=tsbefile, 
#             model_type=hbr_configs[method]['model_type'], likelihood=hbr_configs[method]['likelihood'],  
#             linear_sigma=hbr_configs[method]['linear_sigma'], random_slope_mu=hbr_configs[method]['random_slope_mu'],
#             linear_epsilon=hbr_configs[method]['linear_epsilon'], linear_delta=hbr_configs[method]['linear_delta'], 
#             savemodel='True', inscaler=inscaler, outscaler=outscaler, outputsuffix=outputsuffix, 
#             interactive='auto', cluster_spec=cluster_spec, nuts_sampler="nutpie", n_cores_per_batch="2")

In [None]:
# collect_nm(processing_dir, "NM", collect=True, binary=True, batch_size=1)

In [12]:
import pandas as pd
from scipy.stats import skew, kurtosis

def aggregate_metrics_across_runs(path, method_name, biomarker_names, valcovfile_path,
                                valrespfile_path, valbefile,  metrics = ["skewness", "kurtosis", "W"], 
                                num_runs=10, quantiles=[0.01, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 0.99],
                                outputsuffix='estimate'):
    
    # index_labels = [metric + "_" + biomarker_name for metric in metrics for biomarker_name in biomarker_names]
    # df = pd.DataFrame(index=index_labels, columns=list(range(10)))
    data = {metric: {biomarker_name: [] for biomarker_name in biomarker_names} 
                                for metric in metrics}

    for run in range(num_runs):
        run_path = path.replace("Run_0", f"Run_{run}")
        with open(os.path.join(run_path, method_name, 'Z_estimate.pkl'), 'rb') as file:
            z_scores = pickle.load(file)
            
            for metric in metrics:
                values = []

                if metric == "MACE":
                    for ind in range(len(biomarker_names)):
                        values.append(evaluate_mace(os.path.join(run_path, method, 'Models'), valcovfile_path, 
                                                    valrespfile_path, valbefile, model_id=ind,
                                                    quantiles=quantiles,
                                                    outputsuffix=outputsuffix))
                        
                if metric == "W":
                    with open(os.path.join(run_path, 'x_test.pkl'), 'rb') as file:
                        cov = pickle.load(file)
                    values.extend(shapiro_stat(z_scores, cov))

                if metric == "skewness":
                    values.extend(skew(z_scores))
                
                if metric == "kurtosis":
                    values.extend(kurtosis(z_scores))

                for counter, name in enumerate(biomarker_names):
                    data[metric][name].append(values[counter])

    return data




# Model Diagnostic

In [13]:
metrics_values = aggregate_metrics_across_runs(nm_processing_dir, method, biomarker_names, testcovfile_path, 
                              testrespfile_path, tsbefile,
                               )
metrics_summary_path = "/home/meganorm-mznasrabadi/Results/_natureCommunicationPaper/NM/Run_0/hetero_SHASH_bspline/metrics_summary.pkl"
with open(metrics_summary_path, "wb") as file:
    pickle.dump(metrics_values, file)

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns
def plot_metrics(metrics_path, biomarker_names, which_features,
                  feature_new_name=[]):

    # Use valid hexadecimal colors
    colors = ["#9E6240", "#819595", "#5F0F40", "#0F4C5C"]
    
    with open(metrics_path, "rb") as file:
        metrics_dic = pickle.load(file)
    
    for metric in metrics_dic.keys():
        df_temp = pd.DataFrame(metrics_dic.get(metric)).loc[:, which_features]
        df_temp.columns = feature_new_name
        
        # Reshape the data for boxplot
        df_temp = df_temp.melt(var_name='Variable', value_name='Value')

        sns.set_theme(style="ticks", palette="pastel")
        
        # Use palette instead of color
        sns.boxplot(x='Variable', y='Value', data=df_temp, palette=colors)
        sns.despine(offset=0, trim=True)
        plt.xlabel("Frequency Bands")
        plt.ylabel(metric.title())

        plt.show()

which_features = ["Adjusted_Canonical_Relative_PowerTheta_all", "Adjusted_Canonical_Relative_PowerAlpha_all", 
                  "Adjusted_Canonical_Relative_PowerBeta_all", "Adjusted_Canonical_Relative_PowerGamma_all"]
plot_metrics(metrics_summary_path, biomarker_names, which_features, feature_new_name=["Theta", "Alpha", 
                                                                                        "Beta", "Gamma"])

In [None]:
metrics_values["kurtosis"]["Adjusted_Canonical_Relative_PowerBeta_all"]

a = "/home/meganorm-mznasrabadi/Results/_natureCommunicationPaper/NM/Run_9/hetero_SHASH_bspline/Z_estimate.pkl"

with open(a, 'rb') as file:
    z_scores = pickle.load(file)

sorted(z_scores.iloc[:,-2].tolist(), reverse=True)
# np.where(np.array(metrics_values.iloc[:,2].tolist())==min(metrics_values.iloc[:,2].tolist()))

In [4]:
### Evaluating quantiles using MACE

mace, best_models, bio_ids = model_quantile_evaluation(hbr_configs, nm_processing_dir, testcovfile_path, 
                              testrespfile_path, tsbefile, biomarker_num, plot=False, outputsuffix='estimate')

plot_comparison(nm_processing_dir, hbr_configs, biomarker_num)

# NM range with markers

In [None]:
# Plotting ranges
# # for config in hbr_configs.keys():
processing_path = os.path.join(nm_processing_dir, method)

q = estimate_centiles(processing_path, biomarker_num, quantiles=[0.05, 0.25, 0.5, 0.75, 0.95],
                        batch_map={0:{'Male':0, 'Female':1}, 1:{'BTNRH':0, 'CAMCAN':1, "NIMH":2, "OMEGA":3, "HCP":4}}, 
                        age_range=[6, 80])
plot_nm_range_site2(processing_path, nm_processing_dir)


# Age distribution (plot)

In [None]:


all_participants = pd.concat([merged_data, data_patient])

plot_age_dist2(all_participants, site_names=list(datasets.keys()), save_path="/home/meganorm-mznasrabadi/MEGaNorm/pics/")

# Test on clinical data

In [None]:
# random_seeds = [0]
for i in range(len(random_seeds)):

    nm_processing_dir_temp = nm_processing_dir.replace("Run_0", f"Run_{i}")
    processing_dir_temp = processing_dir.replace("Run_0", f"Run_{i}")

    prefix = "clinicalpredict_"
    prepare_prediction_data(data_patient.drop('diagnosis', axis=1),
                                nm_processing_dir_temp, 
                                drop_nans=True, 
                                batch_effects=['sex', 'site'], 
                                prefix=prefix)

    testrespfile_path = os.path.join(nm_processing_dir_temp, prefix + 'y_test.pkl')
    testcovfile_path = os.path.join(nm_processing_dir_temp, prefix + 'x_test.pkl')
    tsbefile = os.path.join(nm_processing_dir_temp, prefix + 'b_test.pkl')

    execute_nm(processing_dir_temp, python_path,
            'NM', testcovfile_path, testrespfile_path, batch_size, memory, duration, alg='hbr', 
            log_path=nm_log_path, binary=True, tsbefile=tsbefile, func="predict", 
            model_type=hbr_configs[method]['model_type'], likelihood=hbr_configs[method]['likelihood'],  
            linear_sigma=hbr_configs[method]['linear_sigma'], random_slope_mu=hbr_configs[method]['random_slope_mu'],
            linear_epsilon=hbr_configs[method]['linear_epsilon'], linear_delta=hbr_configs[method]['linear_delta'], 
            savemodel='True', inscaler=inscaler, outscaler=outscaler, outputsuffix="clinicalpredict", inputsuffix=outputsuffix,
            interactive='auto', cluster_spec=cluster_spec, nuts_sampler="nutpie", n_cores_per_batch="2")

# Abnormal probability index

In [80]:
def abnormal_probability(processing_dir, nm_processing_dir, site_id, n_permutation=1000):


    with open(os.path.join(processing_dir, "Z_clinicalpredict.pkl"), "rb") as file:
        z_patient = pickle.load(file)

    with open(os.path.join(processing_dir,"Z_estimate.pkl"), "rb") as file:
        z_healthy = pickle.load(file)

    with open(os.path.join(nm_processing_dir, "b_test.pkl"), "rb") as file:
        b_healthy = pickle.load(file)

    z_healthy = z_healthy.iloc[np.where(b_healthy["site"]==site_id)[0], :]

    # z_patient = pd.concat([z_patient, np.sqrt((z_patient.iloc[:, [0, 1, 2, 3]]**2).mean(axis=1))], axis=1)
    # z_healthy = pd.concat([z_healthy, np.sqrt((z_healthy.iloc[:, [0, 1, 2, 3]]**2).mean(axis=1))], axis=1)

    p_patient = z_to_abnormal_p(z_patient)
    p_healthy = z_to_abnormal_p(z_healthy)
    
    p_patient = np.hstack([p_patient, p_patient[:, [0, 2, 3]].mean(axis=1).reshape(-1, 1)])
    p_healthy = np.hstack([p_healthy, p_healthy[:, [0, 2, 3]].mean(axis=1).reshape(-1, 1)])

    p = np.concatenate([p_patient, p_healthy])
    labels = np.concatenate([np.ones(p_patient.shape[0]), np.zeros(p_healthy.shape[0])])

    auc, p_val = anomaly_detection_auc(p, labels, n_permutation=n_permutation)

    p_val = false_discovery_control(p_val)

    return p_val, auc

In [None]:
site_id = 3
p_vals, aucs = [], []

for i in range(len(random_seeds)):

    nm_processing_dir_temp = nm_processing_dir.replace("Run_0", f"Run_{i}")
    processing_dir_temp = processing_dir.replace("Run_0", f"Run_{i}")

    p_val, auc = abnormal_probability(processing_dir_temp,
                                    nm_processing_dir_temp, 
                                    site_id,
                                    n_permutation=1000)
    
    p_vals.append(p_val); aucs.append(auc)

p_vals = pd.DataFrame(np.vstack(p_vals))
aucs = pd.DataFrame(np.vstack(aucs))

aucs.columns = ["Theta", "Alpha", "Beta", "Gamma"]

## AUC box plot

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns

def box_plot_auc(df, save_path):

    # Melt the DataFrame to long format for Seaborn
    data_long = pd.melt(df)


    plt.figure(figsize=(6, 5))
    colors = ['#E6B213', 'sandybrown', '#E84653', 'lightseagreen']
    sns.boxplot(x='variable', y='value', data=data_long, color="lightgray")#, palette=colors)

    sns.stripplot(x='variable', y='value', data=data_long, color='black', marker='o', size=6, alpha=0.7, jitter=True)

    means = df.mean(axis=0)
    for i, mean in enumerate(means):
        plt.text(i, mean, '', color='black', ha='center', va='center', fontsize=2)


    # Customize the plot
    plt.xticks(fontsize=14)
    plt.yticks(fontsize=14)
    # plt.title('AUCs Across 10 Runs', fontsize=16)
    plt.ylabel('AUC', fontsize=16)
    plt.xlabel("")
    plt.grid()


    plt.gca().spines["right"].set_visible(False)
    plt.gca().spines["top"].set_visible(False)
    plt.gca().spines["left"].set_visible(False)
    plt.gca().spines["bottom"].set_visible(False)

    plt.tight_layout()
    # Show the plot
    plt.savefig(os.path.join(save_path, "AUCs.svg"), dpi=600, format="svg")

box_plot_auc(aucs, save_path="")

In [None]:
data_patient = data_patient.iloc[np.where(data_patient["diagnosis"] == "parkinson")[0], :]
print(data_patient.shape)

z_patient.index = data_patient.index
parkinson_patient_feat = data_patient.iloc[:,4:]
print(parkinson_patient_feat.shape)

In [None]:
import matplotlib.pyplot as plt
import numpy as np

np.random.seed(42)
a = list(z_patient.iloc[:, np.where(parkinson_patient_feat.columns=="Adjusted_Canonical_Relative_PowerTheta_all")[0][0]])
b = list(z_patient.iloc[:, np.where(parkinson_patient_feat.columns=="Adjusted_Canonical_Relative_PowerBeta_all")[0][0]])

plt.figure(figsize=(8, 8))

plt.ylim((-4, 4))
plt.xlim((-4, 4))

# Define the fixed order of labels and corresponding colors
order = [
    ('High beta - Low theta', 'red'),
    ('High theta - Low beta', 'purple'),
    ('High beta - Normal theta', 'blue'),
    ('Normal theta - Low beta', 'orange'),
    ('Normal beta - High theta', 'green'),
    ('Normal beta - Low theta', 'teal'),
    ('Low beta - Low theta', 'pink'),
    ('High beta - High theta', 'mediumvioletred'),
    ('Normal range', 'black')
]

# Initialize lists for colors and labels
colors = []
labels = []

# Assign colors and labels based on conditions
for theta, beta in zip(a, b):
    if beta > 0.68 and theta < -0.68:
        colors.append('red')
        labels.append('High beta - Low theta')
    elif theta > 0.68 and beta < -0.68:
        colors.append('purple')
        labels.append('High theta - Low beta')
    elif beta > 0.68 and -0.68 < theta < 0.68:
        colors.append("blue")
        labels.append('High beta - Normal theta')
    elif -.68 < theta < 0.68 and beta < -0.68:
        colors.append("orange")
        labels.append('Normal theta - Low beta')
    elif -0.68 < beta < 0.68 and theta > 0.68:
        colors.append("olive")
        labels.append('Normal beta - High theta')
    elif -0.68 < beta < 0.68 and theta < -0.68:
        colors.append("teal")
        labels.append('Normal beta - Low theta')

    elif  beta < -0.68 and theta < -0.68:
        colors.append("pink")
        labels.append('Low beta - Low theta')
    elif  beta > 0.68 and theta > 0.68:
        colors.append("mediumvioletred")
        labels.append('High beta - High theta')
    else:
        colors.append('black')
        labels.append('Normal range')

# Create the legend handles in the correct order
handles = []
for label, color in order:
    handles.append(plt.Line2D([0], [0], marker='o', color='w', markerfacecolor=color, markersize=10, label=label))

# Plot the scatter plot
plt.scatter(a, b, color=colors)

# Add the gray region and lines
plt.fill_betweenx(y=[-0.68, 0.68], x1=-0.68, x2=0.68, color='gray', alpha=0.5, label="|z| < 0.68")
plt.hlines(y=[-0.68, 0.68], xmin=-0.68, xmax=0.68, colors='black', linestyles='--', linewidth=1.5)
plt.vlines(x=[-0.68, 0.68], ymin=-0.68, ymax=0.68, colors='black', linestyles='--', linewidth=1.5)

# Set axis ticks
ticks = [-3, -0.68, 0, 0.68, 3]
plt.xticks(ticks)
plt.yticks(ticks)

# Labeling
plt.xlabel('Theta z-scores', fontsize=16)
plt.ylabel('Beta z-scores', fontsize=16)

# Style the plot
plt.grid(alpha=0.5)
plt.gca().spines["right"].set_visible(False)
plt.gca().spines["top"].set_visible(False)
plt.gca().spines["left"].set_visible(False)
plt.gca().spines["bottom"].set_visible(False)

# Add the legend with the correct order
plt.legend(handles=handles, fontsize=13)

# Finalize and save the plot
plt.tight_layout()
plt.savefig("normal_var.png", dpi=400)


In [192]:
import os
import pickle
import numpy as np
from matplotlib.colors import LinearSegmentedColormap, ListedColormap
import matplotlib.pyplot as plt
import scipy.stats as st
import seaborn as sns
import pandas as pd
import plotly.graph_objects as go



def plot_quantile_gauge(current_value, q1, q3, percentile_5, percentile_95, percentile_50, 
                        title="Quantile-Based Gauge", min_value=0, max_value=1, show_legend=False, bio_name=None, save_path=""):
    """
    Plots a gauge chart based on quantile ranges with a threshold marker for the 0.5 percentile.
    
    Parameters:
    - current_value (float): The current decimal value to display.
    - q1 (float): The 25th percentile value as a decimal.
    - q3 (float): The 75th percentile value as a decimal.
    - percentile_5 (float): The 5th percentile value as a decimal.
    - percentile_95 (float): The 95th percentile value as a decimal.
    - percentile_50 (float): The 0.5 percentile value as a decimal, marked by a threshold line.
    - title (str): The title of the gauge chart.
    - min_value (float): The minimum value for the gauge range (default is 0).
    - max_value (float): The maximum value for the gauge range (default is 1).
    - show_legend (bool): Whether to display the legend with color-coded ranges (default is False).
    """
    



    if current_value < percentile_5:
        value_color = "rgba(115, 90, 63, 1)"  # Purple 
    elif current_value < q1:
        value_color = "rgba(255, 215, 0, 1)"  # Gold 
    elif current_value <= q3:
        value_color = "rgba(34, 139, 34, 1)"  # Green 
    elif current_value <= percentile_95:
        value_color = "rgba(255, 99, 71, 1)"  # Tomato red
    else:
        value_color = "rgba(128, 0, 128, 1)"  # Purple

    if show_legend:
        number_font_size = 75
        delta_font_size = 30
    else:
        number_font_size = 150
        delta_font_size = 50
        
    fig = go.Figure(go.Indicator(
        mode="gauge+number+delta",
        value=current_value,
        number={'font': {'size': number_font_size, 'family': 'Arial', 'color': value_color}},  
        delta={'reference': percentile_50, 'position': "top", 'font': {'size': delta_font_size}},
        gauge={
            'axis': {
                'range': [min_value, max_value],
                'tickfont': {'size': 30, 'family': 'Arial', 'color': 'black'},
                'showticklabels': True,
                'tickwidth': 2,
                'tickcolor': "lightgrey",
                'tickvals': [round(min_value + i * (max_value - min_value) / 10, 2) for i in range(11)],  
            },
            'bar': {'color': "rgb(255, 69, 58)"},  
            'steps': [
                {'range': [min_value, percentile_5], 'color': "rgba(115, 90, 63, 1)"},  # Purple 
                {'range': [percentile_5, q1], 'color': "rgba(255, 215, 0, 0.6)"},  # Warm gold 
                {'range': [q1, q3], 'color': "rgba(34, 139, 34, 0.7)"},  # Forest green 
                {'range': [q3, percentile_95], 'color': "rgba(255, 99, 71, 0.6)"},  # Soft tomato red
                {'range': [percentile_95, max_value], 'color': "rgba(128, 0, 128, 0.9)"},  # dark Purple
            ],
            'threshold': {
                'line': {'color': "black", 'width': 6},  # Black line for the 0.5th percentile marker
                'thickness': 0.75,
                'value': percentile_50, 
            },
        },
        title={
            'text': bio_name,
            'font': {'size': 50, 'family': 'Arial', 'color': 'black'}
        }
    ))

    if show_legend:
        fig.add_trace(go.Scatter(x=[None], y=[None], mode="markers",
                                 marker=dict(size=12, color="rgba(115, 90, 63, 1)"),
                                 name="0-5th Percentile (Extremely Low)"))

        fig.add_trace(go.Scatter(x=[None], y=[None], mode="markers",
                                 marker=dict(size=12, color="rgba(255, 215, 0, 0.6)"),
                                 name="5th-25th Percentile (Below Normal)"))

        fig.add_trace(go.Scatter(x=[None], y=[None], mode="markers",
                                 marker=dict(size=12, color="rgba(34, 139, 34, 0.7)"),
                                 name="25th-75th Percentile (Normal)"))

        fig.add_trace(go.Scatter(x=[None], y=[None], mode="markers",
                                 marker=dict(size=12, color="rgba(255, 99, 71, 0.6)"),
                                 name="75th-95th Percentile (Above Normal)"))

        fig.add_trace(go.Scatter(x=[None], y=[None], mode="markers",
                                 marker=dict(size=12, color="rgba(128, 0, 128, 0.9)"),
                                 name="95th-100th Percentile (Extremely High)"))
        
        
    
    # Update layout for better aesthetics
    fig.update_layout(
        paper_bgcolor='white',
        plot_bgcolor='white',
        margin=dict(t=50, b=100 if show_legend else 30, l=30, r=30),  # Adjust bottom margin for legend
        showlegend=show_legend,
        width=1100,
        height=700,
        legend=dict(
            orientation="h",      # Horizontal orientation for legend
            yanchor="top",        # Align legend to top
            y=-0.2,               # Place below the chart
            xanchor="center",     # Center legend horizontally
            x=0.5,                # Centered under the chart
            font=dict(size=14)    # Set font size for readability
        ),
        xaxis=dict(visible=False),  # Hide x-axis
        yaxis=dict(visible=False)   # Hide y-axis   
    )

    # Display the adapted gauge chart
    # fig.show()
    # Save the figure as a PNG image with the specified name
    fig.write_image(os.path.join(save_path, f"{bio_name}.png"))

# INOCs

In [None]:
q_path = "/home/meganorm-mznasrabadi/Results/natureArticle_new_pcn_no_scalar/NM/Run_0/hetero_SHASH_bspline/Quantiles_estimate.pkl"
feature = ['Adjusted_Canonical_Relative_PowerTheta_all', "Adjusted_Canonical_Relative_PowerAlpha_all", 'Adjusted_Canonical_Relative_PowerBeta_all', "Adjusted_Canonical_Relative_PowerGamma_all"]
features_list = list(merged_data.columns)[3:]

sub_index = "sub-042"
statistics = cal_stats_for_gauge(q_path, feature, features_list, 
                                 site_id=merged_data.loc[sub_index]["site"], 
                                 gender_id=merged_data.loc[sub_index]["sex"], 
                                 age=merged_data.loc[sub_index]["age"]*100)

names = ["Theta", "Alpha", "Beta", "Gamma"]


for i, name in enumerate(feature):
    print(names[i])
    if names[i] == "Gamma": max_value=0.2
    else: max_value=1

    plot_quantile_gauge(merged_data.loc[sub_index, name],
                        statistics[name][1],
                        statistics[name][3],
                        statistics[name][0],
                        statistics[name][4],
                        statistics[name][2],
                        title="",
                        max_value=max_value,
                        show_legend=False,
                        bio_name=names[i],
                        save_path="/home/meganorm-mznasrabadi/MEGaNorm/pics/gauges"
                        )

In [199]:
# sub-PD1674
# sub-PD1551
# sub-PD1517
# sub-MNI0079
# sub-PD0512
# sub-PD1487

In [157]:
def plot_neurooscillochart(data, age_slices, save_path):
    
    # Age ranges
    ages = [f"{i}-{i+5}" for i in age_slices]
    
    sns.set_theme(style="whitegrid")
    
    fig, axes = plt.subplots(2, 1, figsize=(14, 10), sharex=True)
    
    def plot_gender_data(ax, gender_data, title, legend=True, colors=None):
        
        means = {k: [item[0] * 100 for item in v] for k, v in gender_data.items()}  
        stds = {k: [item[1] * 100 * 1.96 for item in v] for k, v in gender_data.items()}  
        
        df_means = pd.DataFrame(means, index=ages)
        df_stds = pd.DataFrame(stds, index=ages)
  
        my_cmap = ListedColormap(colors, name="my_cmap")
        
        bar_plot = df_means.plot(kind='bar', yerr=df_stds, capsize=4, stacked=True, ax=ax, alpha=0.7, 
                                 colormap=my_cmap)
        for p in bar_plot.patches:
            width, height = p.get_width(), p.get_height()
            x, y = p.get_xy()
            bar_plot.text(x + width / 2, 
                          y + height / 2 + 2, 
                          f'{height:.0f}%', 
                          ha='center', 
                          va='center', fontsize=14)
        ax.set_title(title, fontsize=18)
        ax.set_xlabel('Age Ranges', fontsize=16)
        if legend:
            ax.legend(loc='upper right', bbox_to_anchor=(1.1,1))  
        else:    
            ax.get_legend().remove()
            
        ax.grid(True, axis='y', linestyle='--', linewidth=0.5)
        ax.grid(False, axis='x')  
        ax.tick_params(axis='x', labelsize=14)
        ax.set_yticklabels([])  
    
    plot_gender_data(axes[0], data['Male'], "Males' Chrono-NeuroOscilloChart", 
                     colors= ['lightgrey', 'gray', 'dimgrey', 'lightslategray'])
    
    plot_gender_data(axes[1], data['Female'], "Females' Chrono-NeuroOscilloChart", legend=False, 
                     colors=['lightgrey', 'gray', 'dimgrey', 'lightslategray'])
    
    axes[1].set_xlabel('Age Ranges', fontsize=14)
    plt.xticks(rotation=45)
    plt.tight_layout()
    
    if save_path is not None:
        plt.savefig(os.path.join(save_path, 'Chrono-NeuroOscilloChart.png'), dpi=600)
    else:
        plt.show()


# PNOCs

In [None]:
# # Calculateing Oscilograms

gender_ids = {'Male':0, 'Female':1}
frequency_band_model_ids = {'Theta':0, 'Alpha':2, 'Beta':4, 'Gamma':6}
quantiles_path= os.path.join(processing_dir, 'Quantiles_estimate.pkl')
oscilograms, age_slices = calculate_oscilochart(quantiles_path, gender_ids, frequency_band_model_ids)

plot_neurooscillochart(oscilograms, age_slices, processing_dir)

# Plot growthchart

In [None]:
plot_growthcharts(processing_dir, idp_indices=list(range(8)), 
                  idp_names= ['Theta',
                            'Theta',
                            'Alpha',
                            'Alpha',
                            'Beta',
                            'Beta',
                            'Gamma',
                            'Gamma'])