In [None]:
root = '' # change as needed
data_root = root + 'data/'
save_path = root + 'results/inference/'
model_path = root + 'model' # path to the model reported in the paper (downloaded from Zenodo), you can change it to your own model
sc_file = data_root + 'SC_dbs80HARDIFULL.mat'
dbs_path = root + 'ds80_labels.csv'
yeo_path = root + 'dbs802Yeo7.csv'
!mkdir -p {save_path}

# 1) Read

In [None]:
import mat73
import numpy as np
import pandas as pd

In [None]:
def get_ts_hcp_task(data_root):
    '''
    Reads the HCP task dataset.

    Args:
        data_root (str): Folder containing the data files (.mat).

    Returns:
        time_series (dict): Dictionary where the keys are the tasks and the values are the arrays.
    '''
    time_series = {}
    time_series['memory'] = mat73.loadmat(data_root+'hcp1003_WM_LR_dbs80.mat')
    time_series['gambling'] = mat73.loadmat(data_root+'hcp1003_GAMBLING_LR_dbs80.mat')
    time_series['motor'] = mat73.loadmat(data_root+'hcp1003_MOTOR_LR_dbs80.mat')
    time_series['language'] = mat73.loadmat(data_root+'hcp1003_LANGUAGE_LR_dbs80.mat')
    time_series['social'] = mat73.loadmat(data_root+'hcp1003_SOCIAL_LR_dbs80.mat')
    time_series['relational'] = mat73.loadmat(data_root+'hcp1003_RELATIONAL_LR_dbs80.mat')
    time_series['emotion'] = mat73.loadmat(data_root+'hcp1003_EMOTION_LR_dbs80.mat')
    time_series['rest'] = mat73.loadmat(data_root+'hcp1003_REST1_LR_dbs80.mat')
    return time_series

# Read time series
data_ts = get_ts_hcp_task(data_root)
# Get only useful data
for k in data_ts.keys():
  data_ts[k] = [i['dbs80ts'] for i in data_ts[k]['subject'] if type(i) is dict]

In [None]:
# Create time series dataframe
rows = []
for cohort, values in data_ts.items():
    for id, value in enumerate(values):
        rows.append({'cohort': cohort.capitalize(), 'bold': value, 'subject_id': id})
del data_ts
time_series = pd.DataFrame(rows)
time_series

# 2) Signal filtering

In [None]:
from scipy.signal import butter, detrend, filtfilt

def demean(x,dim=0):
    dims = x.size
    return x - np.tile(np.mean(x,dim), dims)  # repmat(np.mean(x,dim),dimrep)

def BandPassFilter(boldSignal, f_low, f_high, TR, k, removeStrongArtefacts=True):
    # Convenience method to apply a filter (always the same one) to all areas in a BOLD signal. For a single,
    # isolated area evaluation, better use the method below.
    (N, Tmax) = boldSignal.shape
    fnq = 1./(2.*TR)              # Nyquist frequency
    Wn = [f_low/fnq, f_high/fnq]                                   # butterworth bandpass non-dimensional frequency
    bfilt, afilt = butter(k,Wn, btype='band', analog=False)   # construct the filter
    # bfilt = bfilt_afilt[0]; afilt = bfilt_afilt[1]  # numba doesn't like unpacking...
    signal_filt = np.zeros(boldSignal.shape)
    for seed in range(N):
        if not np.isnan(boldSignal[seed, :]).any():  # No problems, go ahead!!!
            ts = demean(detrend(boldSignal[seed, :]))  # Probably, we do not need to demean here, detrend already does the job...

            if removeStrongArtefacts:
                ts[ts>3.*np.std(ts)] = 3.*np.std(ts)    # Remove strong artefacts
                ts[ts<-3.*np.std(ts)] = -3.*np.std(ts)  # Remove strong artefacts

            signal_filt[seed,:] = filtfilt(bfilt, afilt, ts, padlen=3*(max(len(bfilt),len(afilt))-1))  # Band pass filter. padlen modified to get the same result as in Matlab
        else:  # We've found problems, mark this region as "problematic", to say the least...
            print(f'############ Warning!!! BandPassFilter: NAN found at region {seed} ############')
            signal_filt[seed,0] = np.nan
    return signal_filt

def AmplitudeFilter(time_series):
    return time_series/np.abs(time_series).max()

In [None]:
# Filters parameters
TR = 2.0
k = 2                                # 2nd order butterworth filter
f_low = 0.008                        # lowpass frequency of filter
f_high = 0.08                        # highpass

# Apply filters
time_series.loc[:,'bold'] = time_series['bold'].apply(BandPassFilter, args=(f_low, f_high, TR, k))
time_series.loc[:,'bold'] = time_series['bold'].apply(AmplitudeFilter)

# 3) Windowing

In [None]:
t_use = 50
stride = 50

def create_windows(ts, size=1, stride=1):
	N, t = ts.shape
	windows = np.stack([ts[:, i:i+size] for i in reversed(list(range(t - size, -1, -stride)))], axis=0)
	return windows

# Extract windows and create a new dataframe
all_windows = []
for idx, row in time_series.iterrows():
    windows = create_windows(row['bold'], t_use, stride)
    for i, window in enumerate(windows):
        all_windows.append({'cohort': row['cohort'], 'subject_id': row['subject_id'], 'window_id': i, 'window': window})
    all_windows.append({'cohort': row['cohort'], 'subject_id': row['subject_id'], 'window_id': 'mean', 'window': windows.mean(axis=0)})
del time_series

# Create a new dataframe from the list of windows
windows = pd.DataFrame(all_windows)
del all_windows
windows

In [None]:
def get_filename(row):
    if row['window_id'] == 'mean':
        return f"{row['cohort']}_{row['subject_id']}"
    else:
        return f"{row['cohort']}_{row['subject_id']}_{row['window_id']}"

# Add filenames to the df
windows['filename'] = windows.apply(get_filename, axis=1)

# 4) Time-series-to-image conversion

In [None]:
from matplotlib import pyplot as plt

ratio = 77
def plot_ts(ts, path):
    N, t = ts.shape
    plt.figure(figsize=(t/ratio, N/ratio))
    plt.imshow(ts, aspect='auto', cmap='viridis', vmin=-1, vmax=1)
    plt.axis('off')
    plt.savefig(path, bbox_inches='tight', pad_inches=0)
    plt.close()

In [None]:
import tqdm, os

if (not os.path.exists(save_path+'img/')):
	os.mkdir(save_path+'img/')

# Create and save images
for _, row in tqdm.tqdm(windows.iterrows(), total=len(windows)):
    plot_ts(row['window'], f"{save_path}/img/{row['filename']}.png")

# 5) Inference

In [None]:
from fastai.vision.all import *

a_min, a_max = -1, 1

# Metric used
def rmse_a(inp, targ):
  return rmse(inp, targ)*100/(a_max-a_min)

In [None]:
# Get inputs for the model
image_list = save_path + '/img/' + windows['filename'].values + '.png'
# Create dummy dataloader
dls = ImageDataLoaders.from_path_func('', [0], lambda x: '0', bs=16, item_tfms=Resize((80, 50), method='squish'))
# Load model
learn = vision_learner(dls, 'convnext_tiny_in22k', n_out=80, y_range=(-1,1), loss_func=MSELossFlat).to_fp16()
learn.load(model_path, device='cuda', weights_only=False)
# Predict
test_dl = learn.dls.test_dl(image_list, device='cuda')
preds, _ = learn.get_preds(dl=test_dl)

In [None]:
# Add predictions to the dataframe
windows['pred'] = [p for p in preds.numpy()]

# 6) Yeo networks

In [None]:
# Read names of the 80 ROIs
dbs_names = pd.read_csv(dbs_path, sep=';').loc[:,'Rois'].values
# Read asociated regions
yeo_ids = pd.read_csv(yeo_path, header=None).iloc[0].values
yeo_names = np.array([None, 'Visual', 'Somatomotor', 'Dorsal attention', 'Ventral attention', 'Limbic', 'Frontoparietal', 'Default'])
# Filter non cortical regions
dbs_names_cortex = dbs_names[np.r_[0:31, 49:80]]

In [None]:
from collections import defaultdict

tasks_act = defaultdict(list)
for c in np.unique(windows.cohort.values):
    # Filter windows by task
    selection = np.stack(windows[windows['cohort'] == c][windows['window_id'] == 'mean'].pred.values)
    # Get the mean value for each ROI in a particular task
    activations = selection.mean(axis=0)
    # Filter non cortical regions
    activations_cortex = activations[np.r_[0:31, 49:80]]
    tasks_act[c].append(activations_cortex)

In [None]:
comp = pd.DataFrame(columns=['Rest', 'Memory', 'Language', 'Gambling', 'Motor', 'Relational', 'Social', 'Emotion'])
comp = comp.drop(columns=['Rest'])

for c in tasks_act.keys():
	if c != 'Rest':
		net_act = defaultdict(list)
		activations_cortex = tasks_act[c][0]-tasks_act['Rest'][0]
		for net,act in zip(yeo_names[yeo_ids], activations_cortex):
			if act>-20:
				net_act[net].append(act)
		for net in net_act.keys():
			comp.loc[net, c] = [np.mean(net_act[net]), np.max(net_act[net])]

In [None]:
import seaborn as sns

data = {'x': [], 'y': [], 'network': []}

# Iterate over the df to extract y values and column/row labels
for col in comp.columns:
    for index, values in comp[col].items():
        data['x'].append(col)
        data['y'].append(values[0])
        data['network'].append(index)

# Convert data into a new df
df_plot = pd.DataFrame(data)

# Create scatter plot
plt.figure(figsize=(10, 6))
sns.stripplot(x='x', y='y', hue='network', data=df_plot, palette='deep', jitter=False, size=9, edgecolor='black', linewidth=0.5, alpha=0.7)

# Plot formatting
plt.xlabel('Tasks', fontsize=14)
plt.ylabel(r'Difference of mean bifurcation parameters', fontsize=14)
plt.xticks(fontsize=10)
plt.yticks(fontsize=10)
plt.grid(True, which='major', linestyle='--', alpha=0.6)
plt.legend(title='Networks', bbox_to_anchor=(1.05, 1), loc='upper left', fontsize=10, title_fontsize=14)
plt.tight_layout()
plt.savefig(save_path+'networks.png', dpi=800)

# 7) Task separation

In [None]:
import seaborn as sns
from itertools import combinations
from statannotations.Annotator import Annotator

def pad_dicts(tests):
    """Ensure all dictionaries have the same length by padding with NaNs."""
    max_len = max(len(lst) for lst in tests.values())
    return {key: np.pad(lst, (0, max_len - len(lst)), 'constant', constant_values=np.nan)
            for key, lst in tests.items()}

def plot_comparison_across_labels(tests, method="BH", title='', x_label='', y_label='', fig_size=(12, 8), save_path='./'):

	# Prepare data
	tests = pad_dicts(tests)
	use_labels = list(tests.keys())
	pairs = list(combinations(use_labels, 2))
	df = pd.DataFrame(tests, columns=use_labels)

	# Set up figure and axes
	fig, (ax_table, ax_violin) = plt.subplots(2, 1, figsize=fig_size, gridspec_kw={'height_ratios': [1.2, 3]})
	sns.set_context("talk")

	# Perform statistical comparisons and annotate
	annotator = Annotator(ax=ax_violin, data=df, pairs=pairs, order=use_labels)
	annotator.configure(test='Mann-Whitney', text_format='star', verbose=True)
	annotator.configure(comparisons_correction=method, correction_format="replace")
	results = annotator.apply_test()

	# Prints
	annotator.print_pvalue_legend()
	for a in results.annotations:
		a.print_labels_and_content()

	# Create violin plot
	sns.violinplot(data=df, order=use_labels, ax=ax_violin, palette='deep', linewidth=1.2)
	ax_violin.set_xlabel(x_label, fontsize=14)
	ax_violin.set_ylabel(y_label, fontsize=14)
	ax_violin.set_xticklabels(ax_violin.get_xticklabels(), fontsize=12)

	# Grid and despine for cleaner look
	sns.despine()
	ax_violin.grid(True, which='major', linestyle='--', alpha=0.6)

	# Create a significance stars table
	stars_matrix = pd.DataFrame("", index=df.columns[1:], columns=df.columns[:-1])
	for result in results.annotations:
		c1, c2 = [str(struct["label"]) for struct in result.structs]
		stars_matrix.loc[c2, c1] = result.text  # Insert stars

	# Plot the table above the violin plot
	bbox_v = ax_violin.get_position()
	new_left = (bbox_v.x0 + bbox_v.x1) / 2 - 0.7 / 2
	table = pd.plotting.table(ax_table, stars_matrix, loc='center', cellLoc='center', fontsize=15, bbox=[new_left,0,0.7,1])
	ax_table.axis('off')

	# Title
	fig.suptitle(title, fontsize=18, y=0.93)

	# Save plot with high resolution
	plt.tight_layout()
	plt.savefig(save_path, dpi=800)

In [None]:
# Create and plot violin plot
result = windows[windows['window_id'] == 'mean'].groupby('cohort')['pred'].apply(lambda x: np.mean(np.vstack(x), axis=1).tolist()).to_dict()
plot_comparison_across_labels(result, x_label='Cohorts', y_label=r'Mean bifurcation parameters $a$', title='', save_path=save_path+'tasks.png')