In [None]:
import os
import numpy as np
import cortex, time
from cortex import mni
import nibabel
from IPython.core.debugger import set_trace
import subprocess

fsldir = os.environ.get('FSLDIR')

affine_transforms = dict()
mni_transforms = dict()
masks = dict()

affine_tranforms_dir = "affine_transforms"
mni_transforms_dir = "mni_transforms"
masks_dir = "masks"

subjects = ['F','G','H','I','J','K','L','M','N']

for subject in subjects:
    affine_transforms[subject] = np.load(os.path.join(affine_tranforms_dir, subject + ".npy"))
    mni_transforms[subject] = np.loadtxt(os.path.join(mni_transforms_dir, subject + ".txt"))
    masks[subject] = np.load(os.path.join(masks_dir, subject + ".npy"))    

In [None]:
mask_mni = cortex.db.get_mask('MNI', 'atlas', 'thin')
n_v_mni = mask_mni.sum()
print(n_v_mni)

In [None]:
root_data_dir = "predictions/"
root_save_dir = "predictions_mni/"

features = ["punct_final"]
all_feature_pairs = [
                     ("node_count_punct", "punct_final"), # (NC + PU) - (PU)
                     ("syntactic_surprisal_punct", "punct_final"), # (SS + PU) - (PU)                 
                     ("word_frequency_punct", "punct_final"), # (WF + PU) - (PU)
                     ("word_length_punct", "punct_final"), # (WL + PU) - (PU)
                     ("all_complexity_metrics_punct", "punct_final"), # (CM + PU) - (PU)
                     ("pos_dep_tags_all_complexity_metrics", "all_complexity_metrics_punct"), # since PD already contains PU, this tests (PD + CM + PU) - (CM + PU)
                     ("aggregated_contrege_comp_pos_dep_tags_all_complexity_metrics", "pos_dep_tags_all_complexity_metrics"), # (CC + PD + CM + PU) - (PD + CM + PU)
                     ("aggregated_contrege_incomp_pos_dep_tags_all_complexity_metrics", "pos_dep_tags_all_complexity_metrics"), # (CI + PD + CM + PU) - (PD + CM + PU)
                     ("aggregated_incontrege_pos_dep_tags_all_complexity_metrics", "pos_dep_tags_all_complexity_metrics"), # (INC + PD + CM + PU) - (PD + CM + PU)
                     ("aggregated_bert_PCA_dims_15_contrege_incomp_pos_dep_tags_all_complexity_metrics", "aggregated_contrege_incomp_pos_dep_tags_all_complexity_metrics") # (BERT + CI + PD + CM + PU) - (CI + PD + CM + PU)
                    ]

for p in all_feature_pairs:
    features.append("{}_diff_{}".format(p[0], p[1]))
    features.append(p[0])

print(features)

if not os.path.exists(root_save_dir):
    os.mkdir(root_save_dir)
    
all_pats = ["sig_group_corrected" ,"sig_bootstrap_group_corrected", "r2s"]
    
for pat in all_pats:
    for feature in features:
        print(feature)
        data_dir = root_data_dir + feature + "/"
        save_dir = root_save_dir + feature + "/"
        if not os.path.exists(save_dir):
            os.mkdir(save_dir)
        mni_masked = dict()
        for subject in subjects:
            print(subject)
            if os.path.exists(save_dir + subject + "_{}.npy".format(pat)):
                continue
            if not os.path.exists(data_dir + subject + "_{}.npy".format(pat)):
                print("Skipping {}".format(subject))
                continue
            data = np.load(data_dir + subject + "_{}.npy".format(pat))
            print(data.shape)
            
            # use mask to get 3D func data
            mask = masks[subject]
            data3D = np.zeros(mask.shape)
            data3D[mask] = data
            
            # save as nii file
            affine = affine_transforms[subject]
            new_nii = nibabel.Nifti1Image(data3D.T, affine)
            nibabel.save(new_nii, "temp.nii")
            
            # make temporary file containing MNI transform
            np.savetxt("temp.mat", mni_transforms[subject], "%0.10f")
            
            # Use FSL's flirt to resample functional data
            subprocess.call(["flirt",
                             "-in", "temp.nii",
                             "-ref", "{}/data/standard/MNI152_T1_1mm_brain.nii.gz".format(fsldir),
                             "-applyxfm", "-init", "temp.mat",
                             "-out", "temp_out.nii.gz"])

            mni_vol = nibabel.load("temp_out.nii.gz").get_data().T            
            mni_masked[subject] = mni_vol[mask_mni]
            
            # remove temp files
            os.remove("temp.nii")
            os.remove("temp.mat")
            os.remove("temp_out.nii.gz")
            
            # save MNI output
            np.save(save_dir + subject + "_{}.npy".format(pat),mni_masked[subject])