In [None]:
import numpy as np
from scipy.io import loadmat
from scipy import stats
from IPython.core.debugger import set_trace
import pickle
import os
from tqdm import tqdm
import random
from joblib import Parallel, delayed

random.seed(97)
np.random.seed(97)

root_save_dir = "predictions/"
sub_data = "sub_space_data/"

In [None]:
def FDR(vector, q, do_correction = False):
    original_shape = vector.shape
    vector = vector.flatten()
    N = vector.shape[0]
    sorted_vector = sorted(vector)
    if do_correction:
        C = np.sum([1.0/i for i in range(N)])
    else:
        C = 1.0
    thresh = 0
    #a=b
    for i in range(N-1, 0, -1):
        if sorted_vector[i]<= (i*1.0)/N*q/C:
            thresh = sorted_vector[i]
            break
    thresh_vector = vector<=thresh
    thresh_vector = thresh_vector.reshape(original_shape)
    thresh_vector = thresh_vector*1.0
    print("FDR threshold is : {}, {} voxels rejected".format(thresh, thresh_vector.sum()))
    return thresh_vector, thresh

In [None]:
all_subjects = ["F","G","H","I","J","K","L","M","N"]
all_feature_pairs = [
                     ("node_count_punct", "punct_final"), # (NC + PU) - (PU)
                     ("syntactic_surprisal_punct", "punct_final"), # (SS + PU) - (PU)                 
                     ("word_frequency_punct", "punct_final"), # (WF + PU) - (PU)
                     ("word_length_punct", "punct_final"), # (WL + PU) - (PU)
                     ("all_complexity_metrics_punct", "punct_final"), # (CM + PU) - (PU)
                     ("pos_dep_tags_all_complexity_metrics", "all_complexity_metrics_punct"), # since PD already contains PU, this tests (PD + CM + PU) - (CM + PU)
                     ("aggregated_contrege_comp_pos_dep_tags_all_complexity_metrics", "pos_dep_tags_all_complexity_metrics"), # (CC + PD + CM + PU) - (PD + CM + PU)
                     ("aggregated_contrege_incomp_pos_dep_tags_all_complexity_metrics", "pos_dep_tags_all_complexity_metrics"), # (CI + PD + CM + PU) - (PD + CM + PU)
                     ("aggregated_incontrege_pos_dep_tags_all_complexity_metrics", "pos_dep_tags_all_complexity_metrics"), # (INC + PD + CM + PU) - (PD + CM + PU)
                     ("aggregated_bert_PCA_dims_15_contrege_incomp_pos_dep_tags_all_complexity_metrics", "aggregated_contrege_incomp_pos_dep_tags_all_complexity_metrics") # (BERT + CI + PD + CM + PU) - (CI + PD + CM + PU)
                    ]

In [None]:
all_uncorrected_sig = []
for sub in all_subjects:
    print(sub)
    punct_uncorrected_sig = np.load(root_save_dir + "punct_final" + "/{}_sig.npy".format(sub))
    all_uncorrected_sig.append(punct_uncorrected_sig)
    
    for feat in all_feature_pairs:
        uncorrected_sig = np.load(root_save_dir + "{}_diff_{}".format(feat[0],feat[1]) + "/{}_sig_boot.npy".format(sub))
        all_uncorrected_sig.append(uncorrected_sig)

In [None]:
q = 0.05
all_corrected_sig, _ = FDR(np.hstack(all_uncorrected_sig), q)

In [None]:
last_end = 0
ind = 0
for sub in all_subjects:
    print(sub)
    punct_corrected_sig = all_corrected_sig[last_end:last_end + all_uncorrected_sig[ind].shape[0]]
    np.save(root_save_dir + "punct_final" + "/{}_sig_group_corrected".format(sub), punct_corrected_sig)
    last_end += all_uncorrected_sig[ind].shape[0]
    ind += 1
    
    print("punct_final: Num voxels rejected = {}".format(punct_corrected_sig.sum()))
    
    for feat in all_feature_pairs:
        corrected_sig = all_corrected_sig[last_end:last_end + all_uncorrected_sig[ind].shape[0]]    
        np.save(root_save_dir + "{}_diff_{}".format(feat[0],feat[1]) + "/{}_sig_bootstrap_group_corrected".format(sub), corrected_sig)
        last_end += all_uncorrected_sig[ind].shape[0]
        ind += 1
        
        print("{}: Num voxels rejected = {}".format("{}_diff_{}".format(feat[0],feat[1]), corrected_sig.sum()))