In [None]:
root = '' # change as needed
data_root = root + 'data/'
save_path = root + 'results/inference/'
model_path = root + 'model' # path to the model reported in the paper (downloaded from Zenodo), you can change it to your own model
sc_file = data_root + 'SC_dbs80HARDIFULL.mat'
dbs_path = root + 'ds80_labels.csv'
yeo_path = root + 'dbs802Yeo7.csv'
!mkdir -p {save_path}

# 1) Read

In [None]:
import mat73
import numpy as np
import pandas as pd

In [None]:
def get_ts_hcp_task(data_root):
    '''
    Reads the HCP task dataset.

    Args:
        data_root (str): Folder containing the data files (.mat).

    Returns:
        time_series (dict): Dictionary where the keys are the tasks and the values are the arrays.
    '''
    time_series = {}
    time_series['memory'] = mat73.loadmat(data_root+'hcp1003_WM_LR_dbs80.mat')
    time_series['gambling'] = mat73.loadmat(data_root+'hcp1003_GAMBLING_LR_dbs80.mat')
    time_series['motor'] = mat73.loadmat(data_root+'hcp1003_MOTOR_LR_dbs80.mat')
    time_series['language'] = mat73.loadmat(data_root+'hcp1003_LANGUAGE_LR_dbs80.mat')
    time_series['social'] = mat73.loadmat(data_root+'hcp1003_SOCIAL_LR_dbs80.mat')
    time_series['relational'] = mat73.loadmat(data_root+'hcp1003_RELATIONAL_LR_dbs80.mat')
    time_series['emotion'] = mat73.loadmat(data_root+'hcp1003_EMOTION_LR_dbs80.mat')
    time_series['rest'] = mat73.loadmat(data_root+'hcp1003_REST1_LR_dbs80.mat')
    return time_series

# Read time series
data_ts = get_ts_hcp_task(data_root)
# Get only useful data
for k in data_ts.keys():
  data_ts[k] = [i['dbs80ts'] for i in data_ts[k]['subject'] if type(i) is dict]

In [None]:
# Create time series dataframe
rows = []
for cohort, values in data_ts.items():
    for id, value in enumerate(values):
        rows.append({'cohort': cohort.capitalize(), 'bold': value, 'subject_id': id})
del data_ts
time_series = pd.DataFrame(rows)
time_series

# 2) Signal filtering

In [None]:
from scipy.signal import butter, detrend, filtfilt

def demean(x,dim=0):
    dims = x.size
    return x - np.tile(np.mean(x,dim), dims)  # repmat(np.mean(x,dim),dimrep)

def BandPassFilter(boldSignal, f_low, f_high, TR, k, removeStrongArtefacts=True):
    # Convenience method to apply a filter (always the same one) to all areas in a BOLD signal. For a single,
    # isolated area evaluation, better use the method below.
    (N, Tmax) = boldSignal.shape
    fnq = 1./(2.*TR)              # Nyquist frequency
    Wn = [f_low/fnq, f_high/fnq]                                   # butterworth bandpass non-dimensional frequency
    bfilt, afilt = butter(k,Wn, btype='band', analog=False)   # construct the filter
    # bfilt = bfilt_afilt[0]; afilt = bfilt_afilt[1]  # numba doesn't like unpacking...
    signal_filt = np.zeros(boldSignal.shape)
    for seed in range(N):
        if not np.isnan(boldSignal[seed, :]).any():  # No problems, go ahead!!!
            ts = demean(detrend(boldSignal[seed, :]))  # Probably, we do not need to demean here, detrend already does the job...

            if removeStrongArtefacts:
                ts[ts>3.*np.std(ts)] = 3.*np.std(ts)    # Remove strong artefacts
                ts[ts<-3.*np.std(ts)] = -3.*np.std(ts)  # Remove strong artefacts

            signal_filt[seed,:] = filtfilt(bfilt, afilt, ts, padlen=3*(max(len(bfilt),len(afilt))-1))  # Band pass filter. padlen modified to get the same result as in Matlab
        else:  # We've found problems, mark this region as "problematic", to say the least...
            print(f'############ Warning!!! BandPassFilter: NAN found at region {seed} ############')
            signal_filt[seed,0] = np.nan
    return signal_filt

def AmplitudeFilter(time_series):
    return time_series/np.abs(time_series).max()

In [None]:
# Filters parameters
TR = 2.0
k = 2                                # 2nd order butterworth filter
f_low = 0.008                        # lowpass frequency of filter
f_high = 0.08                        # highpass

# Apply filters
time_series.loc[:,'bold'] = time_series['bold'].apply(BandPassFilter, args=(f_low, f_high, TR, k))
time_series.loc[:,'bold'] = time_series['bold'].apply(AmplitudeFilter)

# 3) Windowing

In [None]:
t_use = 50

def create_windows(ts):
    N, t = ts.shape
    n_windows = t//t_use
    windows = ts[:, -n_windows*t_use:].reshape((n_windows, N, t_use))
    return windows

# Extract windows and create a new dataframe
all_windows = []
for idx, row in time_series.iterrows():
    windows = create_windows(row['bold'])
    for i, window in enumerate(windows):
        all_windows.append({'cohort': row['cohort'], 'subject_id': row['subject_id'], 'window_id': i, 'window': window})
    all_windows.append({'cohort': row['cohort'], 'subject_id': row['subject_id'], 'window_id': 'mean', 'window': windows.mean(axis=0)})
del time_series

# Create a new dataframe from the list of windows
windows = pd.DataFrame(all_windows)
del all_windows
windows

In [None]:
def get_filename(row):
    if row['window_id'] == 'mean':
        return f"{row['cohort']}_{row['subject_id']}"
    else:
        return f"{row['cohort']}_{row['subject_id']}_{row['window_id']}"

# Add filenames to the df
windows['filename'] = windows.apply(get_filename, axis=1)

# 4) Time-series-to-image conversion

In [None]:
from matplotlib import pyplot as plt

ratio = 77
def plot_ts(ts, path):
    N, t = ts.shape
    plt.figure(figsize=(t/ratio, N/ratio))
    plt.imshow(ts, aspect='auto', cmap='viridis', vmin=-1, vmax=1)
    plt.axis('off')
    plt.savefig(path, bbox_inches='tight', pad_inches=0)
    plt.close()

In [None]:
import tqdm, os

if (not os.path.exists(save_path+'img/')):
	os.mkdir(save_path+'img/')

# Create and save images
for _, row in tqdm.tqdm(windows.iterrows(), total=len(windows)):
    plot_ts(row['window'], f"{save_path}/img/{row['filename']}.png")

# 5) Inference

In [None]:
from fastai.vision.all import *

a_min, a_max = -1, 1

# Metric used
def rmse_a(inp, targ):
  return rmse(inp, targ)*100/(a_max-a_min)

In [None]:
# Get inputs for the model
image_list = save_path + '/img/' + windows['filename'].values + '.png'
# Create dummy dataloader
dls = ImageDataLoaders.from_path_func('', [0], lambda x: '0', bs=16, item_tfms=Resize((80, 50), method='squish'))
# Load model
learn = vision_learner(dls, 'convnext_tiny_in22k', n_out=80, y_range=(-1,1), loss_func=MSELossFlat).to_fp16()
learn.load(model_path, device='cuda')
# Predict
test_dl = learn.dls.test_dl(image_list, device='cuda')
preds, _ = learn.get_preds(dl=test_dl)

In [None]:
# Add predictions to the dataframe
windows['pred'] = [p for p in preds.numpy()]

# 6) Yeo networks

In [None]:
# Read names of the 80 ROIs
dbs_names = pd.read_csv(dbs_path, sep=';').loc[:,'Rois'].values
# Read asociated regions
yeo_ids = pd.read_csv(yeo_path, header=None).iloc[0].values
yeo_names = np.array([None, 'Visual', 'Somatomotor', 'Dorsal attention', 'Ventral attention', 'Limbic', 'Frontoparietal', 'Default'])
# Filter non cortical regions
dbs_names_cortex = dbs_names[np.r_[0:31, 49:80]]

In [None]:
from collections import defaultdict

tasks_act = defaultdict(list)
for c in np.unique(windows.cohort.values):
    # Filter windows by task
    selection = np.stack(windows[windows['cohort'] == c][windows['window_id'] == 'mean'].pred.values)
    # Get the mean value for each ROI in a particular task
    activations = selection.mean(axis=0)
    # Filter non cortical regions
    activations_cortex = activations[np.r_[0:31, 49:80]]
    tasks_act[c].append(activations_cortex)

In [None]:
comp = pd.DataFrame(columns=tasks_act.keys())
comp = comp.drop(columns=['Rest'])

for c in tasks_act.keys():
	if c != 'Rest':
		net_act = defaultdict(list)
		activations_cortex = tasks_act[c][0]-tasks_act['Rest'][0]
		for net,act in zip(yeo_names[yeo_ids], activations_cortex):
			if act>-20:
				net_act[net].append(act)
		for net in net_act.keys():
			comp.loc[net, c] = [np.mean(net_act[net]), np.max(net_act[net])]

In [None]:
import seaborn as sns

data = {'x': [], 'y': [], 'network': []}

# Iterate over the df to extract y values and column/row labels
for col in comp.columns:
    for index, values in comp[col].items():
        data['x'].append(col)
        data['y'].append(values[0])
        data['network'].append(index)

# Convert data into a new df
df_plot = pd.DataFrame(data)

# Create scatter plot
plt.figure(figsize=(10, 6))
sns.stripplot(x='x', y='y', hue='network', data=df_plot, palette='deep', jitter=False, size=9, edgecolor='black', linewidth=0.5, alpha=0.7)

# Plot formatting
plt.xlabel('Tasks', fontsize=14)
plt.ylabel(r'Difference of mean bifurcation parameters', fontsize=14)
plt.xticks(fontsize=10)
plt.yticks(fontsize=10)
plt.grid(True, which='major', linestyle='--', alpha=0.6)
plt.legend(title='Networks', bbox_to_anchor=(1.05, 1), loc='upper left', fontsize=10, title_fontsize=14)
plt.tight_layout()
plt.savefig(save_path+'networks.png', dpi=800)

# 7) Task separation

In [None]:
import seaborn as sns
from itertools import combinations
from statannotations.Annotator import Annotator

def pad_dicts(tests):
    """Ensure all dictionaries have the same length by padding with NaNs."""
    max_len = max(len(lst) for lst in tests.values())
    return {key: np.pad(lst, (0, max_len - len(lst)), 'constant', constant_values=np.nan)
            for key, lst in tests.items()}

def plot_comparison_across_labels(tests, method="BH", title='', x_label='', y_label='', fig_size=(12, 8), save_path='./'):

	# Prepare data
	tests = pad_dicts(tests)
	use_labels = list(tests.keys())
	pairs = list(combinations(use_labels, 2))
	df = pd.DataFrame(tests, columns=use_labels)

	# Set up figure and axes
	fig, (ax_table, ax_violin) = plt.subplots(2, 1, figsize=fig_size, gridspec_kw={'height_ratios': [1.2, 3]})
	sns.set_context("talk")

	# Perform statistical comparisons and annotate
	annotator = Annotator(ax=ax_violin, data=df, pairs=pairs, order=use_labels)
	annotator.configure(test='Mann-Whitney', text_format='star', verbose=True)
	annotator.configure(comparisons_correction=method, correction_format="replace")
	results = annotator.apply_test()

	# Prints
	annotator.print_pvalue_legend()
	for a in results.annotations:
		a.print_labels_and_content()

	# Create violin plot
	sns.violinplot(data=df, order=use_labels, ax=ax_violin, palette='deep', linewidth=1.2)
	ax_violin.set_xlabel(x_label, fontsize=14)
	ax_violin.set_ylabel(y_label, fontsize=14)
	ax_violin.set_xticklabels(ax_violin.get_xticklabels(), fontsize=12)

	# Grid and despine for cleaner look
	sns.despine()
	ax_violin.grid(True, which='major', linestyle='--', alpha=0.6)

	# Create a significance stars table
	stars_matrix = pd.DataFrame("", index=df.columns[1:], columns=df.columns[:-1])
	for result in results.annotations:
		c1, c2 = [str(struct["label"]) for struct in result.structs]
		stars_matrix.loc[c2, c1] = result.text  # Insert stars

	# Plot the table above the violin plot
	bbox_v = ax_violin.get_position()
	new_left = (bbox_v.x0 + bbox_v.x1) / 2 - 0.7 / 2
	table = pd.plotting.table(ax_table, stars_matrix, loc='center', cellLoc='center', fontsize=15, bbox=[new_left,0,0.7,1])
	ax_table.axis('off')

	# Title
	fig.suptitle(title, fontsize=18, y=0.93)

	# Save plot with high resolution
	plt.tight_layout()
	plt.savefig(save_path, dpi=800)

In [None]:
# Create and plot violin plot
result = windows[windows['window_id'] == 'mean'].groupby('cohort')['pred'].apply(lambda x: np.mean(np.vstack(x), axis=1).tolist()).to_dict()
plot_comparison_across_labels(result, x_label='Cohorts', y_label=r'Mean bifurcation parameters $a$', title='', save_path=save_path+'tasks.png')

# 8) FC and FCD

In [None]:
import scipy.io

# Read omega vector and SC
W = np.load(root+'results/G/w.npy')
SC = scipy.io.loadmat(sc_file)['SC_dbs80FULL']
SC = 0.2 * SC/SC.max()

In [None]:
# Set parameters in the df
windows['SC'] = windows['window'].apply(lambda x: SC)
windows['G'] = 2.3
windows['W'] = [i for i in np.repeat(W[None,:], len(windows), axis=0)]

In [None]:
from numba import jit

@jit(nopython=True)
def ode_hopf(t, vars, a=-0.02, w=1, G=0, C=0):
    '''
    Defines the ordinary differential equations of the Hopf model.

    Args:
        t (int): Dummy parameter.
        vars (np.array): x and y variables of the model.
        a (np.array): Amplitude parameters of each node.
        w (np.array): Frequency parameters of each node.
        G (float): Coumpling factor constant.
        C (np.array): Matrix of structural connectivity.

    Returns:
        dvars (np.array): Derivates of the x and y variables of the model.
    '''
    n = len(a)
    x = vars[:n].flatten()
    y = vars[n:].flatten()

    x_term = np.dot(C,x) - C.sum(axis=1) * x
    y_term = np.dot(C,y) - C.sum(axis=1) * y

    dxdt = a*x - w*y - x*(x**2 + y**2) + G*x_term
    dydt = a*y + w*x - y*(x**2 + y**2) + G*y_term

    dvars = np.concatenate((dxdt, dydt), axis=0)
    return dvars

def initialize_hopf(n_samples, nodes, seed, a_range, w_range, g, SC):
    '''
    Initializes Hopf parameters by:
        * Stacking SC and G.
        * Generating random values for a and w using a fixed seed.
    '''
    np.random.seed(seed=seed)
    A = np.random.uniform(a_range[0], a_range[1], size=(n_samples, nodes))
    W = np.random.uniform(w_range[0], w_range[1], size=(n_samples, nodes))
    G = np.repeat(g, n_samples)
    C = np.tile(SC[None,:], (n_samples,1,1))
    return A, W, G, C

@jit(nopython=True)
def numba_noise(size):
    '''
    The parameter 'size' in np.random.normal() is not supported by numba, this function fixes that.
    '''
    noise = np.empty(size,dtype=np.float64)
    for i in range(size):
        noise[i] = np.random.normal()
    return noise

@jit(nopython=True)
def integrate_hopf_euler_maruyama(A, W, C, G, TR, t_use, t_max, init_min=-1, init_max=1, dt=0.5, sigma=0.01):
	'''
	Integrates the Hopf model using the Euler-Maruyama method for each provided subject.

	Args:
		A, W, C, G: Hopf parameters for each subject.
		TR (float): Repetition time of the dataset.
		t_use (int): Number of timesteps to return.
		t_max (float): Last timestep.
		init_min (float): Minimum value for variable initialization.
		init_max (float): Maximum value for variable initialization.
		dt (float): Distance between timesteps.
		sigma (float): Controls the amount of noise.

	Returns:
		x_solution (np.array): Resulting time series for each subject.
	'''
	n_samples, nodes = A.shape

	# Sample time
	t = np.arange(0.0, t_max*TR, dt)

	# Initialize array to store the results
	x_solution = np.empty((n_samples, nodes, len(t)))

	for n in range(n_samples):
		# Initial conditions for x and y for each node
		x0 = np.random.uniform(init_min, init_max, size=nodes)
		y0 = np.random.uniform(init_min, init_max, size=nodes)
		vars = np.concatenate((x0, y0), axis=0)

		a = A[n]
		w = W[n]
		c = C[n]
		g = G[n]

		# Euler-Maruyama integration
		for i in range(1, len(t)):
			# Time
			t_span = (t[i - 1], t[i])
			# Derivates
			d_vars = ode_hopf(t_span, vars, a, w, g, c)

			# Euler-Maruyama integration
			vars += d_vars*dt + np.sqrt(dt)*sigma*numba_noise(size=2*nodes)

			# Clamping
			vars[vars > init_max] = init_max
			vars[vars < init_min] = init_min

			# Only save values for x
			x_solution[n, :, i] = vars[:nodes]

		if np.isnan(x_solution[n]).any():
			print(f'NaN found! n={n}')
			raise

		if (n+1) % 100 == 0:
			print(f'Generated {n+1}/{n_samples} samples')

	# Delete initial unwanted timesteps
	x_solution = x_solution[:,:,-int((t_use*TR/dt)):]
	# Fix sampling rate
	x_solution = x_solution[:,:,::int(TR/dt)]

	return x_solution

In [None]:
# Integrate
X = integrate_hopf_euler_maruyama(np.stack(windows['pred'].values), np.stack(windows['W'].values), np.stack(windows['SC'].values), np.stack(windows['G'].values), TR, t_use, t_use+100)
windows['X'] = [x for x in X]

In [None]:
# Drop unnecessary columns
windows.drop(columns=['SC', 'G', 'W'], inplace=True)

In [None]:
from scipy import signal
from numba import jit

def demean(x,dim=0):
    dims = x.size
    return x - np.mtlib.tile(np.mean(x,dim), dims)

@jit(nopython=True)
def adif(a, b):
    if np.abs(a - b) > np.pi:
        c = 2 * np.pi - np.abs(a - b)
    else:
        c = np.abs(a - b)
    return c

@jit(nopython=True)
def numba_PIM(phases, N, Tmax, dFC, PhIntMatr, discardOffset=10):
  T = np.arange(discardOffset, Tmax - discardOffset + 1)
  for t in T:
    for i in range(N):
      for j in range(i+1):
        dFC[i, j] = np.cos(adif(phases[i, t - 1], phases[j, t - 1]))
        dFC[j, i] = dFC[i, j]
    PhIntMatr[t - discardOffset] = dFC
  return PhIntMatr

def PhaseInteractionMatrix(ts, discardOffset=10):  # Compute the Phase-Interaction Matrix of an input BOLD signal
    if not np.isnan(ts).any():  # No problems, go ahead!!!
        (N, Tmax) = ts.shape
        npattmax = Tmax - (2 * discardOffset - 1)  # calculates the size of phfcd matrix
        # Data structures we are going to need...
        phases = np.empty((N, Tmax))
        dFC = np.empty((N, N))
        PhIntMatr = np.empty((npattmax, N, N))

        for n in range(N):
            Xanalytic = signal.hilbert(demean(ts[n, :]))
            phases[n, :] = np.angle(Xanalytic)

        PhIntMatr = numba_PIM(phases, N, Tmax, dFC, PhIntMatr)

    else:
        print('############ Warning!!! PhaseInteractionMatrix.from_fMRI: NAN found ############')
        PhIntMatr = np.array([np.nan])
    # ======== sometimes we need to plot the matrix. To simplify the code, we save it here if needed...
    # if saveMatrix:
    #     import scipy.io as sio
    #     sio.savemat(save_file + '.mat', {name: PhIntMatr})
    return PhIntMatr

def tril_indices_column(N, k=0):
    row_i, col_i = np.nonzero(
        np.tril(np.ones(N), k=k).T)  # Matlab works in column-major order, while Numpy works in row-major.
    Isubdiag = (col_i,
                row_i)  # Thus, I have to do this little trick: Transpose, generate the indices, and then "transpose" again...
    return Isubdiag

@jit(nopython=True)
def numba_phFCD(phIntMatr_upTri, size_kk3):
    npattmax = phIntMatr_upTri.shape[0]
    phfcd = np.zeros((size_kk3))
    kk3 = 0

    for t in range(npattmax - 2):
        p1_sum = np.sum(phIntMatr_upTri[t:t + 3, :], axis=0)
        p1_norm = np.linalg.norm(p1_sum)
        for t2 in range(t + 1, npattmax - 2):
            p2_sum = np.sum(phIntMatr_upTri[t2:t2 + 3, :], axis=0)
            p2_norm = np.linalg.norm(p2_sum)

            dot_product = np.dot(p1_sum, p2_sum)
            phfcd[kk3] = dot_product / (p1_norm * p2_norm)
            kk3 += 1
    return phfcd

def phFCD(ts, discardOffset=10):  # Compute the FCD of an input BOLD signal
    phIntMatr = PhaseInteractionMatrix(ts)  # Compute the Phase-Interaction Matrix
    if not np.isnan(phIntMatr).any():  # No problems, go ahead!!!
        (N, Tmax) = ts.shape
        npattmax = Tmax - (2 * discardOffset - 1)  # calculates the size of phfcd vector
        size_kk3 = int((npattmax - 3) * (npattmax - 2) / 2)  # The int() is not needed because N*(N-1) is always even, but "it will produce an error in the future"...
        Isubdiag = tril_indices_column(N, k=-1)  # Indices of triangular lower part of matrix
        phIntMatr_upTri = np.zeros((npattmax, int(N * (N - 1) / 2)))  # The int() is not needed, but... (see above)
        for t in range(npattmax):
            phIntMatr_upTri[t,:] = phIntMatr[t][Isubdiag]
        phfcd = numba_phFCD(phIntMatr_upTri, size_kk3,)

    else:
        print('############ Warning!!! phFCD.from_fMRI: NAN found ############')
        phfcd = np.array([np.nan])
    # if saveMatrix:
    #     buildMatrixToSave(phfcd, npattmax - 2)
    return phfcd

In [None]:
from scipy import stats

def get_correlation(a, b):
    return np.corrcoef(a[np.triu_indices(a.shape[-1])], b[np.triu_indices(a.shape[-1])])[0,1]

def get_ks_distance(a, b):
    d, pvalue = stats.ks_2samp(a.flatten(), b.flatten())
    return d

In [None]:
def apply_corr(x):
    fc_emp = np.corrcoef(x['window'])
    fc_emp[np.isnan(fc_emp)] = np.nanmean(fc_emp)

    fc_sim = np.corrcoef(x['X'])
    fc_sim[np.isnan(fc_sim)] = np.nanmean(fc_sim)
    return np.abs(get_correlation(fc_emp, fc_sim))

# Get measures
windows['FC_correlation'] = windows.apply(apply_corr, axis=1)
windows['phFCD_distance'] = windows.apply(lambda x: get_ks_distance(phFCD(x['window']), phFCD(x['X'])), axis=1)

In [None]:
# Create and plot violin plot for FC
FC = windows[windows['window_id'] == 'mean'].groupby('cohort')['FC_correlation'].apply(list).to_dict()
plot_comparison_across_labels(FC, x_label='Cohorts', y_label='FC correlation', title='', save_path=save_path+'FC.png')

In [None]:
# Explore FC values
pd.DataFrame([(np.mean(FC[k]), np.std(FC[k]), np.min(FC[k]), np.max(FC[k])) for k in FC.keys()], columns=['mean', 'std', 'min', 'max'], index=FC.keys())

In [None]:
# Create and plot violin plot for FCD
phFCD = windows[windows['window_id'] == 'mean'].groupby('cohort')['phFCD_distance'].apply(list).to_dict()
plot_comparison_across_labels(phFCD, x_label='Cohorts', y_label='KS distance FCD', title='', save_path=save_path+'phFCD.png')

In [None]:
# Explore FCD values
pd.DataFrame([(np.mean(phFCD[k]), np.std(phFCD[k]), np.min(phFCD[k]), np.max(phFCD[k])) for k in phFCD.keys()], columns=['mean', 'std', 'min', 'max'], index=phFCD.keys())