In [94]:
from nilearn import plotting, datasets, surface
import nibabel as nib
import numpy as np
from pathlib import Path
from scipy import stats
from statsmodels.stats.multitest import multipletests
import src.custom_plotting as cp
import seaborn as sns

In [95]:
def correct(ps_, rs_, p_crit=5e-2):
    sig_bool, ps_corrected, _, _ = multipletests(ps_, alpha=p_crit, method='fdr_bh')
    indices = np.where(sig_bool)[0]
    return sig_bool, rs_[indices].min()


def filter_r(rs, ps):
    # remove nan
    indices = np.isnan(rs)
    rs[indices] = 0
    ps[indices] = 1

    # correct
    ps, threshold = correct(ps, rs)
    ps = np.invert(ps)
    indices = np.where(ps)[0]
    rs[indices] = 0.
    rs_mask = np.copy(rs)
    rs_mask[rs != 0.] = 1.
    return rs, rs_mask, threshold

In [96]:
#
process = 'PlotOverlap'

#Paths
loc_dir = '/Users/emcmaho7/Dropbox/projects/SI_fmri/SIfMRI_analysis/data/raw/localizer_stats'
stat_dir = '/Users/emcmaho7/Dropbox/projects/SI_fmri/SIfMRI_analysis/data/interim/VoxelPermutation'
mask_dir = '/Users/emcmaho7/Dropbox/projects/SI_fmri/SIfMRI_analysis/data/interim/Reliability'
figure_dir_top = f'/Users/emcmaho7/Dropbox/projects/SI_fmri/SIfMRI_analysis/reports/figures/{process}'

#Surf plotting
fsaverage = datasets.fetch_surf_fsaverage()
cmap = sns.color_palette('magma', as_cmap=True)

#ROIS
tasks = ['tom', 'SIpSTS', 'FBOS', 'FBOS', 'biomotion']
afni_labels = ['Belief-Photo_GLT#0_Tstat', 'Interact-Non_GLT#0_Tstat',
               'Faces-Obj_GLT#0_Tstat', 'Bodies-Obj_GLT#0_Tstat',
               'Bio-Trans_GLT#0_Tstat']
titles = ['ToM', 'SIpSTS', 'Faces', 'Bodies', 'Biomotion']

#Subjects
sids = ['sub-01', 'sub-02', 'sub-03', 'sub-04']

#Features
features = ['JointactionCommunicationCooperationDominanceIntimacyValenceArousal',
           'IndoorExpanseTransitivity',
           'AgentdistanceFacingness']
feature_names = ['social', 'scene_object', 'social_primitive']

In [None]:
for i_task in range(len(tasks)):
    for i_sid in range(len(sids)):
        for feature, feature_name in zip(features, feature_names):
            figure_dir = f'{figure_dir_top}/{sids[i_sid]}/{feature_name}/'
            Path(figure_dir).mkdir(parents=True, exist_ok=True)

            # Load masks
            mask_im = nib.load(f'{mask_dir}/{sids[i_sid]}_set-test_stat-rho_statmap.nii.gz')
            mask = np.load(f'{mask_dir}/{sids[i_sid]}_set-test_reliability-mask.npy')

            # Load ROI stats
            brik = nib.load(f'{loc_dir}/{sids[i_sid]}.{tasks[i_task]}.results/stats.{sids[i_sid]}.{tasks[i_task]}+tlrc.BRIK')
            labels = brik.header.get_volume_labels()
            loc = np.array(brik.dataobj)[..., labels.index(afni_labels[i_task])]

            # Plot ROI stats
            loc_img = nib.Nifti1Image(loc, brik.affine, brik.header)
            texture = {'left': surface.vol_to_surf(loc_img, fsaverage['pial_left'],
                                                   interpolation='nearest'),
                       'right': surface.vol_to_surf(loc_img, fsaverage['pial_right'],
                                                    interpolation='nearest')}
            cp.plot_surface_stats(fsaverage, texture,
                                  # roi=self.roi_parcel,
                                  cmap=sns.color_palette('icefire', as_cmap=True),
                                  modes=['lateral', 'ventral'],
                                  output_file=f'{figure_dir}/task-{titles[i_task]}_view-vol_roimap.png',
                                  threshold=stats.norm.isf(0.001),
                                  title=titles[i_task])

            # Filter the significant voxels and the ones that are greater than 0
            loc[loc < stats.norm.isf(0.001)] = 0.
            loc[loc < 0] = 0.
            # Make binary
            loc[loc != 0] = 1.
            loc = loc.astype('bool').flatten()

            # Load the prediction accuracy
            base = f'{stat_dir}/{sids[i_sid]}_model-full_predict-{feature}_control-conv2_pca-False'
            rs = np.load(f'{base}_rs.npy')
            ps = np.load(f'{base}_ps.npy')
            rs, _, threshold = filter_r(rs, ps)
            rs = cp.mkNifti(rs, mask, mask_im, nii=False)

            #Filter the rmap by the localizer
            i = np.where(np.invert(loc))
            rs[i] = 0.

            #Normalize rs by the split-half reliability
            reliability = np.array(mask_im.dataobj).flatten()
            reliability[np.isclose(reliability, 0.)] = np.nan
            rs /= reliability
            rs = nib.Nifti1Image(rs.reshape(mask_im.shape), affine=mask_im.affine)

            # Plot overlap
            plotting.plot_stat_map(rs, threshold=threshold,
                                   cmap=cmap,
                                   vmax=1.,
                                   output_file=f'{figure_dir}/task-{titles[i_task]}_view-vol_conjunction.png')
            texture = {'left': surface.vol_to_surf(rs, fsaverage['pial_left'],
                                                   interpolation='nearest'),
                       'right': surface.vol_to_surf(rs, fsaverage['pial_right'],
                                                    interpolation='nearest')}
            cp.plot_surface_stats(fsaverage, texture,
                                  # roi=self.roi_parcel,
                                  cmap=cmap,
                                  modes=['lateral', 'ventral'],
                                  output_file=f'{figure_dir}/task-{titles[i_task]}_view-surf_conjunction.png',
                                  threshold=threshold,
                                  vmax=1.,
                                  title=titles[i_task])

  warn(
  warn(
  warn(
  warn(
  warn(
  warn(
  warn(
  warn(
  warn(
  warn(
  warn(
  warn(
  warn(
  warn(
  warn(
  warn(
  warn(
  warn(
  warn(
