In [10]:
import mne
import pandas as pd
from pymatreader import read_mat
import numpy as np
import os
from autoreject import AutoReject
from BCOM_processing.SCRIPTS.functions import Epoching
import os.path as op

In [None]:
root = "/Volumes/BCOM"

output_path = op.join(root, "ciprian_project/data_analyzed/evoked/data")
# os.makedirs(output_path, exist_ok=True)
raws_path = op.join(root, "BCOM/DATA_RAW")
cleaned_path = op.join(root, 'ciprian_project/data_analyzed/preprocessed')

# this channel has a weird position in the helmet coordinate space,
# so the thinking is to not interpolate it here
bad_localization_channel = "MEG 173" 

In [None]:
def interploate_no_bad_loc(raw, bad_loc_chan):
    bads_list = raw.info["bads"]
    bads_list_cleaned = [bad for bad in bads_list if bad != bad_loc_chan]
    raw.info["bads"] = bads_list_cleaned
    raw.interpolate_bads()
    raw.info['bads'].append(bad_loc_chan)

In [None]:
# set up directories
triggers_pd = pd.read_csv(op.join(root, "BCOM/PROTOCOL/trigger_labels.csv"), sep=';')

produce_triggers = triggers_pd['produce_head'].to_list() #the csv already has names

read_triggers = triggers_pd['read_head'].to_list()

syllables = triggers_pd['syllable'].to_list()

In [None]:
epoch_tmin=-0.4
epoch_tmax=0.8
reject_dict = dict(mag=5e-12)

subjects = [x for x in os.listdir(cleaned_path) if x[0] == 'B']
subjects.sort()
numbers = [x[-2:] for x in subjects]

conditions = ['OVERT', 'COVERT']
cleanings = ['WITHOUT_BADS', 'WITH_BADS']

evoked_output = os.path.join(output_path, 'DATA')

In [None]:
for trigger in produce_triggers:
    for sub in subjects:
        for i in range(2, 5):
            f_path = os.path.join(cleaned_path, sub, str(i), "subject_cleaned_raw.fif")
            cleaned_raw = mne.io.read_raw_fif(f_path, preload=True)

            path = os.path.join(raws_path, sub, "MEG", sub, "BCom")

            for file in os.listdir(path):
                if os.path.isdir(os.path.join(path, file)):
                    if sub == 'BCOM_08': #something special about this one
                        mat_path = os.path.join(path, file, str(i + 1))
                    else:
                        mat_path = os.path.join(path, file, str(i))
                    break

            events = np.load(op.join(cleaned_path, sub, str(i), 'resampled_events.npy'))

            produce_trigger = trigger
            read_trigger = trigger - 100

            all_read_events = [event for event in events if int(event[2]) == int(read_trigger)] 
            all_produce_events = [event for event in events if int(event[2]) == int(produce_trigger)] 

            Epo = Epoching(mat_path, events)

            bad_idx = Epo.get_bad_syll(raws_folder, sub, read_trigger, read_triggers, syllables, i)

            cleaned_read_events = all_read_events.copy()
            cleaned_produce_events = all_produce_events.copy()
            
            for idx in bad_idx[::-1]:
                cleaned_read_events.pop(idx)
                cleaned_produce_events.pop(idx)

            events_list = []

            all_read_events_covert = np.array([event for event in all_read_events if not Epo.is_overt(event)]) #gets the trial which contains the overt phase, and then returns the ones that are covert
            events_list.append(all_read_events_covert)
            all_read_events_overt = np.array([event for event in all_read_events if Epo.is_overt(event)]) #gets the trial which contains the overt phase, and then returns the ones that are overt
            events_list.append(all_read_events_overt)

            all_produce_events_covert = np.array([event for event in all_produce_events if not Epo.is_overt(event)]) #same thing with th eproduce events
            events_list.append(all_produce_events_covert)
            all_produce_events_overt = np.array([event for event in all_produce_events if Epo.is_overt(event)])
            events_list.append(all_produce_events_overt)

            cleaned_read_events_covert = np.array([event for event in cleaned_read_events if not Epo.is_overt(event)]) #same thing just with the cleaned files instead
            events_list.append(cleaned_read_events_covert)
            cleaned_read_events_overt = np.array([event for event in cleaned_read_events if Epo.is_overt(event)])
            events_list.append(cleaned_read_events_overt)

            cleaned_produce_events_covert = [event for event in cleaned_produce_events if not Epo.is_overt(event)]
            events_list.append(cleaned_produce_events_covert)
            cleaned_produce_events_overt = [event for event in cleaned_produce_events if Epo.is_overt(event)]
            events_list.append(cleaned_produce_events_overt)

            picks = mne.pick_types(cleaned_raw.info, meg=True, eeg=False, stim=False, eog=False, ecg=False, misc=False) #which channels to pick

            for idx in range(len(events_list)): #setting the indexes for the different conditions
                evs = events_list[idx]
                if idx < 4:
                    cleaning = 'WITH_BADS'
                elif idx > 3:
                    cleaning = 'WITHOUT_BADS'
                if idx % 2 == 0:
                    condition = 'COVERT'
                elif idx % 2 == 1:
                    condition = 'OVERT'

                if len(evs) > 0:
                    epochs_main = mne.Epochs(cleaned_raw, events=evs, reject=reject_dict, picks=picks, baseline=None,
                                         tmin=epoch_tmin, tmax=epoch_tmax, preload=True) #creates the epochs given the data and the specified parameters

                    if len(epochs_main) != 0:
                        ar = AutoReject(verbose=True, picks=picks, n_jobs=3) #initializes the autoreject class
                        # ransac = Ransac(verbose=True, picks=picks, n_jobs=3)
                        try:
                            epochs_clean = ar.fit_transform(epochs_main) # automatically identifies the bad data
                            # epochs_clean = ransac.fit_transform(epochs_main)
                        except:
                            epochs_clean = epochs_main #if that doesn't work, then just use the original epochs

                        trigger_index = produce_triggers.index(produce_trigger)
                        if len(epochs_clean) != 0:
                            syll_label = epochs_main.events[0][2]
                            epochs_clean.save(op.join(evo_output, cleaning, condition, sub+'_'+str(i)+'_'+syllables[trigger_index]+'_'+str(syll_label)+'-epo.fif'))


