
   ### Do temp gen fix and img (8)
    ## - train on img_rand test on fix_rand
    ## - train on img_radn test on fix_det
    ## - train on img_det test on fix_rand
    ## - train on img_det test on fix_det
    
    ## - train on fix_rand test on img_rand
    ## - train on fix_radn test on img_det
    ## - train on fix_det test on img_rand
    ## - train on fix_det test on img_det

In [5]:
import time
import numpy as np
import os
from tqdm.notebook import tqdm
import pickle
from sklearn.discriminant_analysis import _cov
import scipy
from sklearn.svm import LinearSVC
from sklearn.metrics import accuracy_score

def load_data(file):
   
    print('loading file: ' + file)
    with open(file, 'rb') as f:
        data = pickle.load(f)

    return(data)

def dump_data(data, filename):
    print('writing file: ' + filename)
    with open(filename, 'wb') as f:
        pickle.dump(data, f, pickle.HIGHEST_PROTOCOL)
        
def subsample_data(dat, sub_factor):
    dshape = dat['eeg'].shape
    sub_ix = list(range(0, dshape[-1], sub_factor))
    dat['eeg'] = dat['eeg'][:,:,sub_ix]
    dat['time'] = dat['time'][sub_ix]
    
    return dat
    

def get_pseudotrials(eeg_dat, tr_num):
    """
    Applies pseudotrial creation independently for each slice of the first leading dimension.
    
    Parameters:
        eeg_dat (numpy.ndarray): EEG data with shape (2, 4, 150, 64, 90).
        tr_num (int): Desired number of trials.
    
    Returns:
        pst (numpy.ndarray): Pseudotrials with shape (2, 4, tr_num, 64, 90).
        k (int): Number of chunks each trial is divided into.
    """
    # Prepare to store results
    results = []
    k_values = []

    # Iterate over the first leading dimension
    for i in range(eeg_dat.shape[0]):
        single_data = eeg_dat[i]  # Extract slice with shape (4, 150, 64, 90)
        
        shape = single_data.shape
        k = shape[1]  # Start with the number of trials
        l = int(shape[1] / k)
        
        # Adjust k and l to fit tr_num
        while l < int(tr_num):
            k = k - 1
            l = int(shape[1] / k)

        # Shuffle and reshape for pseudotrials
        single_data = single_data[:, np.random.permutation(shape[1]), :, :]
        single_data = single_data[:, :l*k, :, :]
        single_data = single_data.reshape(shape[0], k, l, shape[2], shape[3])
        
        # Average across the pseudotrial axis
        pst_single = single_data.mean(axis=1)  # Resulting shape: (4, l, 64, 90)

        results.append(pst_single)  # Append results for this slice
        k_values.append(k)

    # Stack results along the first axis to combine (2, 4, l, 64, 90)
    pst = np.stack(results, axis=0)
    k = k_values[0]  # Assuming k is the same across all slices
    
    return pst, k


def average_across_points(dat, window_size=10):
    dshape = dat['eeg'].shape
    
    new_length = dshape[-1] // window_size
    eeg_reshaped = dat['eeg'][:, :, :new_length * window_size].reshape(dshape[0], dshape[1], new_length, window_size)
    dat['eeg'] = eeg_reshaped.mean(axis=-1)
    dat['time'] = dat['time'][:new_length * window_size].reshape(new_length, window_size).mean(axis=-1)
    
    return dat

def select_partition(data, cond):
    
    ### Select condition 
    image_labels = [1,2,3,4]
    if cond == "rand":
        image_labels = [im + 10 for im in image_labels]

    mask = np.isin(data["ids"], image_labels)

    eeg = data["eeg"][mask]
    ids = data["ids"][mask]
    return eeg, ids

def random_eeg_pick(eeg, ids):
    eeg_svm = np.full((len((np.unique(ids))), trial_lim, eeg.shape[1], eeg.shape[2]), np.nan)
    for idx, x in enumerate(np.unique(ids)):
        
        total_num_trials = len(ids[ids == x])
        range_array = np.arange(0, total_num_trials)
        random_numbers = np.random.choice(range_array, trial_lim, replace=False)
        # Select
        eeg_svm[idx, :, :, :] = eeg[ids == x][random_numbers, :, :]
        
    return eeg_svm
    

In [6]:
### Load Data


subsample_factor = 10
testsize = 0.2
trial_num = 12
img_nperms = 25
trial_lim = 150

decoding_data = {}

pairs = [("fix_det", "img_det"),
         ("fix_rand", "img_rand"),
         ("img_det", "fix_det"),
         ("img_rand", "fix_rand")]

for pair_train, pair_test in pairs:
    print(f"Training on: {pair_train}, Testing on: {pair_test}")
    cond_name = f"{pair_train}_{pair_test}"
    train_out = pair_train.split("_")[0]
    train_in = pair_train.split("_")[1]
    
    test_out = pair_test.split("_")[0]
    test_in = pair_test.split("_")[1]
    
    if train_out != test_out:
        dat_name = f"/projects/crunchie/boyanova/EEG_Things/eeg_experiment/eeg_epoched/eeg_things_{sub:04d}_{train_out}.pickle"
        dat_train = load_data(dat_name)
        
        ### Subsample data
        dat_train = average_across_points(dat_train, window_size=10)

        ### Button press mask
        bt_press = dat_train["button_press_mask"]
        dat_train["eeg"] = dat_train["eeg"][~bt_press]
        dat_train["ids"] = dat_train["ids"][~bt_press]
        dat_train["block_num"] = dat_train["block_num"][~bt_press]
        
        dat_name = f"/projects/crunchie/boyanova/EEG_Things/eeg_experiment/eeg_epoched/eeg_things_{sub:04d}_{test_out}.pickle"
        dat_test = load_data(dat_name)
        
        ### Subsample data
        dat_test = average_across_points(dat_test, window_size=10)

        ### Button press mask
        bt_press = dat_test["button_press_mask"]
        dat_test["eeg"] = dat_test["eeg"][~bt_press]
        dat_test["ids"] = dat_test["ids"][~bt_press]
        dat_test["block_num"] = dat_test["block_num"][~bt_press]        
        
        ### Get vars
        n_conditions = len(range(4))
        n_sensors = dat_train["eeg"].shape[1]
        n_time = dat_train["eeg"].shape[-1]

        ### DA matrix 
        TG = np.full((n_conditions, n_conditions, n_time, n_time), np.nan)        
        train_eeg, train_ids = select_partition(dat_train, train_in)
        test_eeg, test_ids = select_partition(dat_test, test_in)
          
        for p in tqdm(range(img_nperms)):
            eeg_svm_train = random_eeg_pick(train_eeg, train_ids)
            eeg_svm_test = random_eeg_pick(test_eeg, test_ids)

            eeg_general = np.stack((eeg_svm_train, eeg_svm_test), axis=0)
            pstrials, binsize = get_pseudotrials(eeg_general, trial_num)
           
            n_pstrials = pstrials.shape[2]
            n_test = int(n_pstrials * testsize)
            ps_ixs = np.arange(n_pstrials)
            cvs = int(n_pstrials / n_test)

            for cv in range(cvs):
                print('cv: {}, out of: {}'.format(cv+1, cvs))

                # we take idxs for the test/train
                test_ix = np.arange(n_test) + (cv * n_test)
                train_ix = np.delete(ps_ixs.copy(), test_ix)

                # subset idxs from the pseudotrials
                ps_train = pstrials[0]
                ps_test = pstrials[1]
                ps_train = ps_train[:,train_ix,:,:]
                ps_test = ps_test[:,test_ix,:,:]

                sigma_ = np.empty((n_conditions, n_sensors, n_sensors))
                for c in range(n_conditions):
                    # compute sigma for each time point, then average across time
                    sigma_[c] = np.mean([_cov(ps_train[c, :, :, t], shrinkage='auto')
                                        for t in range(n_time)], axis=0)
                
                # average across conditions
                sigma = sigma_.mean(axis=0)  
                # the formula is sigma * -1/2 // reason for sigma_inv
                sigma_inv = scipy.linalg.fractional_matrix_power(sigma, -0.5)

                # apply sigma to pseudo trials 
                ps_train = (ps_train.swapaxes(2, 3) @ sigma_inv).swapaxes(2, 3)
                ps_test = (ps_test.swapaxes(2, 3) @ sigma_inv).swapaxes(2, 3)

                # decoding: cA image vs cB (cA + 1 :) // then do it for each time point 
                for cA in range(n_conditions):
                    #print('decoding image ' + str(cA))
                    for cB in range(cA+1, n_conditions):
                        for t in range(n_time):
                            # retrieve the patterns from pseudotrials that correspond to cA and cB at time pt t
                            train_x = np.array((ps_train[cA,:,:,t], ps_train[cB,:,:,t]))
                            # concatinate them
                            train_x = np.reshape(train_x,(len(train_ix)*2, n_sensors))
                            # do the same with the test set, but here we take all time points 
                            test_x = np.array((ps_test[cA], ps_test[cB]))
                            test_x = np.reshape(test_x,(len(test_ix)*2, n_sensors, n_time))
                            # config labesls 1 for cA and 2 for cB
                            train_y = np.array([1] * len(train_ix) + [2] * len(train_ix))
                            test_y = np.array([1] * len(test_ix) + [2] * len(test_ix))

                            # instantiate a classifier 
                            classifier = LinearSVC(dual=True,
                                                    penalty = 'l2',
                                                    loss = 'hinge',
                                                    C = .5,
                                                    multi_class = 'ovr',
                                                    fit_intercept = True,
                                                    max_iter = 10000)
                            # train it
                            classifier.fit(train_x, train_y)
                            for tt in range(n_time):
                                pred_y = classifier.predict(test_x[:,:,tt])
                                acc_score = accuracy_score(test_y,pred_y)
                                # we store the acc score in the temp gen mattrix 
                                TG[cA,cB,t,tt] = np.nansum(np.array((TG[cA,cB,t,tt],acc_score)))
        TG = TG / (img_nperms * cvs)
        decoding_data[cond_name] = TG
        
dump_data(decoding_data, "/projects/crunchie/boyanova/EEG_Things/eeg_experiment/eeg_decoding/eeg_TG_between_att_{:04d}.pickle".format(sub))

Training on: fix_det, Testing on: img_det
loading file: /projects/crunchie/boyanova/EEG_Things/eeg_experiment/eeg_epoched/eeg_things_0005_fix.pickle
loading file: /projects/crunchie/boyanova/EEG_Things/eeg_experiment/eeg_epoched/eeg_things_0005_img.pickle


  0%|          | 0/25 [00:00<?, ?it/s]

cv: 1, out of: 6
cv: 2, out of: 6
cv: 3, out of: 6
cv: 4, out of: 6
cv: 5, out of: 6
cv: 6, out of: 6
cv: 1, out of: 6


KeyboardInterrupt: 