In [None]:
import numpy as np
import pickle
import os 
from scipy.stats import pearsonr
import scipy.linalg
from tqdm import tqdm
from sklearn.preprocessing import StandardScaler
from sklearn.discriminant_analysis import _cov

def load_data(file):
   
    print('loading file: ' + file)
    with open(file, 'rb') as f:
        data = pickle.load(f)

    return(data)

def dump_data(data, filename):
    print('writing file: ' + filename)
    with open(filename, 'wb') as f:
        pickle.dump(data, f, pickle.HIGHEST_PROTOCOL)
        
def subsample_data(dat, sub_factor):
    dshape = dat['eeg'].shape
    sub_ix = list(range(0, dshape[-1], sub_factor))
    dat['eeg'] = dat['eeg'][:,:,sub_ix]
    dat['time'] = dat['time'][sub_ix]
    
    return dat
    
def get_pseudotrials(eeg_dat, tr_num):
    shape = eeg_dat.shape
    k = shape[1]
    l = int(shape[1] / k)
    
    while l < int(tr_num):
        k = k - 1
        l = int(shape[1] / k)

    eeg_dat = eeg_dat[:,np.random.permutation(shape[1]),:,:]
    eeg_dat = eeg_dat[:,:l*k,:,:]

    pst = np.reshape(eeg_dat, (shape[0], k, l, shape[2],shape[3]))
    pst = pst.mean(axis=1)

    return(pst, k)

def average_across_points(dat, window_size=10):
    dshape = dat['eeg'].shape
    
    new_length = dshape[-1] // window_size
    eeg_reshaped = dat['eeg'][:, :, :new_length * window_size].reshape(dshape[0], dshape[1], new_length, window_size)
    dat['eeg'] = eeg_reshaped.mean(axis=-1)
    dat['time'] = dat['time'][:new_length * window_size].reshape(new_length, window_size).mean(axis=-1)
    
    return dat

### Load Data
sub = 0
conditions_1 = ["fix", "img"]
conditions_2 = ["det", "rand"]
subsample_factor = 10
trial_num = 12
img_nperms = 100
trial_lim = 300
main_path = "/projects/crunchie/boyanova/EEG_Things/eeg_experiment/"
rdm_data = {}

for cond in conditions_1:
    for cond2 in conditions_2:
        cond_name = "{}_{}".format(cond, cond2)
        
        print(cond_name)
        dat_name = os.path.join(main_path, f"eeg_epoched/eeg_things_{sub:04d}_{cond}.pickle")
        dat = load_data(dat_name)

        ### Subsample data
        dat = average_across_points(dat, window_size=10)
        
        ### Button press mask
        bt_press = dat["button_press_mask"]
        dat["eeg"] = dat["eeg"][~bt_press]
        dat["ids"] = dat["ids"][~bt_press]

        ### Select condition 
        image_labels = [1,2,3,4]
        if cond2 == "rand":
            image_labels = [im + 10 for im in image_labels]

        mask = np.isin(dat["ids"], image_labels)

        eeg_ = dat["eeg"][mask]
        ids_ = dat["ids"][mask]

        ### Get vars
        n_conditions = len(image_labels)
        n_sensors = eeg_.shape[1]
        n_time = eeg_.shape[-1]

        # DA matrix to store RDMs
        TG = np.full((n_conditions, n_conditions, n_time), np.nan)

        ### Randomly pick 300 trials per condition  
        eeg_rdm = np.full((len((np.unique(ids_))), trial_lim, eeg_.shape[1], eeg_.shape[2]), np.nan)
       
        for p in tqdm(range(img_nperms)):
            for idx, x in enumerate(np.unique(ids_)):
                total_num_trials = len(ids_[ids_ == x])
                range_array = np.arange(0, total_num_trials)
                random_numbers = np.random.choice(range_array, trial_lim, replace=False)
                eeg_rdm[idx, :, :, :] = eeg_[ids_ == x][random_numbers, :, :]

            # Create pseudotrials
            pstrials, binsize = get_pseudotrials(eeg_rdm, trial_num)
            n_pstrials = pstrials.shape[1]

            # Whitening with multivariate noise normalization
            sigma_ = np.empty((n_conditions, n_sensors, n_sensors))
            for c in range(n_conditions):
                sigma_[c] = np.mean([_cov(pstrials[c, :, :, t], shrinkage='auto')
                                     for t in range(n_time)], axis=0)
            sigma = sigma_.mean(axis=0)
            sigma_inv = scipy.linalg.fractional_matrix_power(sigma, -0.5)

            pstrials = (pstrials.swapaxes(2, 3) @ sigma_inv).swapaxes(2, 3)

            #################
            # RDM calculation: Pearson's correlation dissimilarity
            for cA in range(n_conditions):
                for cB in range(cA + 1, n_conditions):
                    for t in range(n_time):
                        # Get pseudotrial data for condition A and condition B at time t
                        pseudoA = pstrials[cA, :, :, t]
                        pseudoB = pstrials[cB, :, :, t]
                        
                        # Initialize list to store dissimilarity values
                        dissimilarities = []

                        # Calculate Pearson's correlation for each pseudotrial pair and compute dissimilarity
                        for pa in range(pseudoA.shape[0]):
                            for pb in range(pseudoB.shape[0]):
                                # Compute Pearson correlation and convert to dissimilarity
                                corr_coef = np.corrcoef(pseudoA[pa], pseudoB[pb])[0, 1]
                                dissimilarity = 1 - corr_coef
                                dissimilarities.append(dissimilarity)

                        # Average dissimilarity and store in TG matrix
                        TG[cA, cB, t] = np.nanmean(dissimilarities)

        # Average RDM across permutations
        TG = TG / img_nperms
        rdm_data[cond_name] = TG

# Save the decoding data as a pickle file
dump_data(rdm_data, "/projects/crunchie/boyanova/EEG_Things/eeg_experiment/eeg_rsa/eeg_rdms_{:04d}.pickle".format(sub))


fix_det
loading file: /projects/crunchie/boyanova/EEG_Things/eeg_experiment/eeg_epoched/eeg_things_0000_fix.pickle


 20%|██        | 20/100 [01:29<05:58,  4.48s/it]