# Cove ADHD-ASD analysis: learning latent constructs of brain activity using connectivity maps and variational autoencoders (VAEs)

In this notebook, we will be walking through the analysis of neurophysiological and neuroimaging data from individuals with autism spectrum disorders (ASD), attention deficit hyperactivity disorder (ADHD), as well as typically developing (TD) children. The goals of this notebook are twofold:
1. Preprocess magnetoencephalography (MEG) and functional magnetic resonance imaging (fMRI) data collected from children with either ASD, ADHD, or typical development histories, using best-practice approaches via commercially-friendly software implementations.
2. Extract connectivity networks
3. Fit neural network models (variational autoencoders (VAEs)) to the spatiotemporal connectivity networks
4. Relate the learned latent distributions to clinical outcomes and identify subgroups of patients within- and/or across diagnostic categories.

# First, let's compile our clinical data

This information will give us some intuition about the distributions of clinical severities among the patients, and also how they relate to each other at a high level (e.g., we can use PCA to explore the overall distributions of Strengths and Weaknesses of ADHD symptoms and Normal behavior (SWAN) evaluation scores). We can step through each of the populations and save a condensed clinical sheet for them.

In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import glob
import seaborn as sns
from functools import reduce
from sklearn.cluster import KMeans
from sklearn.decomposition import PCA


clinical_files = glob.glob("/home/coveneuro-leif/Downloads/**/Clinical*", recursive=True)
clinical_files = [c for c in clinical_files if ".csv" in c]
clinical_dfs = [pd.read_csv(c, encoding="utf-16") for c in clinical_files]
clinical_main_df = pd.concat(clinical_dfs)

clinical_main_df.columns
clinical_main_df['subject_id'].unique()
clinical_main_df.groupby(["subject_id", "primary_diagnosis"]).count()

### SWAN assessment
clinical_SWAN = clinical_main_df.loc[clinical_main_df['form_name'].isin(["SWAN"]), :]

dfs = []
for qnum in np.arange(18):
    question = f"P{qnum+1}"
    clinical_SWAN_Q = clinical_SWAN.loc[clinical_SWAN['field_name'].isin([f"SWAN{question}"]), :]
    
    clinical_SWAN_Q.index = clinical_SWAN_Q['subject_id']
    title = clinical_SWAN_Q['field_desc'].values[0]

    if qnum==0:
        clinical_SWAN_Q = pd.DataFrame(clinical_SWAN_Q[['primary_diagnosis', 'field_value']].values, columns=['primary_diagnosis', title], index=clinical_SWAN_Q.index)
    else:
        clinical_SWAN_Q = pd.DataFrame(clinical_SWAN_Q[['field_value']].values, columns=[title], index=clinical_SWAN_Q.index)
    dfs.append(clinical_SWAN_Q)

# Recode SWAN ratings into ordinal values
clinical_SWAN_final = pd.concat(dfs, axis=1)
clinical_SWAN_final_facto = clinical_SWAN_final.copy()
for c in clinical_SWAN_final.columns[1:]:
    clinical_SWAN_final[c] = pd.Categorical(clinical_SWAN_final[c].values, 
                                                      categories = ['Far above', 'Above', 'Slightly above', 'Average', 'Slightly below',
                                                                    'Below', 'Far below', 'Blank'])
    clinical_SWAN_final_facto.loc[clinical_SWAN_final_facto[c]=="Far above", c]         = 0
    clinical_SWAN_final_facto.loc[clinical_SWAN_final_facto[c]=="Above", c]             = 1
    clinical_SWAN_final_facto.loc[clinical_SWAN_final_facto[c]=="Slightly above", c]    = 2
    clinical_SWAN_final_facto.loc[clinical_SWAN_final_facto[c]=="Average", c]           = 3
    clinical_SWAN_final_facto.loc[clinical_SWAN_final_facto[c]=="Slightly below", c]    = 4
    clinical_SWAN_final_facto.loc[clinical_SWAN_final_facto[c]=="Below", c]             = 5
    clinical_SWAN_final_facto.loc[clinical_SWAN_final_facto[c]=="Far below", c]         = 6
    clinical_SWAN_final_facto.loc[clinical_SWAN_final_facto[c]=="Blank", c]             = 7
    
clinical_SWAN_final_facto.dropna(axis=0, inplace=True)
clinical_SWAN_final_facto[clinical_SWAN_final_facto.columns[1:]] = clinical_SWAN_final_facto.iloc[:, 1:].astype(int)

qnum = 17
fig, ax = plt.subplots(1, 1)
plt.suptitle(f"(P{qnum+1}) "+clinical_SWAN_final.columns[qnum+1])
plt.subplots_adjust(bottom=0.3)
sns.histplot(data=clinical_SWAN_final, x = clinical_SWAN_final.columns[qnum+1], hue = "primary_diagnosis", 
             kde=True, multiple="dodge", stat="percent", common_norm=False, ax = ax)
ax.set_xticklabels(ax.get_xticklabels(), rotation=90)
ax.set_xlabel('')

### Exploratory plotting of SWAN scores and their clusters
clf = KMeans(n_clusters=3)
clf.fit(clinical_SWAN_final_facto.iloc[:, 1:])

pca = PCA(n_components=2)
pca_out = pca.fit_transform(clinical_SWAN_final_facto.iloc[:, 1:])
plt.scatter(pca_out[:, 0], pca_out[:, 1], c = clf.labels_)
plt.scatter(pca_out[:, 0], pca_out[:, 1], c = clinical_SWAN_final_facto.iloc[:, 0].factorize()[0])

from scipy.stats import mannwhitneyu, ttest_ind
from itertools import combinations
from statsmodels.stats.multitest import fdrcorrection
unique_pairs = list(combinations(["Typically-Developing", "ADHD", "ASD"], 2))

pvals = dict()
stats = dict()
for pop1, pop2 in unique_pairs:
    pvals_temp = []
    stats_temp = []
    for c in clinical_SWAN_final_facto.columns[1:]:
        pval = ttest_ind(clinical_SWAN_final_facto.loc[clinical_SWAN_final_facto['primary_diagnosis']==pop1, c].values,
                            clinical_SWAN_final_facto.loc[clinical_SWAN_final_facto['primary_diagnosis']==pop2, c].values)[1]
        pvals_temp.append(pval)
        stat = ttest_ind(clinical_SWAN_final_facto.loc[clinical_SWAN_final_facto['primary_diagnosis']==pop1, c].values,
                            clinical_SWAN_final_facto.loc[clinical_SWAN_final_facto['primary_diagnosis']==pop2, c].values)[0]
        stats_temp.append(stat)
    pvals.update({pop1+"_"+pop2: np.array(pvals_temp)})
    stats.update({pop1+"_"+pop2: np.array(stats_temp)})

stats_df = pd.DataFrame(stats)
pvals_df = pd.DataFrame(pvals)
pvals_corr = fdrcorrection(pvals_df.values.ravel(), alpha = 0.05)[1].reshape(pvals_df.shape)
pvals_corr_df = pd.DataFrame(pvals_corr, index = clinical_SWAN_final_facto.columns[1:], columns = pvals_df.columns)

fig, ax = plt.subplots(1, 1, figsize=(20,10))
plt.subplots_adjust(left=0.3, bottom=0.3)
ax.set_title("log10(p-values) and test statistics of pairwise t-tests - SWAN items\nLower=first group is higher-functioning\nCells with values significant after FDR")
sns.heatmap(np.log10(pvals_corr_df), mask = pvals_corr_df>0.05, annot=stats_df.values)

### Demographics
data_table = []
clin_pops = ["Typically-Developing", "ADHD"]

pop = "ADHD" # "Typically-Developing", "ASD", "ADHD"
clinical_file = f"/home/coveneuro-leif/Downloads/Clinical_Demographics_{pop}/Clinical_Demographics_{pop}.csv"
data_dir = f"/home/coveneuro-leif/Downloads/Imaging_MEG-RestingState_{pop}/"

# Load data
clinical_data = pd.read_csv(clinical_file, encoding="utf-16")
# clinical_data.columns

# Get data filenames. Also get shortened subject_id's that correspond to the formatting in the clinical datasheet
all_files = glob.glob(data_dir+"/*")
all_files_subs = ["_".join(a.split("/")[-1].split("_")[:3]) for a in all_files]

# For grabbing the correct rows from the main df
selector = clinical_data['subject_id'].isin(all_files_subs)
clinical_data_meg = clinical_data.loc[selector, :]

# Demographics
# Sex
demos_sex = clinical_data_meg.loc[clinical_data_meg['field_name'].isin(['GENDER']),:]
prc_female = (demos_sex['field_value']=="Female").mean()

# Age
demos_age = clinical_data_meg.loc[clinical_data_meg['field_name'].isin(['AGE_AT_ENROLLMENT']),:]
age_median = demos_age['field_value'].astype(int).median()
age_iqr_lo = demos_age['field_value'].astype(int).quantile(0.25)
age_iqr_hi = demos_age['field_value'].astype(int).quantile(0.75)

# plt.figure()
# plt.hist(demos_age['field_value'].astype(int))
# plt.title(f"Age-{pop}")
# plt.savefig(f"hist-{pop}.png")

data_list = [prc_female, age_median, age_iqr_lo, age_iqr_hi]
data_table.append(data_list)

data_table_df = pd.DataFrame(data_table).T
data_table_df.index = ["% Female", "Age (median)", "Age (IQR-lower)", "Age (IQR-upper)"]
data_table_df.columns = clin_pops

# %% Clinical scores

clinical_file = f"/home/coveneuro-leif/Downloads/Clinical_Demographics_{pop}/Clinical_Demographics_{pop}.csv"
data_dir = f"/home/coveneuro-leif/Downloads/Imaging_MEG-RestingState_{pop}/"

# Load data
clinical_data = pd.read_csv(clinical_file, encoding="utf-16")
# clinical_data.columns


# Next, let's get into preprocessing and network extraction

Here, we will load the files that we want to analyze (MEG first, then fMRI). Load all libraries and grab MEG files:

In [None]:
# Lots of the below from the MNE tutorial here: https://mne.tools/stable/auto_tutorials/preprocessing/40_artifact_correction_ica.html
import numpy as np
import pandas as pd
import os
import glob
import mne
import mne_connectivity
from mne.preprocessing import ICA, corrmap
from mne.preprocessing import (
    compute_proj_ecg,
    compute_proj_eog,
    create_ecg_epochs,
    create_eog_epochs,
)
import autoreject
import matplotlib.pyplot as plt
from mne import write_source_spaces

# Load data - assumes shuttle_MEG_files.py has been run on the main_dir target
main_dir    = "/home/coveneuro-leif/Downloads/Imaging_MEG-RestingState_Typically-Developing/"
file_idx    = 0
all_files_TD   = glob.glob(main_dir + "/*")
all_files_TD   = [a for a in all_files_TD if ".ds" in a]

# Load data - assumes shuttle_MEG_files.py has been run on the main_dir target
main_dir    = "/home/coveneuro-leif/Downloads/Imaging_MEG-RestingState_ASD/"
file_idx    = 0
all_files_ASD   = glob.glob(main_dir + "/*")
all_files_ASD   = [a for a in all_files_ASD if ".ds" in a]

# Load data - assumes shuttle_MEG_files.py has been run on the main_dir target
main_dir    = "/home/coveneuro-leif/Downloads/Imaging_MEG-RestingState_ADHD/"
file_idx    = 0
all_files_ADHD   = glob.glob(main_dir + "/*")
all_files_ADHD   = [a for a in all_files_ADHD if ".ds" in a]

all_files = all_files_TD+all_files_ASD+all_files_ADHD


Now let's iterate through and perform a first processing pass of all MEG files:

In [None]:
# %% Overall workflow

# Load data
# Get overlapping channel names across the max number of samples
# Recognizing that many are not shared (but only off by 1 or 2 usually,
# except for one weirdo sample that uses ENTIRELY different channel names)
chs  = list()
for idx_f, f in enumerate(all_files):
    # break
    raw         = mne.io.read_raw_ctf(f, preload=True).pick_types("mag")
    chs.append(pd.DataFrame(data=[1]*len(raw.info['ch_names']), index=raw.info['ch_names']))
chs_ = pd.concat(chs, axis=1)
chs_keep = list(chs_.index[chs_.sum(axis=1)>=58].values)
chs_keep = [c for c in chs_keep if "MRP21" not in c]
# chs_keep_4602 = [c.replace("1706", "4602").replace("2104", "4122").replace("3000", "4122") for c in chs_keep]

# raw.info.rename_channels(mapping={c: c.split("-")[0] for c in raw.info['ch_names']})
chs_keep_2 = [c.split("-")[0] for c in chs_keep]
n_comps = 15

# Perform batch processing up to the point of ICA
for idx_f, f in enumerate(all_files):
    '''
    idx_f = 13 
    f=all_files[idx_f]
    '''
    # try:
    # Read in, select MAG channels only, resample (less memory etc., doesn't need to be at 600)
    # raw         = mne.io.read_raw_ctf(f, preload=True).resample(250).pick_channels(chs_keep).pick_types("mag")
    # raw         = mne.io.read_raw_ctf(f, preload=True).resample(250).pick_channels(chs_keep_2).pick_types("mag")
    raw         = mne.io.read_raw_ctf(f, preload=True).resample(250)
    raw.info.rename_channels(mapping={c: c.split("-")[0] for c in raw.info['ch_names']})
    raw         = raw.pick_types("mag").pick_channels(chs_keep_2)
    # raw.info['ch_names']
    # Filter data - notch and then BP
    # The 120 is arguably redundant here - remove in future to save time?
    filtered    = raw.copy().notch_filter([60, 120])
    filtered    = filtered.copy().filter(l_freq=1, h_freq = 100)
    filtered.save(f"{f.replace('_meg.ds', '_filtered_meg.fif')}", overwrite=True)
    
    # ICA
    ica         = ICA(n_components=n_comps, max_iter=100, random_state=97)
    ica.fit(filtered, verbose="error")
    ica.save(f"{f.replace('_meg.ds', '_ica.fif')}", overwrite=True)
    ica_txed    = ica._transform_raw(filtered, 0, len(filtered.times)) # Sneaky
    
    new_dir     = f.replace(".ds", "_scan")#"/".join(f.split("/")[:-1]) + f"/scan_{idx_f}"
    
    # Make a new directory to save processed data if it doesn't already exist
    if os.path.isdir(new_dir)==False:
        os.mkdir(new_dir)
    else:
        pass
    
    for i in range(n_comps):
        fig, ax = plt.subplots(1,1, figsize=(1.2, 1.2))
        ica.plot_components(picks=i, axes=ax, show=False, title="",
                            outlines=None, sensors=False, contours=0)
        ax.set_title("")
         # TODO MEGnet cuts timeseries into epochs but uses the same spatial map repeatedly
         # NOT recalculated per-epoch
        timeseries = ica_txed[i, :]
        
        win_len     = 15000
        overlap     = 3750
        start_times = []
        st          = 0
        k           = 0
        while st+win_len<len(timeseries):
            # We append first and increment next to ensure there is (win_len) 
            # of room remaining to index into later!
            start_times.append(st)
            # print(st, st+win_len)
            
            plt.savefig(new_dir+f"/map_{i}_epoch_{k}.png")
            np.save(new_dir+f"/timeseries_{i}_epoch_{k}", timeseries[st:(st+win_len)])
            
            st += win_len-overlap
            k += 1
            
        plt.close()
    # except:
        # pass
    # ica         = mne.preprocessing.read_ica(f"{f.replace('_meg.ds', '_ica.fif')}")
    # ica.plot_sources(filtered) # Clicking automatically excludes those components
    # break


Here, we manually iterate through the files and perform ICA. The workflow is split into the cell below this one, and then the one below that one: run the first cell to load the ICA object thta was saved in the previous iterative loop, as well as the filtered data. Then visually inspect for data quality. To reconstruct & save the data, once it is properly cleaned, run the following cell.

In [None]:
# %% Perform single processing to (1) get bad ICs, (2) reconstruct low-dim data, (3) save reconstructed data

ix          = 58
idx_f, f    = ix, all_files[ix]
filtered    = mne.io.read_raw_fif(f"{f.replace('_meg.ds', '_filtered_meg.fif')}", preload=True)
ica         = mne.preprocessing.read_ica(f"{f.replace('_meg.ds', '_ica.fif')}")
ica.plot_sources(filtered) # Clicking automatically excludes those components


In [None]:
reconst = filtered.copy()
ica.apply(reconst)
reconst.save(f"{f.replace('_meg.ds', '_ica15_reconst_meg.fif')}", overwrite=True)
# reconst.plot()

Perform autoreject to remove bad epochs

In [None]:
# %% Iterate over all reconstructed files and perform autoreject

# Perform batch processing up to the point of ICA
for idx_f, f in enumerate(all_files):
    if os.path.isfile(f"{f.replace('_meg.ds', '_ica15_reconst_meg_ar_epo.fif')}")==True:
        continue
    else:
        print(idx_f)
        try:
            reconst = mne.io.read_raw_fif(f"{f.replace('_meg.ds', '_ica15_reconst_meg.fif')}")
            
            # Make into epochs for autoreject
            epochs = mne.make_fixed_length_epochs(reconst, duration=3, preload=True)#, picks=mag_channels)
            
            # Autoreject
            ar = autoreject.AutoReject(n_jobs=-1, verbose=True)
            ar.fit(epochs)
            epochs_ar, reject_log = ar.transform(epochs, return_log=True)
            epochs_ar.save(f"{f.replace('_meg.ds', '_ica15_reconst_meg_ar_epo.fif')}", overwrite=True)
            # epochs_ar.plot()
            # break
        except:
            continue


Now we perform source modeling. This is one of the most powerful aspects of MEG, owing to its high-density ata collection ability and relative lack of noise due to volume effects (cf. EEG, where this is a major problem)

In [None]:

for idx_f, f in enumerate(all_files):
    # break
    if os.path.isfile(f"{f.replace('_meg.ds', '_ica15_reconst_meg_ar_epo.fif')}")==False:
        continue
    else:
        
        epo         = mne.read_epochs(f"{f.replace('_meg.ds', '_ica15_reconst_meg_ar_epo.fif')}") 
        info        = epo.info#raw.info
        
        noise_cov   = mne.make_ad_hoc_cov(info)
        parc        = "aparc"  # the parcellation to use, e.g., 'aparc' 'aparc.a2009s'
        loose       = dict(surface=0.2, volume=1.0)
        snr         = 3.0  # use smaller SNR for raw data
        lambda2     = 1.0 / snr**2
        
        data_path = mne.datasets.sample.data_path()
        subject = "sample"
        data_dir = data_path / "MEG" / subject
        subjects_dir = data_path / "subjects"
        bem_dir = subjects_dir / subject / "bem"
        
        # Set file names
        fname_mixed_src = bem_dir / f"{subject}-oct-6-mixed-src.fif"
        fname_aseg = subjects_dir / subject / "mri" / "aseg.mgz"
        
        fname_model = bem_dir / f"{subject}-5120-bem.fif"
        fname_bem = bem_dir / f"{subject}-5120-bem-sol.fif"
        
        fname_evoked = data_dir / f"{subject}_audvis-ave.fif"
        fname_trans = data_dir / f"{subject}_audvis_raw-trans.fif"
        # fname_fwd = data_dir / f"{subject}_audvis-meg-oct-6-mixed-fwd.fif"
        fname_cov = data_dir / f"{subject}_audvis-shrunk-cov.fif"
        
        
        # import os.path as op
        # from mne.datasets import fetch_fsaverage
        # fs_dir = fetch_fsaverage(verbose=True)
        # subjects_dir = op.dirname(fs_dir)
        # subject = 'fsaverage'
        # trans = 'fsaverage'
        # src = op.join(fs_dir, 'bem', 'fsaverage-ico-5-src.fif')
        # fname_bem = op.join(fs_dir, 'bem', 'fsaverage-5120-5120-5120-bem-sol.fif')
        
        labels_vol = [
            "Left-Amygdala",
            "Left-Thalamus-Proper",
            "Left-Cerebellum-Cortex",
            "Brain-Stem",
            "Right-Amygdala",
            "Right-Thalamus-Proper",
            "Right-Cerebellum-Cortex",
        ]
        
        ''' # Run the commented stuff here IF YOU WANT TO REMAKE THE INVERSE OPERATOR AND SAVE IT
        # Get a surface-based source space, here with few source points for speed in this demonstration, in general you should use oct6 spacing!
        src = mne.setup_source_space(
            subject, spacing="oct6", add_dist=False, subjects_dir=subjects_dir
        )
        vol_src = mne.setup_volume_source_space(
            subject,
            mri=fname_aseg,
            pos=10.0,
            bem=fname_model,
            volume_label=labels_vol,
            subjects_dir=subjects_dir,
            add_interpolator=False,  # just for speed, usually this should be True
            verbose=True,
        )
        
        # Generate the mixed source space
        src += vol_src
        write_source_spaces(fname_mixed_src, src, overwrite=True)
        
        
        fwd         = mne.make_forward_solution(info, trans=fname_trans, 
                                                src=fname_mixed_src, 
                                                bem=fname_bem, meg=True, eeg=False, 
                                                mindist=0.0, ignore_ref=False, 
                                                n_jobs=None, verbose=None)
        inv         = mne.minimum_norm.make_inverse_operator(info, fwd, 
                                                             noise_cov, depth=None, 
                                                             loose=loose, verbose=True)
        del fwd
        mne.minimum_norm.write_inverse_operator("/home/coveneuro-leif/Documents/scripts/inv", inv, overwrite=True)
        '''
        epo.info['ch_names']
        epo.rename_channels({c: c.split("-")[0] for c in epo.info['ch_names']})
        # inv.ch_names
        inv         = mne.minimum_norm.read_inverse_operator("/home/coveneuro-leif/Documents/scripts/inv")
        stc         = mne.minimum_norm.apply_inverse_epochs(epo, 
                                                            inv, 
                                                            lambda2, 
                                                            method="dSPM", 
                                                            pick_ori=None)
        # stc_vec     = mne.minimum_norm.apply_inverse_epochs(epochs_ar, 
        #                                                     inv, 
        #                                                     lambda2, 
        #                                                     method="dSPM", 
        #                                                     pick_ori="vector")
        src = inv["src"]
        # brain = stc_vec[0].plot(
        #     hemi="both",
        #     src=inv["src"],
        #     views="coronal",
        #     # initial_time=initial_time,
        #     subjects_dir=subjects_dir,
        #     brain_kwargs=dict(silhouette=True),
        #     smoothing_steps=7,
        # )
        # Get labels for FreeSurfer 'aparc' cortical parcellation with 34 labels/hemi
        labels_parc = mne.read_labels_from_annot(subject, parc=parc, subjects_dir=subjects_dir)
        labels = mne.read_labels_from_annot("sample", parc="aparc", subjects_dir=subjects_dir)
        
        label_ts = [mne.extract_label_time_course(
            stc[i], labels_parc, src, mode="mean", allow_empty=True
        ) for i in range(len(stc))]
        
        np.save(f"{f.replace('_meg.ds', 'label_ts.npy')}", label_ts)


At last! Let's extract some networks. Note here that we use the "coherence" method which, if we were using EEG, would be inappropriate due to volume conduction effects. For MEG, it is okay. We iterate over the previously-saved .npy source space timeseries data and save a series of matrices. We also create a dataset linked up to the clinical data from earlier.

In [None]:
# %% Connectivity analysis

# Network analysis
from mne_connectivity import spectral_connectivity_epochs
from mne_connectivity.viz import plot_connectivity_circle
from mne.viz import circular_layout
from itertools import product

data_path = mne.datasets.sample.data_path()
subject = "sample"
data_dir = data_path / "MEG" / subject
subjects_dir = data_path / "subjects"
type_ = "source"
sfreq = 250
bands = {"theta": (4, 8),
         "alpha": (8, 12),
         "beta": (12, 25),
         "broad": (1, 100)}

# pop = "Typically-Developing"

for type_ in ["source"]:#, "sensor"]:
    # break
    for band in bands:
        fmin = bands[band][0]
        fmax = bands[band][1]
        
        class_labels = []
        conn_raveled_list = []
        for idx_f, f in enumerate(all_files):
            pop = f.split("/")[-2].split("_")[-1]
            
            if type_=="source":
                if os.path.isfile(f"{f.replace('_meg.ds', 'label_ts.npy')}")==False:
                    continue
                else:
                    label_ts = list(np.load(f"{f.replace('_meg.ds', 'label_ts.npy')}"))
            else:
                if os.path.isfile(f"{f.replace('_meg.ds', '_ica15_reconst_meg_ar_epo.fif')}")==False:
                    continue
                else:
                    label_ts = mne.read_epochs(f"{f.replace('_meg.ds', '_ica15_reconst_meg_ar_epo.fif')}")
                    label_ts = label_ts.pick("mag")
            class_labels.append(1)
            
            conn_data_list = []
            for l in label_ts:
                conn = spectral_connectivity_epochs([l], method='coh', sfreq=sfreq, mode='multitaper', 
                                                    fmin=fmin, fmax=fmax, faverage=False, tmin=None, tmax=None, 
                                                    n_jobs=-1)
                conn_data = conn.get_data(output="dense").mean(axis=2)[:, :, np.newaxis]#[:, :, 0]#
                # conn_raw = conn.get_data(output="dense")
                conn_data_list.append(conn_data)
            
            conn_raw = np.concatenate(conn_data_list, axis=2)
            
            if type_=="source":
                fpath = f"/home/coveneuro-leif/Downloads/dl_project_{band}/{f.split('/')[-1].split('.')[0]}_{band}.npy"
                np.save(fpath, conn_raw)
                if pop=="Typically-Developing":
                    class_labels.append(1)
                elif pop=="ADHD":
                    class_labels.append(2)
                elif pop=="ASD":
                    class_labels.append(3)

            # First, we reorder the labels based on their location in the left hemi
            labels = mne.read_labels_from_annot("sample", parc="aparc", subjects_dir=subjects_dir)
            labels_vol = [
                "Left-Amygdala",
                "Left-Thalamus-Proper",
                "Left-Cerebellum-Cortex",
                "Brain-Stem",
                "Right-Amygdala",
                "Right-Thalamus-Proper",
                "Right-Cerebellum-Cortex",
            ]
            
            label_names = [label.name for label in labels] + labels_vol
            lh_labels = [name for name in label_names if name.endswith("lh")]
            
            # Get the y-location of the label
            label_ypos = list()
            for name in lh_labels:
                idx = label_names.index(name)
                ypos = np.mean(labels[idx].pos[:, 1])
                label_ypos.append(ypos)
            
            # Reorder the labels based on their location
            lh_labels = [label for (yp, label) in sorted(zip(label_ypos, lh_labels))]
            rh_labels = [label[:-2] + "rh" for label in lh_labels]
            
            # Save the plot order and create a circular layout
            node_order = list()
            node_order.extend(lh_labels[::-1])  # reverse the order
            node_order.extend(rh_labels)
            node_order.extend(labels_vol)
            
            node_angles = circular_layout(label_names, node_order, start_pos=90, group_boundaries=[0, len(label_names) / 2])
            
            ''' Plotting connectivity circle(s)
            # Plot the graph using node colors from the FreeSurfer parcellation. We only
            # show the 300 strongest connections.
            fig, ax = plt.subplots(figsize=(8, 8), facecolor="black", subplot_kw=dict(polar=True))
            plot_connectivity_circle(
                conn_data,
                label_names,
                n_lines=300,
                node_angles=node_angles,
                # node_colors=label_colors,
                title="All-to-All Connectivity left-Auditory " "Condition (PLI)",
                ax=ax,
            )
            fig.tight_layout()
            '''
            
            if type_=="source":
                conn_df = pd.DataFrame(data = conn_raw.mean(axis=2), 
                                       index = lh_labels+rh_labels+labels_vol, 
                                       columns = lh_labels+rh_labels+labels_vol)
            else:
                conn_df = pd.DataFrame(data = conn_data, 
                                       index = label_ts.info['ch_names'], 
                                       columns = label_ts.info['ch_names'])
            
            pairs = [f"{x}_{y}" for x, y in list(product(conn_df.columns, conn_df.index))]
            conn_raveled = pd.DataFrame(data = conn_df.values.ravel(),
                                        index=pairs).T
            conn_raveled.index = ["_".join(f.split("/")[-1].split("_")[:3])]
            
            conn_raveled_list.append(conn_raveled)
        
        conn_raveled_df = pd.concat(conn_raveled_list)
        conn_raveled_df.to_csv(f"{f.replace('_meg.ds', f'_conn_df_{type_}_{band}.csv')}")
        np.save(f"/home/coveneuro-leif/Downloads/dl_project/labels_{pop}.npy", class_labels, allow_pickle=True)
        # plt.plot(conn_raveled_df.T)
        
        ''' # Misc/ICA etc.
        ica.plot_components()
        ica.plot_sources(filtered)
        filtered.plot()
        ica.plot_overlay(filtered, exclude=[0,1], picks="mag")
        ica.exclude = [0, 1]
        reconst = filtered.copy()
        ica.apply(reconst)
        
        # Comparing pre- vs post-ICA
        filtered.plot()
        reconst.plot()
        '''
        
        ## %% Clinical - MEG tie up
        clinical = pd.read_csv("/home/coveneuro-leif/clinical_SWAN.csv")
        clinical.index = clinical['subject_id']
        # sfreq = epochs_ar.info['sfreq']
        # final_dataset = pd.concat([clinical, conn_raveled_df])
        final_dataset = pd.merge(left=clinical, right=conn_raveled_df,
                                  left_index=True, right_index=True)
        
        final_dataset = final_dataset.loc[:, (final_dataset != 0).any(axis=0)]
        final_dataset.to_csv("/".join(all_files[0].split("/")[:-1])+f"/connectivity_{type_}_{band}_fixed.csv")


Now, we do the same for fMRI. We will specifically use the ANTsPy libraries, as these are open and available for commercial use, as opposed to the libraries like NiPype which, while outstanding, rely on closed-source or copyleft (=your codebase is required to be open if you use their code) tools like FSL and SPM. For commercial entities, this is an important consideration and, given the work required to develop and tune a preprocessing workflow, it is better to build the best way, from the beginning.

In [None]:
from nilearn import datasets
from nilearn import plotting

from nilearn.maskers import NiftiLabelsMasker
from nilearn.maskers import NiftiMapsMasker
import numpy as np
import nilearn
import seaborn as sns
from nilearn.connectome import ConnectivityMeasure
from sklearn.covariance import GraphicalLassoCV
from nilearn.connectome import GroupSparseCovarianceCV
import matplotlib.pyplot as plt
from nilearn.datasets import MNI152_FILE_PATH
from nilearn import plotting
from nilearn import image
from nilearn.image import index_img
from nilearn import datasets
from nilearn import plotting
from nilearn.image import mean_img
import pandas as pd
import glob
import ants
import antspynet
import time
import nibabel as nib
import numpy as np
import ants
from nilearn.masking import compute_epi_mask, apply_mask
from nilearn.plotting import plot_epi, plot_roi
from nilearn.image.image import mean_img
from scipy.interpolate import interp1d
import matplotlib.pyplot as plt
from nilearn.input_data import NiftiLabelsMasker
from nilearn import datasets
from nilearn.image import clean_img
from nilearn import image

pops = ["ADHD", "ASD", "Typically-Developing"]

pop = pops[0]
MR_dfs = []
for pop in pops:
    anat = [i.replace("\\", "/") for i in glob.glob(f"/home/coveneuro-leif/Downloads/Imaging_MR-T1_{pop}/*.nii.gz")]
    anat_df = pd.DataFrame(data = np.c_[["_".join(i.split("/")[-1].split("_")[0:4]) for i in anat], anat], columns = ["subid", "path"])
    fxal = [i.replace("\\", "/") for i in glob.glob(f"/home/coveneuro-leif/Downloads/Imaging_fMRI-RestingState_{pop}/*.nii")]
    fxal_df = pd.DataFrame(data = np.c_[["_".join(i.split("/")[-1].split("_")[0:4]) for i in fxal], fxal], columns = ["subid", "path"])
    MR_merged = pd.merge(left=anat_df, right=fxal_df, left_on="subid", right_on="subid")
    MR_merged.columns = ["subid", "path_structural", "path_functional"]
    MR_dfs.append(MR_merged)

MR_merged_final = pd.concat(MR_dfs)
T1_files = MR_merged_final["path_structural"].values
fmri_files = MR_merged_final["path_functional"].values

Time for fMRI preprocessing. This is made fairly painless using ANTsPy. We strip the skull using an AI method from ANTsPyNet, then correct for motion, and register the functional image to the corresponding anatomical image and then to a template to ensure standardization. We use publicly available developmental fMRI templates, as childrens' heads are clearly different from adults' heads.

In [None]:
fmri_file = fmri_files[1]  # Path to the fMRI file

# motions = {}
bads = []
for ix_f, fmri_file in enumerate(fmri_files):
    # break
    # try:
    print(f"\n\n\n {ix_f}\n\n\n")
    st = time.time()
    
    # Get matching T1/fMRI files
    fmri_subid = "_".join(fmri_file.split("/")[-1].split("_")[:4])
    t1_file = T1_files[np.array([fmri_subid in t for t in T1_files])][0]
    pop = fmri_file.split("/")[-2].split("_")[-1]
    print(f"Sorted file IDs at {time.time()-st} sec")
    
    # TODO SUPER HELPFUL
    # https://cdn.jamanetwork.com/ama/content_public/journal/jamanetworkopen/939092/zoi230095supp1_prod_1678113687.23714.pdf?Expires=1733297702&Signature=zr4d70M4ydKr5cOcCCR3DfGfrEaVRNu7Rw9crO5YwDeyMurZeptnWwc-ADevB0BI~g4bVdfPwXt~rZHbH7deS5-2FCk2JcyRagma~CBQybTjKz1T-Q6ALWUGJKtLSoDbV5SPY0-qLcSmbUoxFcF5-Q8ImRwLHrVSYQ9YcSH27v1LtxlvhUGxVdSIpxJXIoIou0U8rJ0nIww33lyNahEdU8E1P7mVAt2bpEo8pk8wvVxxxBVamrDcTui1oNH7mPYLx6P5X9O7MzEkoM1hYX3n4g99fi6UIF-S~77hTgwwtN6OpsXcRJg6MdrkYqxCJXw-aeDy8-uZVl006s5AEzZ4-A__&Key-Pair-Id=APKAIE5G5CRDK6RD3PGA
    
    ## Load files
    fmri = ants.image_read(fmri_file)
    fmri_nib = nib.load(fmri_file)  # Load the fMRI image using nibabel for skull stripping
    t1 = ants.image_read(t1_file)
    # ants.plot(t1)
    print(f"Read files at {time.time()-st} sec")
    
    ## Strip skull
    fmri_mask = compute_epi_mask(fmri_nib)
    fmri_data = fmri_nib.get_fdata()
    mask_data = fmri_mask.get_fdata()
    skull_stripped_data = fmri_data*mask_data[..., None]
    skull_stripped_img = nib.Nifti1Image(skull_stripped_data, fmri_nib.affine, fmri_nib.header)
    fmri_brain_path = 'skull_stripped_fmri.nii.gz'
    nib.save(skull_stripped_img, fmri_brain_path)
    fmri_brain = ants.image_read(fmri_brain_path)
    # ants.plot(ants.slice_image(fmri_brain, axis=3, idx=0))
    
    seg = antspynet.brain_extraction(t1, modality="t1", verbose=True)
    # ants.plot(t1, overlay=seg, overlay_alpha=0.5)
    t1_brain = t1*seg
    print(f"Stripped skulls at {time.time()-st} sec")
    
    ## Motion correction 
    ref_vol = ants.slice_image(fmri_brain, axis=3, idx=fmri_brain.shape[3]//2)
    fmri_corrected = ants.motion_correction(fmri_brain, target_image=ref_vol)
    motion_params = fmri_corrected['motion_parameters']
        
    # with open(f"{fmri_subid}_motion_params.pkl", "rb") as file:
    #     motpars = pickle.load(file)
    print(f"Motion corrected at {time.time()-st} sec")
    
    ## Registration: T1<>Template
    template_path = "/home/coveneuro-leif/Downloads/pediatric_templates/nihpd_asym_04.5-18.5_t1w.nii"
    mni_template = ants.image_read(template_path)
    seg_template = antspynet.brain_extraction(mni_template, modality="t1", verbose=True)
    mni_template = mni_template*seg_template
    t1_to_mni = ants.registration(fixed=mni_template, moving=t1_brain, type_of_transform="Affine")
    print(f"Registered T1-template at {time.time()-st} sec")
    
    ## Registration: Functional<>Anatomical
    # TODO following line: select middle volume, apply across all (see CGPT)
    middle_idx = fmri_corrected['motion_corrected'].shape[-1] // 2
    fmri_single_vol = ants.slice_image(fmri_corrected['motion_corrected'], axis=3, idx=middle_idx)
    # ants.plot(fmri_single_vol)
    fmri_to_t1 = ants.registration(fixed=t1_brain, moving=fmri_single_vol, type_of_transform="Affine") # "SyN"
    coregistered_bold = fmri_to_t1['warpedmovout']
    print(f"Registered fMRI-T1 at {time.time()-st} sec")

    t1_resamp = ants.resample_image(t1_brain, (3.0, 3.0, 3.0), use_voxels=False)
    fmri_aligned = ants.apply_transforms(fixed=t1_resamp, 
                                              moving=fmri_corrected['motion_corrected'], 
                                              transformlist=fmri_to_t1['fwdtransforms']+t1_to_mni['fwdtransforms'], 
                                              interpolator='linear',
                                              imagetype=3)

    ants.image_write(fmri_aligned, f"{fmri_subid}_{pop}_fmri_aligned.nii.gz")


We perform final cleaning as necessary prior to extracting the timeseries of brain activity data

In [None]:
import os

# Load an atlas with labeled regions (e.g., Harvard-Oxford atlas)
atlas = datasets.fetch_atlas_harvard_oxford('cort-maxprob-thr25-2mm')
atlas_filename = atlas.maps  # Path to the atlas labels
labels = atlas.labels        # Region names

bads = []
# ix_f = 97+125 had a compression error (due to memory issues before?)
# ix_f = 97+126+7 had a compression error (due to memory issues before?)
for ix_f, fmri_file in enumerate(fmri_files):
    print(f"\n {ix_f}")
    st = time.time()
    
    # Get matching T1/fMRI files
    fmri_subid = "_".join(fmri_file.split("/")[-1].split("_")[:4])
    t1_file = T1_files[np.array([fmri_subid in t for t in T1_files])][0]
    pop = fmri_file.split("/")[-2].split("_")[-1]
    print(f"Sorted file IDs at {time.time()-st} sec")
    
    fmri_aligned_img = image.load_img(f"{fmri_subid}_{pop}_fmri_aligned.nii.gz")
    # fmri_aligned_img = image.load_img("/home/coveneuro-leif/PND03_HSC_0019_01_ADHD_fmri_aligned.nii.gz")
    try:
        if os.path.isfile(f"{fmri_subid}_{pop}_fmri_aligned_cleaned.nii.gz")==False:
            if fmri_aligned_img.shape[-1]<100:
                continue
            fmri_cleaned = clean_img(
                fmri_aligned_img,
                detrend=True,             # Remove slow drifts
                standardize=True,         # Standardize each voxel's time series to z-scores
                low_pass=0.1,             # Low-pass filter (frequency in Hz, adjust as needed)
                high_pass=0.008,          # High-pass filter (frequency in Hz, adjust as needed)
                t_r=1.5                   # Repetition time (TR) in seconds, adjust to match your data
            )
            nib.save(fmri_cleaned, f"{fmri_subid}_{pop}_fmri_aligned_cleaned.nii.gz")
            # nib.save(fmri_cleaned, "/home/coveneuro-leif/PND03_HBK_0265_03_ASD_fmri_aligned_cleaned.nii.gz")
        else:
            print("Sample already cleaned")
            pass
        
        # Load preprocessed fMRI image
        fmri_cleaned = nib.load(f"{fmri_subid}_{pop}_fmri_aligned_cleaned.nii.gz")
        
        # Set up the NiftiLabelsMasker
        masker = NiftiLabelsMasker(labels_img=atlas_filename, standardize=True, detrend=True)
        
        # Extract timeseries for each region
        timeseries_data = masker.fit_transform(fmri_cleaned)
        np.save(f"{fmri_subid}_{pop}_fmri_aligned_cleaned_timeseries.npy", timeseries_data)
        print("Timeseries extracted\n")
        # timeseries_data = np.load(f"{fmri_subid}_{pop}_fmri_aligned_cleaned_timeseries.npy")
    except:
        bads.append(f"{fmri_subid}_{pop}_fmri_aligned_cleaned.nii.gz")
        

At long last, we can extract the connectivity matrices:

In [None]:

window_size = 10
step_size = 1

def get_windows(window_size, step_size, timeseries_data):
    num_windows = (timeseries_data.shape[0]-window_size)//(step_size+1)
    windows = np.array([timeseries_data[(i*step_size):(i*step_size+window_size), :] for i in range(num_windows)])
    return windows

# TODO use "multi" to re-save arrays
kind = "multi" # single, multi
raveled_mats = []
pops = []
subids = []

for ix_f, fmri_file in enumerate(fmri_files):
    # break
    try:
        print(f"\n\n\n {ix_f}\n\n\n")
        st = time.time()
        
        # Get matching T1/fMRI files
        fmri_subid = "_".join(fmri_file.split("/")[-1].split("_")[:4])
        t1_file = T1_files[np.array([fmri_subid in t for t in T1_files])][0]
        pop = fmri_file.split("/")[-2].split("_")[-1]
        print(f"Sorted file IDs at {time.time()-st} sec")
        
        timeseries_data = np.load(f"/home/coveneuro-leif/{fmri_subid}_{pop}_fmri_aligned_cleaned_timeseries.npy")
        connectivity_measure = ConnectivityMeasure(kind="correlation")
        
        if kind=="single":
            connectome_matrix = connectivity_measure.fit_transform([timeseries_data])[0]
            np.fill_diagonal(connectome_matrix, 0)
            raveled = np.tril(connectome_matrix).ravel()
            # plotting.plot_matrix(correlation_matrix, labels=labels, colorbar=True, vmax=0.8, vmin=-0.8)
            
            if raveled.shape[0]==2304: # To avoid dimension issues - we only lose 9/474
                raveled_mats.append(raveled)
                pops.append(pop)
                subids.append(fmri_subid)
            
        else:
            windows = get_windows(window_size, step_size, timeseries_data)
            connectome_matrix = connectivity_measure.fit_transform(windows)
            connectome_matrix = connectome_matrix.transpose(1,2,0)
            
            if (connectome_matrix.shape[0]==48) & (connectome_matrix.shape[1]==48) & (connectome_matrix.shape[2]>=40):
                np.save(f"/home/coveneuro-leif/Downloads/dl_project_fmri/{fmri_subid}_{pop}_fmri_aligned_cleaned_timeseries_connMat.npy", connectome_matrix)
            else:
                pass
    except:
        print(ix_f)
        continue


That's it for preprocessing fMRI! We are now ready to start having some data modeling fun.

 # Analyzing the data
 
In this part, we will focus on building the models and analyzing the connectivity matrices that we generated above after all of our preprocessing efforts. First, we'll import our relevant libraries

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import numpy as np
import os
import pandas as pd
from cca_zoo.model_selection import GridSearchCV
from cca_zoo.linear import SCCA_IPLS
import pingouin as pg
from scipy.stats import spearmanr
import mne
from sklearn.decomposition import SparsePCA, PCA, KernelPCA
import matplotlib.pyplot as plt
plt.rcParams['svg.fonttype']="none"


Let's define some custom functions to help us out. 

1. plot_kde() is a helper to plot kernel density estimates (KDEs) over top of data points, to help with visualization.
2. scorer() is a custom scoring function that we'll be using later during mode development.

Also, we set the torch and numpy seeds

In [None]:
def plot_kde(V, ax, colour = 'k', label=None, limits=2, n_levels=2):
    # V is a matrix where columns = variables and rows = samples
    # ax is the axis to attach the plot to
    from scipy.stats import gaussian_kde

    kernel = gaussian_kde(V.T)
    xmin, ymin, xmax, ymax = np.r_[V.min(axis = 0)-limits, V.max(axis = 0)+limits]
    
    # Peform the kernel density estimate
    xx, yy = np.mgrid[xmin:xmax:100j, ymin:ymax:100j]
    positions = np.vstack([xx.ravel(), yy.ravel()])
    f = np.reshape(kernel(positions).T, xx.shape)
    
    # cfset = ax.contourf(xx, yy, f, cmap='Blues')
    ax.contour(xx, yy, f, colors=colour, levels = n_levels, linewidths = 2, label=label)
    
    return ax

# Custom scoring function
def scorer(estimator, X):
    dim_corrs = estimator.score(X)
    return dim_corrs.mean()

torch.manual_seed(42)
np.random.seed(42)


Let's define the model and dataloader for the fMRI data first. The convolutional autoencoder takes in whole 3D matrices from each individual, which are accessed using the dataloader (the first 40 time slices, to ensure consistency across individuals), and then squashes them down to a 2D latent space of shape (batch_size, latent_dim). Once the model is trained, we can run just the encoder to extract latent codes for new individuals.

In [None]:
class Conv3DAutoencoderFMRI(nn.Module):
    def __init__(self, latent_dim=2):
        super(Conv3DAutoencoderFMRI, self).__init__()
        self.latent_dim = latent_dim

        # Encoder
        # x = inputs.to("cpu")
        self.encoder = nn.Sequential(
            nn.Conv3d(in_channels=1, out_channels=16, kernel_size=3, stride=1, padding=1),  # Output: (16, 75, 75, 10)
            nn.ReLU(),
            nn.MaxPool3d(kernel_size=2),  # Output: (16, 37, 37, 5)

            nn.Conv3d(in_channels=16, out_channels=32, kernel_size=3, stride=1, padding=1),  # Output: (32, 37, 37, 5)
            nn.ReLU(),
            nn.MaxPool3d(kernel_size=2)  # Output: (32, 18, 18, 2)
        )
        # 32*12*12*10
        ldim = 32*12*12*10
        self.fc_mu = nn.Linear(ldim, latent_dim)
        self.fc_logvar = nn.Linear(ldim, latent_dim)
        self.fc_decode = nn.Linear(latent_dim, ldim)
        
        # Decoder
        # x1=x
        self.decoder = nn.Sequential(
            nn.ConvTranspose3d(in_channels=32, out_channels=16, kernel_size=2, stride=2),  # Output: (16, 36, 36, 4)
            nn.ReLU(),

            nn.ConvTranspose3d(in_channels=16, out_channels=8, kernel_size=2, stride=2),  # Output: (8, 72, 72, 8)
            nn.ReLU(),

            nn.ConvTranspose3d(in_channels=8, out_channels=1, kernel_size=(3,3,3), stride=(1,1,1), padding=(1,1,1)),#(x1).shape,  # Output: (1, 75, 75, 10)
            nn.Sigmoid()  # Output values between 0 and 1
        )
    
    def encode(self, x):
        x = self.encoder(x)
        x = x.view(x.size(0), -1)  # Flatten
        mu = self.fc_mu(x)
        logvar = self.fc_logvar(x)
        return mu, logvar

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)  # Standard deviation
        eps = torch.randn_like(std)  # Random noise
        return mu + eps * std  # Reparameterization trick

    def decode(self, z):
        x = self.fc_decode(z)
        x = x.view(-1, 32, 12, 12, 10)  # Reshape to (32, 18, 18, 2)
        x = self.decoder(x)
        return x

    def forward(self, x):
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        recon = self.decode(z)
        return recon, mu, logvar
    
class Custom3DDatasetFMRI(Dataset):
    def __init__(self, data_dir):
        """
        Args:
            data (np.ndarray or list): List or numpy array of shape (num_samples, 75, 75, 10).
        """
        self.data_dir = data_dir
        self.file_list = [f for f in os.listdir(data_dir) if f.endswith('.npy')] 

    def __len__(self):
        return len(self.file_list)

    def __getitem__(self, idx):
        file_path = os.path.join(self.data_dir, self.file_list[idx])
        sample = np.load(file_path)[:, :, :40]
        sample = torch.tensor(sample, dtype=torch.float32).unsqueeze(0)  # Shape: (1, 75, 75, 10)
        
        return sample, file_path
# def train_autoencoder(model, dataloader, criterion, optimizer, num_epochs=10):
    

Time to train the model (already!) This is the fMRI model. It will pull from the saved matrices that we generated previously, bring them in through our dataloader class (above), and fit the AE model, saving every 100th epoch to file for reproducibility. We also loop over multiple latent dimension sizes so that they can be inspected later on if we wish. Finally, we specify to run the models using CUDA if available, as it speeds up the computation manifold.

In [None]:
device = "cuda:0" if torch.cuda.is_available() else "cpu"

data_dir = "/home/coveneuro-leif/Downloads/dl_project_fmri/"
dataset = Custom3DDatasetFMRI(data_dir)

# Create a DataLoader
batch_size = 16
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

device = "cuda:0" if torch.cuda.is_available() else "cpu"

for latent_dim in [4, 8, 16, 32, 64]:
    num_epochs = 1001 # 1001 so that the 1000th epoch is also saved. THANKS PYTHON
    
    # Instantiate the model
    model = Conv3DAutoencoderFMRI(latent_dim=latent_dim)
    model.to(device)
    
    # Training setup
    criterion = nn.MSELoss()  # Mean squared error loss for reconstruction
    optimizer = optim.Adam(model.parameters(), lr=1e-3)  # Adam optimizer

    model.train()

    for epoch in np.arange(1, num_epochs):
        running_loss = 0.0
    
        for inputs, _ in dataloader:
            # break
            inputs = inputs+inputs.permute(0,1,3,2,4)
            # Move data to GPU if available
            inputs = inputs.to(device)
            # break
            # Zero the gradients
            optimizer.zero_grad()
    
            # Forward pass
            outputs = model(inputs) # outputs[0].shape
            
            # Compute the loss (between the input and its reconstruction)
            loss = criterion(outputs[0], inputs) + criterion(outputs[0].permute(0,1,3,2,4), inputs)
            # loss = criterion(outputs[0], inputs)
            
            # Backpropagation and optimization
            loss.backward()
            optimizer.step()
    
            # Accumulate loss
            running_loss += loss.item() * inputs.size(0)
    
        if epoch%100==0:
            torch.save(model.state_dict(), f"/home/coveneuro-leif/Results/3DCNN_AE_fMRI_{latent_dim}_{epoch}_all.pt")
        
        # Print average loss for the epoch
        epoch_loss = running_loss / len(dataloader.dataset)
        print(f'Epoch {epoch}/{num_epochs}, Loss: {epoch_loss:.6f}')


Next, let's get the anatomical labels for the fMRI dataset...this will be useful later.

In [None]:
from nilearn import datasets
atlas = datasets.fetch_atlas_harvard_oxford('cort-maxprob-thr25-2mm')
atlas_filename = atlas.maps  # Path to the atlas labels
label_names = atlas.labels        # Region names

Next, we can start preparing to plot data from the learned latent space. Define some helpers, define some grouping variables, and then compute groupwise t-tests for identifying significant voxels. Let's grab the 1000th epoch and set the target latent_dim size to 4. Empirically, this combination leads to good visualizations and reconstruction performance without sacrificing separability between groups.

In [None]:
# Plot a cascade of latent traversal
def reparameterize(mu, logvar):
    std = torch.exp(0.5 * logvar)  # Standard deviation
    eps = torch.randn_like(std)  # Random noise
    return mu + eps * std  # Reparameterization trick

# We saved these earlier during training...dealer's choice as to which can be used
# Epochs must be in increments of 100 only 
latent_dim = 4
epoch = 1000
model = Conv3DAutoencoderFMRI(latent_dim=latent_dim)
model.to(device)
model.load_state_dict(torch.load(f"/home/coveneuro-leif/Results/3DCNN_AE_fMRI_{latent_dim}_{epoch}_all.pt"))

model.eval()
means = []
vars_ = []
names_data = []
groups_data = []
for data, name in dataloader:
    data = torch.Tensor(data).to(device)  # Send data to GPU if available
    out = model.encode(data)
    means.append(out[0].cpu().detach().numpy())
    vars_.append(out[1].cpu().detach().numpy())
    names_data.extend(name)
    
    for n in name:
        if "Typically-Developing" in n:
            groups_data.append("Typically-Developing")
        elif "ASD" in n:
            groups_data.append("ASD")
        elif "ADHD" in n:
            groups_data.append("ADHD")
        else:
            groups_data.append("NA")

clinical = pd.read_csv("/home/coveneuro-leif/clinical_SWAN.csv")
names_data = np.array(names_data)
# all_files = glob.glob("/home/coveneuro-leif/Downloads/ae_data/*.npy")
all_names = ["_".join(a.split("/")[-1].split(".")[0].split("_")[:3]) for a in names_data]
all_means = np.concatenate(means)
all_groups = np.array(groups_data)

# Numerical group names so that plotting can be done more easily
all_groups_num = []
for a in all_groups:
    if "Typically-Developing" in a:
        all_groups_num.append(1)
    if "ADHD" in a:
        all_groups_num.append(2)
    if "ASD" in a:
        all_groups_num.append(3)
    if "NA" in a:
        all_groups_num.append(np.nan)
        
all_groups_num = np.array(all_groups_num)        
all_means_df = pd.DataFrame(data=all_means.astype(float), index=all_names, columns=[f"lat{i}" for i in range(all_means.shape[1])])

# TODO use this to sort latents for plotting timeseries
from scipy.stats import ttest_ind
stats_results = []
for i in range(latent_dim):
    var = f"lat{i}"
    # plt.figure()
    # plt.boxplot(positions=[0], x=all_means_df[var][all_groups_num==1])
    # plt.boxplot(positions=[1], x=all_means_df[var][all_groups_num==2])
    # plt.boxplot(positions=[2], x=all_means_df[var][all_groups_num==3])
    stat_1_2, pval_1_2 = ttest_ind(all_means_df[var][all_groups_num==1], all_means_df[var][all_groups_num==2])
    stat_2_3, pval_2_3 = ttest_ind(all_means_df[var][all_groups_num==2], all_means_df[var][all_groups_num==3])
    stat_1_3, pval_1_3 = ttest_ind(all_means_df[var][all_groups_num==1], all_means_df[var][all_groups_num==3])
    stats_results.append((pval_1_2, pval_2_3, pval_1_3, stat_1_2, stat_2_3, stat_1_3))
stats_results_df = pd.DataFrame(data=stats_results, 
                                columns=["p-TD_ADHD", "p-ADHD_ASD", "p-TD_ASD", "t-TD_ADHD", "t-ADHD_ASD", "t-TD_ASD"])

Next, we will traverse the latent space to see how the connectivity matrices change with a change in the latent variables - the most significant latent dimensions also correspond to notable changes in the patterns of the matrices!

In [None]:
mu, logvar = model.encode(data)
z = reparameterize(mu, logvar)
pred_original = model.decode(z).cpu().detach().numpy()

# Perform latent space traversal in statistically-identified ROIs 
num_traversal = 5
perturb_max = 2
latent_perturb = 1

fig, ax = plt.subplots(1, num_traversal) # Top row: ax0 traversal, Bottom row: ax1 traversal
plt.subplots_adjust(wspace=0.5, left=0.05, right=0.99)
add_vals = np.linspace(-perturb_max, perturb_max, num_traversal)

for i in range(num_traversal):
    test_arr = np.zeros(all_means.shape[1])
    test_arr[latent_perturb] = add_vals[i]
    
    # pred_ax = model.decode(z+torch.Tensor(test_arr).to(device)).cpu().detach().numpy()
    pred_ax = model.decode(z+torch.Tensor(test_arr).to(device)).cpu().detach().numpy() - pred_original
    
    # Set aside highest and lowest traversale for finding the nodes that change the most (NEXT STAGE)
    if i==0:
        mat_neg = pred_ax
    elif i==(num_traversal-1):
        mat_pos = pred_ax
        
    ax[i].imshow(pred_ax.squeeze()[0][:, :, 0], vmax=0.5, vmin=-0.5, cmap="jet")
    ax[i].set_xticks(np.arange(len(label_names)), label_names, rotation=90, fontsize=4)
    ax[i].set_yticks(np.arange(len(label_names)), label_names, rotation=00, fontsize=4)
    
main_nodes = mat_pos - mat_neg
main_nodes = main_nodes.squeeze()[0][:, :, 0]
plt.imshow(main_nodes)

Let's plot an example of a reconstructed connectivity matrix vs a corresponding original. Note that the overall patterns are pretty close! Although in this specific example, the symmetry is imperfect. In MEG, the performance tends to be better. Also, fMRI with higher latent dimension counts.

In [None]:
# And now the magic - grab the indices of the highest-magnitude differences as you
# traverse your statistically-defined latent of interest
# TODO n.b. focus on perhaps statistically identifying the best candidate(s) for this step, 
# rather than relying solely on magnitude?
max_r, max_c = np.unravel_index(np.argmax(np.abs(main_nodes)), main_nodes.shape)

# Plot reconstructed connectivity matrices
predicted = model(data)[0]
predicted = predicted.cpu().detach().numpy()
data_for_viz = data
data_for_viz = data_for_viz+data_for_viz.permute(0,1,3,2,4)
data_for_viz = data_for_viz.cpu().detach().numpy()

fig, ax = plt.subplots(1, 2)
plt.subplots_adjust(left=0.10, right=0.99, top=0.95, bottom=0.20, wspace=0.5)
ax[0].set_title("Predicted")
ax[1].set_title("Original")
ax[0].imshow(predicted.squeeze()[0][:, :, 0], vmin=0, vmax=1)
ax[1].imshow(data_for_viz.squeeze()[0][:, :, 0], vmin=0, vmax=1)
ax[0].set_xticks(np.arange(len(label_names)), label_names, rotation=90, fontsize=8)
ax[0].set_yticks(np.arange(len(label_names)), label_names, rotation=00, fontsize=8)
ax[1].set_xticks(np.arange(len(label_names)), label_names, rotation=90, fontsize=8)
ax[1].set_yticks(np.arange(len(label_names)), label_names, rotation=00, fontsize=8)


Because these are 3D volumes that are being reconstructed, we can also plot the the estimated/reconstructed timeseries derived from individual voxels! Pretty cool. Let's plot one of the most significantly varying ones from before...

You will also note immediately that the VAE acts as an adaptive lowpass filter on the data. This could be useful as a preprocessing step for neurophysiological/neuroimaging timeseries data at scale (i.e., across many channels at once).

In [None]:
# Plot timeseries of reconstructed 3D volumes
fig, ax = plt.subplots(1, num_traversal) # Top row: ax0 traversal, Bottom row: ax1 traversal
row = max_r
col = max_c
row_name = label_names[row]
col_name = label_names[col]
ts_combined_title = f"Connectivity: {row_name} × {col_name}"
plt.suptitle(ts_combined_title)

for i in range(num_traversal):
    test_arr = np.zeros(all_means.shape[1])
    test_arr[latent_perturb] = add_vals[i]
    
    # Extract most significant latent timeseries
    pred_ax = model.decode(z+torch.Tensor(test_arr).to(device)).cpu().detach().numpy()
    # pred_ax = model.decode(z+torch.Tensor(test_arr).to(device)).cpu().detach().numpy() - pred_original
    ts_pred = pred_ax.squeeze()[0][row, col, :]
    ts_real = data_for_viz.squeeze()[0][row, col, :]
    
    # Set aside highest and lowest traversal for finding the nodes that change the most (NEXT STAGE)
    ax[i].plot(ts_pred, label="Predicted")
    ax[i].plot(ts_real, label="Original")


Next, let's further embed the latent space from the VAE such that it only occupies 2 dimensions. This is for plotting purposes. We'll use SparsePCA for now, although other alternatives could include regular PCA, tSNE, or UMAP (preferred for large sample size/many-dimensions scenarios because it is more expressive)

In [None]:
group = []
for t in clinical['primary_diagnosis'].values:
    if t=="Typically-Developing":
        group.append(0)
    elif t=="ADHD":
        group.append(1)
    elif t=="ASD":
        group.append(2)
# clinical.insert(0, "group", group)

# Add some accounting of where the means are w.r.t. the clinical values - useful later on!!!
all_means_df.insert(0, "num_index", np.arange(all_means_df.shape[0]))
out = pd.merge(left=clinical, right=all_means_df, left_on="subject_id", right_index=True)
out.index = out['num_index']
out.drop("num_index", axis=1, inplace=True) # Remove column to avoid messing up indices later and, hey it's served its purpose by now

limits = 2e-3
n_levels = 2
embedded = SparsePCA(n_components=2).fit_transform(all_means)


Plot all embeddings for the data.

In [None]:

fig, ax = plt.subplots(1,1)
ax.scatter(embedded[:, 0], embedded[:, 1], c=all_groups_num)
plot_kde(V=embedded[all_groups_num==1, :], ax=ax, colour='purple', label=None, limits=limits, n_levels=n_levels)
plot_kde(V=embedded[all_groups_num==2, :], ax=ax, colour='teal', label=None, limits=limits, n_levels=n_levels)
plot_kde(V=embedded[all_groups_num==3, :], ax=ax, colour='yellow', label=None, limits=limits, n_levels=n_levels)


Plot only the embeddings for which we have corresponding clinical scores.

In [None]:
fig, ax = plt.subplots(1,1)
embedded_short = embedded.copy()[out.index, :]
all_means_short = all_means.copy()[out.index, :]
all_groups_num_short = all_groups_num[out.index]
ax.scatter(embedded_short[:, 0], embedded_short[:, 1], c=all_groups_num_short)
plot_kde(V=embedded_short[all_groups_num_short==1, :], ax=ax, colour='purple', label=None, limits=limits, n_levels=n_levels)
plot_kde(V=embedded_short[all_groups_num_short==2, :], ax=ax, colour='teal', label=None, limits=limits, n_levels=n_levels)
plot_kde(V=embedded_short[all_groups_num_short==3, :], ax=ax, colour='yellow', label=None, limits=limits, n_levels=n_levels)


Finally, let's take a peek at some canonical correlation analyses (CCA). This technique enables us to inspect the shared latent space between the VAE embeddings and the clincal scores - in essence, to see whether the VAE embeddings capture relevant clinical phenomena!

In [None]:

rs = 20

# TODO they advise against using LASSO (i.e., L1) with alpha=0 for "numerical reasons" - probably unstable
linear_cca = SCCA_IPLS(latent_dimensions = 3, alpha=1, l1_ratio=1, epochs=10000, random_state=rs)
train_view_1 = out.iloc[:, 2:20].values
train_view_2 = out.iloc[:, 20:].values

# linear_cca.fit((train_view_1.astype(float), train_view_2))
# linear_cca.pairwise_correlations((train_view_1.astype(float), train_view_2))
# weights = linear_cca.weights_

# pwcorrs = pg.pairwise_corr(out, columns=[list(out.columns[2:20]), list(out.columns[20:])], method="spearman", padjust="fdr_bh")

# Define grid of potential regularization parameters
c1 = [0.05, 0.1, 0.3, 0.7, 0.9]
c2 = [0.05, 0.1, 0.3, 0.7, 0.9]
c3 = [2, 3, 4]
param_grid = {'l1_ratio': c1,
              'alpha': c2,
              'latent_dimensions': c3}

cv = 5

# Conduct grid search
ridge = GridSearchCV(SCCA_IPLS(random_state=rs), param_grid=param_grid,
                     cv=cv, verbose=True, scoring=scorer).fit((train_view_1, train_view_2)).best_estimator_

projections = ridge.transform((train_view_1, train_view_2))
correlation = ridge.score((train_view_1, train_view_2))
view_1_weights = ridge.weights_[0]
view_2_weights = ridge.weights_[1]

# Loadings biplot
plt.figure()
plt.plot([-1, 1], [0, 0], ls=":", c='k')
plt.plot([0, 0], [-1, 1], ls=":", c='k')
view_loadings = ridge.loadings_([train_view_1, train_view_2])
plt.scatter(view_loadings[0][:, 0], view_loadings[0][:, 1])
plt.scatter(view_loadings[1][:, 0], view_loadings[1][:, 1])
for i in range(view_loadings[0].shape[0]):
    plt.arrow(0, 0, view_loadings[0][i, 0], view_loadings[0][i, 1], color="blue")
    plt.text(view_loadings[0][i, 0], view_loadings[0][i, 1], out.columns[2:20][i], ha="center")
for i in range(view_loadings[1].shape[0]):
    plt.arrow(0, 0, view_loadings[1][i, 0], view_loadings[1][i, 1], color="orange")

# Weights biplots(?)
plt.figure()
plt.scatter(view_1_weights[:, 0], view_1_weights[:, 1])
plt.scatter(view_2_weights[:, 0], view_2_weights[:, 1])

plt.figure()
all_cols = out.columns
plt.scatter(projections[0][:, 0], projections[1][:, 0], c = out.loc[:, out.columns[2]])

plt.figure()
plt.scatter(embedded_short[:, 0], embedded_short[:, 1], c = out.loc[:, out.columns[3]])

# Compile results - correlations between latents and SWAN scores
res = spearmanr(embedded_short, out.iloc[:, 2:20])
res = spearmanr(all_means_short, out.iloc[:, 2:20])
res_xcorr_r = pd.DataFrame(data=res[0][latent_dim:, :latent_dim], index=out.columns[2:20], columns=[f"latent_{i}" for i in range(latent_dim)])
res_xcorr_r.insert(0, "factor", ["Attention"]*9 + ["Hyperactivity"]*9)
res_xcorr_p = pd.DataFrame(data=res[1][latent_dim:, :latent_dim], index=out.columns[2:20], columns=[f"latent_{i}" for i in range(latent_dim)])
res_xcorr_p.insert(0, "factor", [0]*9 + [1]*9)

os.makedirs(f"/home/coveneuro-leif/Results/VAE_fMRI_{latent_dim}lt_{epoch}epo/", exist_ok=True)
res_xcorr_r.to_csv(f"/home/coveneuro-leif/Results/VAE_fMRI_{latent_dim}lt_{epoch}epo/latent_clinical_correlations.csv")
res_xcorr_p.to_csv(f"/home/coveneuro-leif/Results/VAE_fMRI_{latent_dim}lt_{epoch}epo/latent_clinical_pvals.csv")

# TODO the correlations for latent 2 an 4 recapitulate the hyperactivity items
# of the SWAN as defined in the 2-factor model of the SWAN (https://www.frontiersin.org/journals/psychiatry/articles/10.3389/fpsyt.2024.1330716)
# Latent 4 is more focused on hyperactivity, latent 2 is involved ith both including attention


You can see clearly that some of the latent dimensions (as stored in the res_xcorr_) dataframes correspond strongly to the different factors of the SWAN as defined in peer-reviewed literature. cf. how this compares to the MEG case: relationships with fMRI tend to be somewhat weaker but nevertheless capture the overall factor structure in at least a subset of the VAE latent dimensions.

# Okay, time for Part II: MEG

The MEG analysis is a bit different, but mainly in the sense that the models have different dims. This is due to the fact that the connectivity matrices have different shapes - both spatially, and temporally. In practice, this just means that the internals of the model need to have slightly different shapes, and the dataloader needs to index the data in a slightly different way.

In [None]:
class Conv3DAutoencoder(nn.Module):
    def __init__(self, latent_dim=2):
        super(Conv3DAutoencoder, self).__init__()
        self.latent_dim = latent_dim
        # inputs_ = inputs.cpu()#.detach().numpy()
        # Encoder
        self.encoder = nn.Sequential(
            nn.Conv3d(in_channels=1, out_channels=16, kernel_size=3, stride=1, padding=1),  # Output: (16, 75, 75, 10)
            nn.ReLU(),
            nn.MaxPool3d(kernel_size=2),  # Output: (16, 37, 37, 5)

            nn.Conv3d(in_channels=16, out_channels=32, kernel_size=3, stride=1, padding=1),  # Output: (32, 37, 37, 5)
            nn.ReLU(),
            nn.MaxPool3d(kernel_size=2)  # Output: (32, 18, 18, 2)
        )
        
        ldim = 32*18*18*10
        self.fc_mu = nn.Linear(ldim, latent_dim)
        self.fc_logvar = nn.Linear(ldim, latent_dim)
        self.fc_decode = nn.Linear(latent_dim, ldim)
        # x = fc_mu(x.view(x.size(0), -1))
        
        # Decoder
        self.decoder = nn.Sequential(
            nn.ConvTranspose3d(in_channels=32, out_channels=16, kernel_size=2, stride=2),  # Output: (16, 36, 36, 4)
            nn.ReLU(),

            nn.ConvTranspose3d(in_channels=16, out_channels=8, kernel_size=2, stride=2),  # Output: (8, 72, 72, 8)
            nn.ReLU(),

            nn.ConvTranspose3d(in_channels=8, out_channels=1, kernel_size=(6,6,3), stride=(1,1,1), padding=(1,1,1)),  # Output: (1, 75, 75, 10)
            nn.Sigmoid()  # Output values between 0 and 1
        )
    
    def encode(self, x):
        x = self.encoder(x)
        x = x.view(x.size(0), -1)  # Flatten
        mu = self.fc_mu(x)
        logvar = self.fc_logvar(x)
        return mu, logvar

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)  # Standard deviation
        eps = torch.randn_like(std)  # Random noise
        return mu + eps * std  # Reparameterization trick

    def decode(self, z):
        x = self.fc_decode(z)
        x = x.view(-1, 32, 18, 18, 10)  # Reshape to (32, 18, 18, 2)
        x = self.decoder(x)
        return x

    def forward(self, x):
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        recon = self.decode(z)
        return recon, mu, logvar

class Custom3DDataset(Dataset):
    def __init__(self, data_dir):
        """
        Args:
            data (np.ndarray or list): List or numpy array of shape (num_samples, 75, 75, 10).
        """
        self.data_dir = data_dir
        
        self.file_list = [f for f in os.listdir(data_dir) if f.endswith('.npy')] 

    def __len__(self):
        return len(self.file_list)

    def __getitem__(self, idx):
        # idx=0
        # file_path = os.path.join(data_dir, file_list[idx])
        file_path = os.path.join(self.data_dir, self.file_list[idx])
        # file_path = "/home/coveneuro-leif/Downloads/dl_project_broad/PND03_HSC_0689_SE01MEG_task-Rest_run-1_meg_broad.npy"
        # sample = np.load(file_path)[:, :, :8]
        sample = np.load(file_path)[:, :, :40]

        # Convert to a PyTorch tensor and add a channel dimension
        sample = torch.tensor(sample, dtype=torch.float32).unsqueeze(0)  # Shape: (1, 75, 75, 10)
        
        return sample, file_path

Let's train the model. Note that in this case, there is an additional outer loop to iterate over frequency bands. This is helpful for allowing us to inspect band-level differentiability of the groups. It is otherwise the same logic as the fMRI case.

In [None]:
for band in ["theta", "alpha", "beta"]:
    # band="alpha"
    data_dir = f"/home/coveneuro-leif/Downloads/dl_project_{band}/"
    dataset = Custom3DDataset(data_dir)
    
    batch_size = 16
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
    
    device = "cuda:0" if torch.cuda.is_available() else "cpu"
    
    for latent_dim in [4, 8, 16, 32, 64]:
        num_epochs = 1001 # 1001 so that the 1000th epoch is also saved. THANKS PYTHON
        
        # Instantiate the model
        model = Conv3DAutoencoder(latent_dim=latent_dim)
        model.to(device)
        
        # Training setup
        criterion = nn.MSELoss()  # Mean squared error loss for reconstruction
        optimizer = optim.Adam(model.parameters(), lr=1e-3)  # Adam optimizer
        
        model.train()
        
        for epoch in np.arange(1, num_epochs): # Again, to make sure we can train for EXACTLY 1000 epochs (i.e., not save the 101st...201st...etc. Not that it would matter too much)
            running_loss = 0.0
        
            for inputs, _ in dataloader:
                # break
                inputs = inputs+inputs.permute(0,1,3,2,4)
                # Move data to GPU if available
                inputs = inputs.to(device)
                # break
                # Zero the gradients
                optimizer.zero_grad()
        
                # Forward pass
                outputs = model(inputs) # outputs[0].shape
                
                # Compute the loss (between the input and its reconstruction)
                loss = criterion(outputs[0], inputs) + criterion(outputs[0].permute(0,1,3,2,4), inputs)
                
                # Backpropagation and optimization
                loss.backward()
                optimizer.step()
        
                # Accumulate loss
                running_loss += loss.item() * inputs.size(0)
            
            if epoch%100==0:
                torch.save(model.state_dict(), f"/home/coveneuro-leif/Results/3DCNN_AE_MEG_{latent_dim}_{epoch}_{band}.pt")
        
            # Print average loss for the epoch
            epoch_loss = running_loss / len(dataloader.dataset)
            print(f'Epoch {epoch}/{num_epochs}, Loss: {epoch_loss:.6f}')


Grab anatomical labels like we did before

In [None]:
data_path = mne.datasets.sample.data_path()
subject = "sample"
data_dir = data_path / "MEG" / subject
subjects_dir = data_path / "subjects"
labels = mne.read_labels_from_annot("sample", parc="aparc", subjects_dir=subjects_dir)
labels_vol = [
    "Left-Amygdala",
    "Left-Thalamus-Proper",
    "Left-Cerebellum-Cortex",
    "Brain-Stem",
    "Right-Amygdala",
    "Right-Thalamus-Proper",
    "Right-Cerebellum-Cortex",
]

label_names = [label.name for label in labels] + labels_vol
lh_labels = [name for name in label_names if name.endswith("lh")]

Latent space plotting as before for fMRI. You will note that generally the separation is better, and is the stringest for alpha- and beta bands

In [None]:
device = "cuda:0" if torch.cuda.is_available() else "cpu"
band = "alpha"
data_dir = f"/home/coveneuro-leif/Downloads/dl_project_{band}/"
dataset = Custom3DDataset(data_dir)
batch_size = 16
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

# Plot a cascade of latent traversal
def reparameterize(mu, logvar):
    std = torch.exp(0.5 * logvar)  # Standard deviation
    eps = torch.randn_like(std)  # Random noise
    return mu + eps * std  # Reparameterization trick

# We saved these earlier during training...dealer's choice as to which can be used
# Epochs must be in increments of 100 only 
latent_dim = 16
epoch = 1000
model = Conv3DAutoencoder(latent_dim=latent_dim)
model.to(device)
model.load_state_dict(torch.load(f"/home/coveneuro-leif/Results/3DCNN_AE_MEG_{latent_dim}_{epoch}.pt"))

model.eval()
means = []
vars_ = []
names_data = []
groups_data = []
for data, name in dataloader:
    data = torch.Tensor(data).to(device)  # Send data to GPU if available
    out = model.encode(data)
    means.append(out[0].cpu().detach().numpy())
    vars_.append(out[1].cpu().detach().numpy())
    names_data.extend(name)
    
    for n in name:
        if "Typically-Developing" in n:
            groups_data.append("Typically-Developing")
        elif "ASD" in n:
            groups_data.append("ASD")
        elif "ADHD" in n:
            groups_data.append("ADHD")
        else:
            groups_data.append("NA")

clinical = pd.read_csv("/home/coveneuro-leif/clinical_SWAN.csv")
names_data = np.array(names_data)
# all_files = glob.glob("/home/coveneuro-leif/Downloads/ae_data/*.npy")
all_names = ["_".join(a.split("/")[-1].split(".")[0].split("_")[:3]) for a in names_data]
all_means = np.concatenate(means)
all_groups = np.array(groups_data)

# Numerical group names so that plotting can be done more easily
all_groups_num = []
for a in all_groups:
    if "Typically-Developing" in a:
        all_groups_num.append(1)
    if "ADHD" in a:
        all_groups_num.append(2)
    if "ASD" in a:
        all_groups_num.append(3)
    if "NA" in a:
        all_groups_num.append(np.nan)
        
all_groups_num = np.array(all_groups_num)        
all_means_df = pd.DataFrame(data=all_means.astype(float), index=all_names, columns=[f"lat{i}" for i in range(all_means.shape[1])])

# TODO use this to sort latents for plotting timeseries
from scipy.stats import ttest_ind
stats_results = []
for i in range(latent_dim):
    var = f"lat{i}"
    # plt.figure()
    # plt.boxplot(positions=[0], x=all_means_df[var][all_groups_num==1])
    # plt.boxplot(positions=[1], x=all_means_df[var][all_groups_num==2])
    # plt.boxplot(positions=[2], x=all_means_df[var][all_groups_num==3])
    stat_1_2, pval_1_2 = ttest_ind(all_means_df[var][all_groups_num==1], all_means_df[var][all_groups_num==2])
    stat_2_3, pval_2_3 = ttest_ind(all_means_df[var][all_groups_num==2], all_means_df[var][all_groups_num==3])
    stat_1_3, pval_1_3 = ttest_ind(all_means_df[var][all_groups_num==1], all_means_df[var][all_groups_num==3])
    stats_results.append((pval_1_2, pval_2_3, pval_1_3, stat_1_2, stat_2_3, stat_1_3))
stats_results_df = pd.DataFrame(data=stats_results, 
                                columns=["p-TD_ADHD", "p-ADHD_ASD", "p-TD_ASD", "t-TD_ADHD", "t-ADHD_ASD", "t-TD_ASD"])


Plots of reconstructed connectivity matrices. Right out of the gate, it is clear that the matrices, compared to fMRI, are more symmetrical and more well-behaved in general.

In [None]:
max_r, max_c = np.unravel_index(np.argmax(np.abs(main_nodes)), main_nodes.shape)

# Plot reconstructed connectivity matrices
predicted = model(data)[0]
predicted = predicted.cpu().detach().numpy()
data_for_viz = data
data_for_viz = data_for_viz+data_for_viz.permute(0,1,3,2,4)
data_for_viz = data_for_viz.cpu().detach().numpy()

fig, ax = plt.subplots(1, 2)
plt.subplots_adjust(left=0.10, right=0.99, top=0.95, bottom=0.20, wspace=0.5)
ax[0].set_title("Predicted")
ax[1].set_title("Original")
ax[0].imshow(predicted.squeeze()[0][:, :, 0], vmin=0, vmax=1)
ax[1].imshow(data_for_viz.squeeze()[0][:, :, 0], vmin=0, vmax=1)
ax[0].set_xticks(np.arange(len(label_names)), label_names, rotation=90, fontsize=8)
ax[0].set_yticks(np.arange(len(label_names)), label_names, rotation=00, fontsize=8)
ax[1].set_xticks(np.arange(len(label_names)), label_names, rotation=90, fontsize=8)
ax[1].set_yticks(np.arange(len(label_names)), label_names, rotation=00, fontsize=8)

Plots of reconstructed volumes. Again, it's pretty cool! You'll note right away that the reconstruction performs better here than it did with fMRI.

In [None]:
# Plot timeseries of reconstructed 3D volumes
fig, ax = plt.subplots(1, num_traversal) # Top row: ax0 traversal, Bottom row: ax1 traversal
row = max_r
col = max_c
row_name = label_names[row]
col_name = label_names[col]
ts_combined_title = f"Connectivity: {row_name} × {col_name}"
plt.suptitle(ts_combined_title)

for i in range(num_traversal):
    test_arr = np.zeros(all_means.shape[1])
    test_arr[latent_perturb] = add_vals[i]
    
    # Extract most significant latent timeseries
    pred_ax = model.decode(z+torch.Tensor(test_arr).to(device)).cpu().detach().numpy()
    # pred_ax = model.decode(z+torch.Tensor(test_arr).to(device)).cpu().detach().numpy() - pred_original
    ts_pred = pred_ax.squeeze()[0][row, col, :]
    ts_real = data_for_viz.squeeze()[0][row, col, :]
    
    # Set aside highest and lowest traversal for finding the nodes that change the most (NEXT STAGE)
    ax[i].plot(ts_pred, label="Predicted")
    ax[i].plot(ts_real, label="Original")

Plots of embeddings using SparsePCA

In [None]:

group = []
for t in clinical['primary_diagnosis'].values:
    if t=="Typically-Developing":
        group.append(0)
    elif t=="ADHD":
        group.append(1)
    elif t=="ASD":
        group.append(2)
# clinical.insert(0, "group", group)

# Add some accounting of where the means are w.r.t. the clinical values - useful later on!!!
all_means_df.insert(0, "num_index", np.arange(all_means_df.shape[0]))
out = pd.merge(left=clinical, right=all_means_df, left_on="subject_id", right_index=True)
out.index = out['num_index']
out.drop("num_index", axis=1, inplace=True) # Remove column to avoid messing up indices later and, hey it's served its purpose by now

limits = 2e-3
n_levels = 2
embedded = SparsePCA(n_components=2).fit_transform(all_means)

Plotting all embeddings by group

In [None]:
fig, ax = plt.subplots(1,1)
ax.scatter(embedded[:, 0], embedded[:, 1], c=all_groups_num)
plot_kde(V=embedded[all_groups_num==1, :], ax=ax, colour='purple', label=None, limits=limits, n_levels=n_levels)
plot_kde(V=embedded[all_groups_num==2, :], ax=ax, colour='teal', label=None, limits=limits, n_levels=n_levels)
plot_kde(V=embedded[all_groups_num==3, :], ax=ax, colour='yellow', label=None, limits=limits, n_levels=n_levels)


Plotting where we have clinical scores only.

In [None]:
 
fig, ax = plt.subplots(1,1)
embedded_short = embedded.copy()[out.index, :]
all_means_short = all_means.copy()[out.index, :]
all_groups_num_short = all_groups_num[out.index]
ax.scatter(embedded_short[:, 0], embedded_short[:, 1], c=all_groups_num_short)
plot_kde(V=embedded_short[all_groups_num_short==1, :], ax=ax, colour='purple', label=None, limits=limits, n_levels=n_levels)
plot_kde(V=embedded_short[all_groups_num_short==2, :], ax=ax, colour='teal', label=None, limits=limits, n_levels=n_levels)
plot_kde(V=embedded_short[all_groups_num_short==3, :], ax=ax, colour='yellow', label=None, limits=limits, n_levels=n_levels)


Plotting CCA as we did with fMRI

In [None]:
rs = 20

# TODO they advise against using LASSO (i.e., L1) with alpha=0 for "numerical reasons" - probably unstable
linear_cca = SCCA_IPLS(latent_dimensions = 3, alpha=1, l1_ratio=1, epochs=10000, random_state=rs)
train_view_1 = out.iloc[:, 2:20].values
train_view_2 = out.iloc[:, 20:].values

linear_cca.fit((train_view_1.astype(float), train_view_2))
linear_cca.pairwise_correlations((train_view_1.astype(float), train_view_2))
weights = linear_cca.weights_

# pwcorrs = pg.pairwise_corr(out, columns=[list(out.columns[2:20]), list(out.columns[20:])], method="spearman", padjust="fdr_bh")

# Define grid of potential regularization parameters
c1 = [0.05, 0.1, 0.3, 0.7, 0.9]
c2 = [0.05, 0.1, 0.3, 0.7, 0.9]
c3 = [2, 3, 4]
param_grid = {'l1_ratio': c1,
              'alpha': c2,
              'latent_dimensions': c3}

cv = 5

# Conduct grid search
ridge = GridSearchCV(SCCA_IPLS(random_state=rs), param_grid=param_grid,
                     cv=cv, verbose=True, scoring=scorer).fit((train_view_1, train_view_2)).best_estimator_

projections = ridge.transform((train_view_1, train_view_2))
correlation = ridge.score((train_view_1, train_view_2))
view_1_weights = ridge.weights_[0]
view_2_weights = ridge.weights_[1]

# Loadings biplot
plt.figure()
view_loadings = ridge.loadings_([train_view_1, train_view_2])
plt.scatter(view_loadings[0][:, 0], view_loadings[0][:, 1])
plt.scatter(view_loadings[1][:, 0], view_loadings[1][:, 1])

# Weights biplots(?)
plt.figure()
plt.scatter(view_1_weights[:, 0], view_1_weights[:, 1])
plt.scatter(view_2_weights[:, 0], view_2_weights[:, 1])

plt.figure()
all_cols = out.columns
plt.scatter(projections[0][:, 0], projections[1][:, 0], c = out.loc[:, out.columns[2]])

plt.figure()
plt.scatter(embedded_short[:, 0], embedded_short[:, 1], c = out.loc[:, out.columns[3]])

# Compile results - correlations between latents and SWAN scores
res = spearmanr(embedded_short, out.iloc[:, 2:20])
res = spearmanr(all_means_short, out.iloc[:, 2:20])
res_xcorr_r = pd.DataFrame(data=res[0][latent_dim:, :latent_dim], index=out.columns[2:20], columns=[f"latent_{i}" for i in range(latent_dim)])
res_xcorr_r.insert(0, "factor", [0]*9 + [1]*9)
res_xcorr_p = pd.DataFrame(data=res[1][latent_dim:, :latent_dim], index=out.columns[2:20], columns=[f"latent_{i}" for i in range(latent_dim)])
res_xcorr_p.insert(0, "factor", [0]*9 + [1]*9)

os.makedirs(f"/home/coveneuro-leif/Results/VAE_{latent_dim}lt_{epoch}epo/", exist_ok=True)
res_xcorr_r.to_csv(f"/home/coveneuro-leif/Results/VAE_{latent_dim}lt_{epoch}epo/latent_clinical_correlations.csv")
res_xcorr_p.to_csv(f"/home/coveneuro-leif/Results/VAE_{latent_dim}lt_{epoch}epo/latent_clinical_pvals.csv")


# TADAHH! That's all
We have walked through a comprehensive analysis of brain activity data from start through preprocessing to eventual downstream analysis. There are many more ways to explore these rich datasets, and we intentionally glossed over some of the basic parts that were done at the very start - exploratory analysis, subjective initial plotting of data and identifying trends - in service of providing a comprehensive overview.

## What did we learn?
We learned a lot of interesting things not only about ML but also ADHD and ASD brain activity patterns, and how they relate to typically-developing children and clinical ratings of function throughout development. We generally infer from this set that individuals in the ADHD and ASD groups both tend to have distinct subgroups within them that differ from each other and from typically-developing children. Much more work will be required to explore these findings in depth, but it demonstrates the great promise that neuroimaging and neurophysiological data analysis have in characterizing developmental disorders, and may help us point the way towards better therapies for these individuals. Thanks for following along!