In [2]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from sklearn.linear_model import Lasso
from sklearn.metrics import mean_squared_error
import tqdm
import scipy
from utils import rdms, load_params, read_subjects_data
import warnings
import importlib
from joblib import Parallel, delayed
from copy import deepcopy
warnings.simplefilter(action='ignore', category=pd.errors.PerformanceWarning)

In [3]:
lottery_objs = load_params.load_lotteries()
set_dicts = load_params.load_set_dicts()
behavior_results = load_params.load_behavior_results()
set_objs = load_params.load_sets(behavior_results)
first_subjects_roi, replication_subjects_roi = load_params.load_samples()

In [4]:
# create folds for leave-one-lottery-out
folds = []
folds_sets = []
for i, set_out in enumerate(set_objs):
    test_sets = []
    test_ind = []
    train_sets = []
    train_ind = []
    for j, set_obj in enumerate(set_objs):
        if set_obj.overlapping_with(set_out):
            test_sets.append(set_obj)
            test_ind.append(j)
        else:
            train_sets.append(set_obj)
            train_ind.append(j)
    if (train_ind, test_ind) not in folds:
        folds.append((train_ind, test_ind))
        folds_sets.append((train_sets, test_sets))

In [5]:
# create nested croos-validation folds 
all_inner_folds = {}
for i, fold in enumerate(folds_sets):
    train_sets = fold[0] # index 0 is train, index 1 is test
    inner_folds = []
    for train_set_i, train_set_out in enumerate(train_sets):
        inner_train_sets = []
        inner_test_sets = []
        for set_ind, set_obj in enumerate(train_sets):
            if set_obj.overlapping_with(train_set_out):
                inner_test_sets.append(set_ind)
            else:
                inner_train_sets.append(set_ind)
        if (inner_train_sets, inner_test_sets) not in inner_folds:
            inner_folds.append((inner_train_sets, inner_test_sets))
    all_inner_folds[i] = inner_folds

In [6]:
def get_avg_rdms(rois, subjects, set_objs):
    '''average the RDMs of the sample'''
    avg_rdms = {}
    std_rdms = {}
    full_avg_rdms = {}
    rois_mean = {}
    rois_std = {}
    for roi in rois:
        subjects_full_rdms = []
        subjects_set_rdms = []
        subjects_rdm_means = {}
        subjects_rdm_stds = {}
        for subject in subjects:
            # if the regression is using only a subset of the lotteries, use subset of the RDM
            rdm_subset = subject.RDM[roi].copy()
            subject_norm_rdm, subject_rdm_mean, subject_rdm_std = rdms.normalize_RDM(rdm_subset, return_stats=True)
            subjects_rdm_means[subject.sub_num] = subject_rdm_mean
            subjects_rdm_stds[subject.sub_num] = subject_rdm_std
            subjects_full_rdms.append(subject_norm_rdm)
            set_rdms = rdms.get_set_RDMs_obj(subject_norm_rdm, set_objs, roi)
            subjects_set_rdms.append(set_rdms)
        rois_mean[roi] = subjects_rdm_means
        rois_std[roi] = subjects_rdm_stds
        subjects_full_rdms = np.array(subjects_full_rdms)
        full_avg_rdms[roi] = np.mean(subjects_full_rdms, axis=0)
        subjects_rdms = pd.concat(subjects_set_rdms, axis=0)
        avg_rdms[roi] = subjects_rdms.groupby(level=0).mean()
        std_rdms[roi] = subjects_rdms.groupby(level=0).std()
    
    return avg_rdms

In [8]:
def create_X_from_subjects(subjects):
    '''average RDMs and concatenate to a dataframe to create features for predictions'''
    rois = list(subjects[0].RDM.keys())
    if len(rois) < 100:
        # in pre-defined rois, remove whole_brain
        rois = rois[:-1] 
    rdms = get_avg_rdms(rois, subjects, set_objs)
    X = pd.concat(rdms.values(), axis=1)
    return X

### pre-defined ROIs results and premutations

In [9]:
predef_rois = list(first_subjects_roi[0].RDM.keys())[:-1] # remove whole_brain

In [10]:
first_X_roi = create_X_from_subjects(first_subjects_roi)
replication_X_roi = create_X_from_subjects(replication_subjects_roi)

In [11]:
y = [set_obj.decoy_effect for set_obj in set_objs]
y = pd.DataFrame(y)

In [12]:
def train_model(X, y, regularizations, folds, return_models=False):
    warnings.simplefilter(action='ignore', category=scipy.stats.ConstantInputWarning)
    k_folds = len(folds)
    cv_rmse = np.zeros(k_folds)
    corrs = np.zeros(k_folds)
    inner_cv_rmse = np.zeros(len(regularizations))
    if return_models:
        models = np.zeros(k_folds, dtype=object)
        preds = np.empty((k_folds, len(y)))
        preds.fill(np.nan)
    fold_i = 0
    for train_ind, test_ind in folds:
        X_train, y_train = X.iloc[train_ind], y[train_ind]
        X_test, y_test = X.iloc[test_ind], y[test_ind]
        # inner CV for hyperparameters
        for reg_i, regularization in enumerate(regularizations):
            inner_fold_rmse = 0
            for inner_fold in all_inner_folds[fold_i]:
                inner_train_ind, inner_test_ind = inner_fold[0], inner_fold[1]
                X_train_inner, y_train_inner = X_train.iloc[inner_train_ind], y_train[inner_train_ind]
                X_test_inner, y_test_inner = X_train.iloc[inner_test_ind], y_train[inner_test_ind]
                lasso = Lasso(alpha=regularization, max_iter=1_000_000)
                lasso.fit(X_train_inner, y_train_inner)
                pred = lasso.predict(X_test_inner)
                inner_fold_rmse += np.sqrt(mean_squared_error(y_test_inner, pred))
            inner_cv_rmse[reg_i] = inner_fold_rmse / len(all_inner_folds[fold_i])
        reg_argmin = np.argmin(inner_cv_rmse)
        best_reg = regularizations[reg_argmin]
        lasso = Lasso(alpha=best_reg, max_iter=100_000)
        lasso.fit(X_train, y_train)
        pred = lasso.predict(X_test)
        cv_rmse[fold_i] = np.sqrt(mean_squared_error(y_test, pred))
        corr = scipy.stats.spearmanr(y_test, pred)[0]
        corrs[fold_i] = corr if ~np.isnan(corr) else 0
        if return_models:
            models[fold_i] = lasso
            preds[fold_i, test_ind] = pred.flatten()
        fold_i += 1
    mean_cv_rmse = np.mean(cv_rmse)
    if return_models:
        return mean_cv_rmse, models, cv_rmse, preds, corrs
    return mean_cv_rmse

In [13]:
alphas = np.logspace(-3, 0, 25)

In [14]:
attributes_X = pd.read_csv('../../data/attributes.csv', index_col=0)
attributes_X = attributes_X.set_index(np.arange(1, 28))

In [15]:
attributes_result, attributes_models, attributes_rmses, attributes_preds, attributes_corrs = train_model(attributes_X, y.values.flatten(), alphas, folds, return_models=True)

In [16]:
print(f'RMSE: {attributes_result:.4f}, mean correlation: {np.mean(attributes_corrs):.4f}')

RMSE: 0.0764, mean correlation: 0.3142


In [17]:
first_roi_result, first_roi_models, first_roi_rmses, first_roi_preds, first_roi_corrs = train_model(first_X_roi, y.values.flatten(), alphas, folds, return_models=True)

In [18]:
print(f'RMSE: {first_roi_result:.4f}, mean correlation: {np.mean(first_roi_corrs):.4f}')

RMSE: 0.0656, mean correlation: 0.4730


In [19]:
replication_roi_result, replication_roi_models, replication_roi_rmses, replication_roi_preds, replication_roi_corrs = train_model(replication_X_roi, y.values.flatten(), alphas, folds, return_models=True)

In [20]:
print(f'RMSE: {replication_roi_result:.4f}, mean correlation: {np.mean(replication_roi_corrs):.4f}')

RMSE: 0.0659, mean correlation: 0.5152


#### permutation test

In [26]:
permutations = 10_000
permutation_rmse = np.zeros(permutations)
permutation_ys = np.zeros((permutations, len(y)))
for perm_i in tqdm.tqdm(range(permutations), total=permutations):
    perm_y = y.sample(frac=1)
    permutation_ys[perm_i, :] = perm_y.values.flatten()

100%|██████████| 10000/10000 [00:00<00:00, 13969.73it/s]


In [27]:
# attributes_permutation_results = Parallel(n_jobs=-1)(  delayed(train_model)(attributes_X, perm_y, alphas, folds) 
#                                             for perm_y in tqdm.tqdm(permutation_ys))

In [28]:
# replication_permutation_results = Parallel(n_jobs=-1)(  delayed(train_model)(replication_X_roi, perm_y, alphas, folds) 
#                                             for perm_y in tqdm.tqdm(permutation_ys))

### random 8 parcels from schaefer - baseline perfromance

In [50]:
first_subjects_schaefer, replication_subjects_schaefer = load_params.load_samples(roi_type='schaefer')
schaefer_rois = list(first_subjects_schaefer[0].RDM.keys())
first_X_schaefer = pd.read_csv('../../results/first_results/cv/first_X_schaefer.csv', index_col=0)
replication_X_schaefer = pd.read_csv('../../results/replication_results/cv/replication_X_schaefer.csv', index_col=0)

In [51]:
def extract_roi_columns(rois, X1, X2=None):
    roi_cols =  [roi + '_Target_Decoy' for roi in rois] +\
                [roi + '_Target_Competitor' for roi in rois] +\
                [roi + '_Competitor_Decoy' for roi in rois]
    roi_cols = sorted(roi_cols)
    X1 = X1.loc[:, roi_cols]
    if X2 is not None:
        X2 = X2.loc[:, roi_cols]
        return X1, X2
    else:
        return X1

In [None]:
permutations = 10_000
n_rois = len(predef_rois)
permutation_rmse = np.zeros(permutations)
permutation_first_X = np.zeros(permutations, dtype=object)
permutation_replication_X = np.zeros(permutations, dtype=object)
for perm_i in tqdm.tqdm(range(permutations), total=permutations):
    random_rois = np.random.choice(schaefer_rois, n_rois, replace=False)
    first_perm_X, replication_perm_X = extract_roi_columns(random_rois, first_X_schaefer, replication_X_schaefer)
    permutation_first_X[perm_i] = first_perm_X
    permutation_replication_X[perm_i] = replication_perm_X

100%|██████████| 10000/10000 [00:07<00:00, 1267.19it/s]


In [None]:
warnings.filterwarnings("ignore", category=scipy.stats.ConstantInputWarning)
warnings.simplefilter(action='ignore', category=scipy.stats.ConstantInputWarning)

In [None]:
# first_permutation_results = Parallel(n_jobs=-1)(  delayed(train_model)(first_perm_X, y.values, alphas, folds) 
#                                             for first_perm_X in tqdm.tqdm(permutation_first_X))

In [None]:
# replication_permutation_results = Parallel(n_jobs=-1)(  delayed(train_model)(replication_perm_X, y.values, alphas, folds) 
#                                             for replication_perm_X in tqdm.tqdm(permutation_replication_X))

100%|██████████| 10000/10000 [2:25:56<00:00,  1.14it/s] 
