In [11]:
from glob import glob
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from pathlib import Path
from scipy.spatial.distance import pdist
from scipy.stats import spearmanr
from tqdm import tqdm
import seaborn as sns
from itertools import combinations
from src.mri import gen_mask
from src.rsa import fit_and_predict
import nibabel as nib
import warnings
from joblib import Parallel, delayed

from sklearn.linear_model import LogisticRegression
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import LeaveOneGroupOut

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
subj = 'sub-01'
process = 'fMRI_RSA'
decoding = True
calculate_rdm = True

top_path = '/Users/emcmaho7/Dropbox/projects/SI_EEG/SIEEG_analysis'
data_dir = f'{top_path}/data'
preproc_files = f'{data_dir}/raw/fmri_betas/{subj}_space-T1w_desc-*-fracridge-all-data.nii.gz'
figure_path = f'{top_path}/reports/figures/{process}'
out_path = f'{top_path}/data/interim/{process}'
Path(figure_path).mkdir(parents=True, exist_ok=True)
Path(out_path).mkdir(parents=True, exist_ok=True)

In [3]:
test_videos = pd.read_csv(f'{data_dir}/raw/annotations/test.csv')
train_videos = pd.read_csv(f'{data_dir}/raw/annotations/train.csv')
df = pd.concat([test_videos, train_videos]).reset_index(drop=True).sort_values(by='video_name')
sort_idx = df.reset_index()['index'].to_numpy()
videos = df.video_name.to_numpy()

feature_rdms = pd.read_csv(f'{data_dir}/interim/FeatureRDMs/feature_rdms.csv')

In [4]:
rois = ['EVC', 'MT', 'EBA',
        'LOC', 'FFA', 'PPA',
        'pSTS', 'face-pSTS', 'aSTS']
feature_order = ['alexnet', 'moten', 'indoor',
                 'expanse', 'object_directedness', 'agent_distance',
                 'facingness', 'joint_action', 'communication', 
                 'valence', 'arousal']

## Pairwise decoding

In [9]:
if decoding:
    n_groups = 5 
    groups = np.concatenate([np.arange(n_groups), np.arange(n_groups)])
    logo = LeaveOneGroupOut()
    pipe = Pipeline([('scale', StandardScaler()), ('lr', LogisticRegression())])
    out_name = f'{out_path}/{subj}_pairwise-decoding.csv'
else:
    out_name = f'{out_path}/{subj}_correlation-distance.csv'

print(out_name)

/Users/emcmaho7/Dropbox/projects/SI_EEG/SIEEG_analysis/data/interim/fMRI_RSA/sub-01_pairwise-decoding.csv


In [12]:
if calculate_rdm:
    betas = []
    for file in sorted(glob(preproc_files)):
        beta_img = nib.load(file)
        arr = beta_img.get_fdata().reshape((-1, beta_img.shape[-2], beta_img.shape[-1]))
        if arr.shape[-1] > 10:
            arr = arr.reshape(arr.shape[0], arr.shape[1], arr.shape[2] // 2, 2).mean(axis=3)
        betas.append(arr)
    betas = np.hstack(betas)[:, sort_idx,:]

    reliability_mask = np.load(f'{data_dir}/raw/reliability_mask/{subj}_space-T1w_desc-test-fracridge_reliability-mask.npy')

    with warnings.catch_warnings():
        warnings.filterwarnings("ignore", message="Mean of empty slice")
        
        nCk = list(combinations(range(betas.shape[1]), 2))
        results = []
        for roi in tqdm(rois, desc='ROIs'):
            files = glob(f'{data_dir}/raw/localizers/{subj}/{subj}_task-*_space-T1w_roi-{roi}_hemi-*_roi-mask.nii.gz')
            mask = gen_mask(files, reliability_mask)
            betas_masked = betas[mask, ...]
            if decoding:
                result_for_t = Parallel(n_jobs=4)(
                    delayed(fit_and_predict)(betas_masked[:, video1, :].squeeze().T,
                                            betas_masked[:, video2, :].squeeze().T,
                                            n_groups) for video1, video2 in tqdm(nCk, total=len(nCk), desc='Pairwise decoding')
                )
                for accuracy, (video1, video2) in zip(result_for_t, nCk):
                    results.append([roi, videos[video1], videos[video2], accuracy])
            else:
                rdm = pdist(np.nanmean(betas[mask, ...], axis=-1).T, metric='correlation')
                for i, (video1, video2) in enumerate(nCk):
                    results.append([roi, videos[video1], videos[video2],
                                    rdm[i]])
    results = pd.DataFrame(results, columns=['roi', 'video1', 'video2', 'distance'])
    results.to_csv(out_name, index=False)
else:
    results = pd.read_csv(out_name)

ROIs:   0%|          | 0/9 [00:00<?, ?it/s]


TypeError: pqdm() missing 2 required positional arguments: 'function' and 'n_jobs'

## Correlate Feature and Neural RDMs

In [None]:
feature_group = feature_rdms.groupby('feature')
neural_group = results.groupby('roi')
rsa = []
for feature, feature_rdm in tqdm(feature_group):
    for time, time_rdm in neural_group:
        rho, _ = spearmanr(feature_rdm.distance, time_rdm.distance)
        rsa.append([feature, time, rho])
rsa = pd.DataFrame(rsa, columns=['feature', 'roi', 'Spearman rho'])
cat_type = pd.CategoricalDtype(categories=feature_order, ordered=True)
rsa['feature'] = rsa.feature.astype(cat_type)
cat_type = pd.CategoricalDtype(categories=rois, ordered=True)
rsa['roi'] = rsa.roi.astype(cat_type)
if decoding: 
    rsa.to_csv(f'{out_path}/{subj}_rsa-decoding.csv')
else:
    rsa.to_csv(f'{out_path}/{subj}_rsa-correlation.csv')
rsa.head()

In [None]:
def feature2color(key=None):
    d = dict()
    d['alexnet'] = np.array([0.5, 0.5, 0.5, 1])
    d['moten'] = np.array([0.5, 0.5, 0.5, 1])
    d['indoor'] = np.array([0.95703125, 0.86328125, 0.25, 0.8])
    d['expanse'] = np.array([0.95703125, 0.86328125, 0.25, 0.8])
    d['object_directedness'] = np.array([0.95703125, 0.86328125, 0.25, 0.8])
    d['agent_distance'] = np.array([0.51953125, 0.34375, 0.953125, 0.8])
    d['facingness'] = np.array([0.51953125, 0.34375, 0.953125, 0.8])
    d['joint_action'] = np.array([0.44921875, 0.8203125, 0.87109375, 0.8])
    d['communication'] = np.array([0.44921875, 0.8203125, 0.87109375, 0.8])
    d['valence'] = np.array([0.8515625, 0.32421875, 0.35546875, 0.8])
    d['arousal'] = np.array([0.8515625, 0.32421875, 0.35546875, 0.8])
    if key is not None:
        return d[key]
    else:
        return d

In [None]:
feature_group = rsa.groupby('roi')
_, axes = plt.subplots(3, 3, sharey=True, sharex=True)
axes = axes.flatten()
ymin, ymax = rsa['Spearman rho'].min(), rsa['Spearman rho'].max()
for ax, (roi, feature_df) in zip(axes, feature_group):
    sns.barplot(x='feature', y='Spearman rho',
                 data=feature_df, ax=ax, color='gray')
    if roi in ['EVC', 'LOC', 'pSTS']:
        ax.set_ylabel('Spearman rho')
    else:
        ax.set_ylabel('')
        
    if roi in ['pSTS', 'face-pSTS', 'aSTS']:
        ax.set_xlabel('Feature')
        ax.set_xticklabels(feature_order, rotation=90, ha='center')
    else:
        ax.set_xlabel('')
        ax.tick_params(axis='x', which='both', length=0)

    for bar, feature in zip(ax.patches, feature_order):
        color = feature2color(feature)
        bar.set_color(color)
        
    ax.set_ylim([ymin, ymax])
    ax.hlines(y=0, xmin=ax.get_xlim()[0], xmax=ax.get_xlim()[1],
        colors='gray', linestyles='solid', zorder=1)
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    ax.spines['bottom'].set_visible(False)
    ax.set_title(roi)
plt.tight_layout()
if decoding: 
    plt.savefig(f'{figure_path}/{subj}_rsa-decoding.png')
else:
    plt.savefig(f'{figure_path}/{subj}_rsa-correlation.png')