In [None]:
import time
import pickle
import nibabel
import cortex
from IPython.core.debugger import set_trace
import matplotlib.pyplot as plt
import numpy as np
import os
import warnings
import subprocess
import pandas as pd
import seaborn as sns

if not os.path.exists("figures"):
    os.mkdir("figures")
    
if not os.path.exists("figures/r2plus_figures"):
    os.mkdir("figures/r2plus_figures")
    
if not os.path.exists("figures/sig_figures"):
    os.mkdir("figures/sig_figures")
    
if not os.path.exists("figures/roi_figures"):
    os.mkdir("figures/roi_figures")

In [None]:
file_pattern = "{base}_{view}_{surface}.png"
_combine = lambda a,b: ( lambda c: [c, c.update(b)][0] )(dict(a))
_tolists = lambda p: {k:[v] for k,v in p.items()}

def save_3d_views(data, root, base_name, list_views =['lateral'],list_surfaces = ['fiducial'], with_labels = False,
                  size=(1024*4, 768*4), trim=True):
    """Saves 3D views of `data` in and around `root` under multiple specifications. Needs to be run
       on a system with a display (will launch webgl viewer)
    data: a pycortex volume
    root: directory where things should be saved
    base_name: base name for images
    list_views: which views do you want? choices are: lateral, lateral_left, lateral_right,
               medial, front, back,top, bottom
    list_surfaces: what surfaces do you want? choices are inflated, flatmap, fiducial
    with_labels: show ROI labels?
    size: size of produced image (before trimming)
    trim: whether to trim
    returns filenames: a dict of the produced image paths
    """
    # Create root dir?
    if not os.path.exists(root):
        os.mkdir(root)

    # Create viewer
    if with_labels:
        labels_visible=('rois',)
    else:
        labels_visible=()
    handle = cortex.webgl.show(data,labels_visible=labels_visible)

    time.sleep(5.0)


    basic = dict()#radius=400)#projection=['orthographic'], #radius=260, visL=True, visR=True)

    views = dict(lateral=dict(altitude=80.5, azimuth=181, pivot=180.5, radius = 700),#, specularity = 0),
                 lateral_left=dict(altitude=90.5, azimuth=90.5, pivot=0.5),
                 lateral_right=dict(altitude=90.5, azimuth=270.5, pivot=0.5),
                 medial=dict(altitude=90.5, azimuth=0.5, pivot=180.5),
                 front=dict(altitude=90.5, azimuth=0, pivot=0),
                 back=dict(altitude=90.5, azimuth=181, pivot=0),
                 top=dict(altitude=0, azimuth=180, pivot=0),
                 bottom=dict(altitude=180, azimuth=0, pivot=0)
                )

    surfaces = dict(inflated=dict(unfold= 0.5) ,
                    flatmap=dict(unfold= 1) ,
                    fiducial=dict(unfold= 0)
                   )

    param_dict = dict(unfold = 'surface.{subject}.unfold',
                      altitude = 'camera.altitude',
                      azimuth = 'camera.azimuth',
                      pivot = 'surface.{subject}.pivot',
                      radius = 'camera.radius')
#                      specularity = 'surface.{subject}.specularity') # unknown parameter


    # Save views!
    filenames = dict([(key, dict()) for key in surfaces.keys()])

    for view in list_views:
        # copy proper parameters with new names
        vparams = dict([(param_dict[k], v) for k,v in views[view].items()])
        for surf in list_surfaces:
            # copy proper parameters with new names
            sparams = dict([(param_dict[k], v) for k,v in surfaces[surf].items()])
            # Combine basic, view, and surface parameters
            params = _combine(_combine(basic, vparams), sparams)

            # Set the view
            handle._set_view(**_tolists(params))
            time.sleep(5.5)

            # Save image, store filename
            filename = file_pattern.format(base=base_name, view=view, surface=surf)
            filenames[surf][view] = filename
            # filenames.append(filename)

            output_path = os.path.join(root, filename)
            print(output_path)
            handle.getImage(output_path, size)
            
            while not os.path.exists(output_path):
                pass
            time.sleep(0.1)

            # Trim edges?
            if trim:
                # Wait for browser to dump file
                while not os.path.exists(output_path):
                    pass

                time.sleep(0.5)
                subprocess.call(["convert", "-trim", output_path, output_path])

    # Close the window!
    try:
        handle.close()
    except:
        print('Could not close viewer')

    return filenames

In [None]:
# sqrt R2 figures
def sqrt_r2(vec):
    vec2=np.zeros_like(vec)
    vec2[vec>0] = vec[vec>0]
    vec2 = np.sqrt(vec2)
    return vec2

predictions_dir = "predictions_mni/"

mask = pickle.load(open("mni_mask.pkl","rb"))
vols = {}
all_subjects = ["F","G","H","I","J","K","L","M","N"]

all_features = ["punct_final",
                "node_count_punct",
                "syntactic_surprisal_punct",
                "word_length_punct",
                "word_frequency_punct",
                "all_effort_based_metrics_punct",
                "pos_dep_tags_all_effort_based_metrics", 
                "aggregated_contrege_comp_pos_dep_tags_all_effort_based_metrics",
                "aggregated_contrege_incomp_pos_dep_tags_all_effort_based_metrics",
                "aggregated_incontrege_pos_dep_tags_all_effort_based_metrics",
                "aggregated_bert_PCA_dims_15_contrege_incomp_pos_dep_tags_all_effort_based_metrics"]

for feat in all_features:
    x = 0
    for sub in all_subjects:
        p = np.load(predictions_dir + "{}/{}_r2s.npy".format(feat, sub))
        if x is 0:
            x = np.zeros(p.shape)
        x += p/len(all_subjects)
    vols[feat] = x

v = {}
for feat in vols:
    v[feat] = cortex.Volume(sqrt_r2(vols[feat]), 'MNI','atlas',mask=mask, vmin = 0, vmax = 0.15, cmap="Reds")
    save_3d_views(v[feat], 'figures/r2plus_figures/',"r2plus_{}".format(feat), trim=True, list_surfaces = ['inflated'])

In [None]:
# sig corrected figures
predictions_dir = "predictions_mni/"

mask = pickle.load(open("mni_mask.pkl","rb"))
vols = {}
all_subjects = ["F","G","H","I","J","K","L","M","N"]

all_features = ["punct_final",
                "node_count_punct_diff_punct_final",
                "syntactic_surprisal_punct_diff_punct_final",
                "word_length_punct_diff_punct_final",
                "word_frequency_punct_diff_punct_final",
                "all_effort_based_metrics_punct_diff_punct_final",
                "pos_dep_tags_all_effort_based_metrics_diff_all_effort_based_metrics_punct",
                "aggregated_contrege_comp_pos_dep_tags_all_effort_based_metrics_diff_pos_dep_tags_all_effort_based_metrics", 
                "aggregated_contrege_incomp_pos_dep_tags_all_effort_based_metrics_diff_pos_dep_tags_all_effort_based_metrics",
                "aggregated_incontrege_pos_dep_tags_all_effort_based_metrics_diff_pos_dep_tags_all_effort_based_metrics",
                "aggregated_bert_PCA_dims_15_contrege_incomp_pos_dep_tags_all_effort_based_metrics_diff_aggregated_contrege_incomp_pos_dep_tags_all_effort_based_metrics"]

for feat in all_features:
    x = 0
    for sub in all_subjects:
        p = 0
        if os.path.exists(predictions_dir + "{}/{}_sig_bootstrap_group_corrected.npy".format(feat, sub)):
            p = np.load(predictions_dir + "{}/{}_sig_bootstrap_group_corrected.npy".format(feat, sub))
        elif os.path.exists(predictions_dir + "{}/{}_sig_group_corrected.npy".format(feat, sub)):
            p = np.load(predictions_dir + "{}/{}_sig_group_corrected.npy".format(feat, sub))
        if x is 0:
            x = np.zeros(p.shape)
        x[p!=0] += 1
    vols[feat] = x

v = {}
for feat in vols:
    v[feat] = cortex.Volume(vols[feat], 'MNI','atlas',mask=mask, vmin = 0, vmax = 5, cmap = "Greens")
    save_3d_views(v[feat], 'figures/sig_figures/',"sig_{}".format(feat), trim=True, list_surfaces = ['inflated'])

In [None]:
# blank plot of all ROIs
roi_file_name = os.path.join('LangParcels_n220_LH.hdr')
roi_dat = nibabel.load(roi_file_name)
roi_mat = roi_dat.get_data().T
mask_narrow_mni = np.zeros((91,109,91))
mask_narrow_mni[11:-11,7:-7,6:-6] = roi_mat
# flip
mask_narrow_mni[:,:,:45] = mask_narrow_mni[:,:,46:][:,:,::-1]
plt.hist(mask_narrow_mni[mask_narrow_mni>0],100)
print(np.unique(mask_narrow_mni))
for i in range(1,7):
    mask_narrow_mni[ (mask_narrow_mni>i-0.5)*(mask_narrow_mni<i+0.5)] = i
print(np.unique(mask_narrow_mni))
plt.hist(mask_narrow_mni[mask_narrow_mni>0],100)
plt.figure()
np.save("processed_fedorenko_masks_MNI", mask_narrow_mni)
vols_parcels_narrow = cortex.Volume(mask_narrow_mni,'MNI','atlas_2mm', vmin = 0, vmax = 6, cmap="gist_ncar_r")
save_3d_views(vols_parcels_narrow, 'figures/',"blank_plot_of_rois", trim=True, list_surfaces = ['inflated'], with_labels=True)

In [None]:
# ROI analysis
predictions_dir = 'predictions/'
all_subjects = ["F","G","H","I","J","K","L","M","N"]

mask_sub = np.load('mask_sub_fedorenko_HP_narrow_mirrored.npy', allow_pickle=True, fix_imports=True)[()] # because it was saved with np.save

mask_num = dict([(s,mask_sub[s][np.load("masks/{}.npy".format(s))])
                 for s in all_subjects])

roi_names = ['PostTemp',
             'AntTemp',
             'AngG',
             'IFG',
             'MFG',
             'IFGorb']

all_features = ["node_count_punct_diff_punct_final",
                "syntactic_surprisal_punct_diff_punct_final",
                "word_length_punct_diff_punct_final",
                "word_frequency_punct_diff_punct_final",
                "all_effort_based_metrics_punct_diff_punct_final",
                "pos_dep_tags_all_effort_based_metrics_diff_all_effort_based_metrics_punct",
                "aggregated_contrege_comp_pos_dep_tags_all_effort_based_metrics_diff_pos_dep_tags_all_effort_based_metrics", 
                "aggregated_contrege_incomp_pos_dep_tags_all_effort_based_metrics_diff_pos_dep_tags_all_effort_based_metrics",
                "aggregated_incontrege_pos_dep_tags_all_effort_based_metrics_diff_pos_dep_tags_all_effort_based_metrics",
                "aggregated_bert_PCA_dims_15_contrege_incomp_pos_dep_tags_all_effort_based_metrics_diff_aggregated_contrege_incomp_pos_dep_tags_all_effort_based_metrics"]

metric = "sig_bootstrap_group_corrected"

n_rois = len(roi_names)
NFS = len(all_features)
NS = len(all_subjects)
percent_sig = []
np.zeros((NFS, NS, n_rois))

all_names = ["{NC, PU} - {PU}",
             "{SS, PU} - {PU}",
             "{WF, PU} - {PU}",
             "{WL, PU} - {PU}",
             "{EF, PU} - {PU}", 
             "{PD, EF, PU} - {EF, PU}", 
             "{CC, PD, EF, PU} - {PD, EF, PU}", 
             "{CI, PD, EF, PU} - {PD, EF, PU}",
             "{INC, PD, EF, PU} - {PD, EF, PU}",
             "{BERT, CI, PD, EF, PU} - {CI, PD, EF, PU}"]

for i_s, sub in enumerate(all_subjects):
    mask = np.load("masks/{}.npy".format(sub))
    for i, roi in enumerate(roi_names):        
        roi_mask = np.where(mask_num[sub]==(i+1))[0]
        
        for ifs, feature in enumerate(all_features):
            feature_metrics = np.load("{}{}/{}_{}.npy".format(predictions_dir, feature, sub, metric))
            if len(feature_metrics.shape) > 1 and feature_metrics.shape[1] > 1:
                feature_metrics = feature_metrics[:,0]
            
            data_point = {}
            data_point["Test"] = all_names[ifs]
            data_point["Feature Added"] = all_names[ifs].split(",")[0][1:]
            data_point["Subject"] = sub
            data_point["ROI"] = roi
            data_point["% of significant voxels"] = np.nan_to_num(((feature_metrics[roi_mask] != 0).sum() / feature_metrics[roi_mask].shape[0]) * 100, posinf = 0, neginf = 0)
            
            percent_sig.append(data_point)
            
percent_sig = pd.DataFrame(percent_sig)

fig, ax = plt.subplots(figsize=(15,8))
plt.rcParams.update({'font.size': 20})
sns.barplot(x="ROI", y="% of significant voxels", hue="Test", data=percent_sig, ax=ax, ci=68)
ax.legend(fontsize = 20, bbox_to_anchor=(0.5, 1.16), loc='center',frameon=False, ncol=2)
plt.savefig(os.path.join("figures/roi_figures", "combined.png"), bbox_inches='tight')
plt.clf()

for ROI in roi_names:
    data_for_this_roi = percent_sig[percent_sig["ROI"] == ROI]
    fig, ax = plt.subplots(figsize=(15,8))
    ax.set_title(ROI)
    sns.barplot(x="Feature Added", y="% of significant voxels", \
                data=data_for_this_roi, ci=68, ax=ax, linewidth=2.5, facecolor=(1, 1, 1, 0), errcolor=".2", edgecolor=".2", capsize=0.25)
    sns.scatterplot(x="Feature Added", y="% of significant voxels", data=data_for_this_roi, ax=ax, color="red")
    plt.savefig(os.path.join("figures/roi_figures", "{}.png".format(ROI)), bbox_inches='tight')
    plt.clf()