In [5]:
import numpy as np
import mne
from sklearn.neighbors import NearestNeighbors

In [131]:
class DifferentLengthsError(Exception):
    """Exception raised for errors when lengths of trigger map and events are different.

    Attributes:
        trigger_map_len -- length of the trigger map
        events_len -- length of the events
        message -- explanation of the error
    """

    def __init__(self, trigger_map_len, events_len, message="Trigger map and events have different lengths."):
        self.trigger_map_len = trigger_map_len
        self.events_len = events_len
        self.message = message
        super().__init__(self.message)

    def __str__(self):
        return f"{self.message} Trigger map length: {self.trigger_map_len}, Events length: {self.events_len}"


def replace_trigger_names(raw, trigger_map=None):
    # Replace event IDs in the Raw object
    events = raw.annotations.description
    if len(events) != len(trigger_map):
        raise DifferentLengthsError(len(trigger_map), len(events))
    else:
        for idx, event in enumerate(events):
            raw.annotations.description[idx] = trigger_map[idx]
        return raw


def create_epochs(
        raw,
        tmin=-.1,
        tmax=.6,
        events_to_select=None,  # response_event_dict
        new_events_dict=None,  # new_response_event_dict
        events_mapping=None,  # events_mapping
        reject=None,
        reject_by_annotation=False
):
    # select specific events
    events, event_ids = mne.events_from_annotations(raw, event_id=events_to_select)

    # Merge different events of one kind
    for mapping in events_mapping:
        events = mne.merge_events(
            events=events,
            ids=events_mapping[mapping],
            new_id=new_events_dict[mapping],
            replace_events=True,
        )

    # Read epochs
    epochs = mne.Epochs(
        raw=raw,
        events=events,
        event_id=new_events_dict,
        tmin=tmin,
        tmax=tmax,
        baseline=None,
        reject_by_annotation=reject_by_annotation,
        preload=True,
        reject=reject,
    )
    
    return epochs


def ocular_correction_ica(epochs, raw, veog=[], heog=[]):
    filtered_raw = raw.copy().filter(l_freq=1.0, h_freq=None)

    # peak-to-peak amplitude rejection parameters to exclude very noisy epochs
    epochs_ica = create_epochs(
        filtered_raw,
        tmin=-.4,
        tmax=.8,
        reject=dict(mag=4e-12),
    )

    ica = mne.preprocessing.ICA(
        n_components=36,
        method='infomax',
        max_iter="auto",
        random_state=97
    )
    ica.fit(epochs_ica)

    ica.exclude = []
    # find which ICs match the VEOG pattern
    veog_indices, veog_scores = ica.find_bads_eog(
        epochs,
        ch_name=veog,
        threshold=0.9,
        measure='correlation'
    )
    heog_indices, heog_scores = ica.find_bads_eog(
        epochs,
        ch_name=heog,
        threshold=0.8,
        measure='correlation'
    )

    ica.exclude = veog_indices + heog_indices

    reconstructed_epochs = epochs.copy()
    ica.apply(reconstructed_epochs)

    del filtered_raw, epochs_ica

    return reconstructed_epochs


def get_k_nearest_neighbors(target_ch_name, epochs, k=6):
    """
    Finds k nearest neighbors of given channel according to the 3D channels positions from the Epoch INFO
    :param target_ch_name: String
        Name of the target channel.
    :param epochs: mne Epochs
        Epochs with info attribute that consists of channels positions.
    :param k: int
        Number of neighbors to use by default for kneighbors queries. 
    :return: 
        indices: ndarray of shape (n_neighbors)
            Indices of the nearest channels.
        neighbor_ch_names: ndarray of shape (n_neighbors)
            Names of the nearest channels.   
    """
    info = epochs.info
    ch_names = epochs.info['ch_names']
    
    chs = [info["chs"][pick] for pick in np.arange(0,len(ch_names))]
    electrode_positions_3d =[]
    
    for ch in chs:
        electrode_positions_3d.append((ch['ch_name'], ch["loc"][:3]))
    
    neighbors_model = NearestNeighbors(n_neighbors=k+1, algorithm='auto')
    ch_coordinates = np.array([ch_name_coordinates[1] for ch_name_coordinates in electrode_positions_3d])

    neighbors_model.fit(ch_coordinates)

    target_ch_coordinates = [ch_name_coordinates[1] for ch_name_coordinates in electrode_positions_3d if ch_name_coordinates[0] == target_ch_name]

    distances, indices = neighbors_model.kneighbors(target_ch_coordinates)
    neighbor_ch_names = []
    
    # Print the nearest neighbors without the first (self) neighbor
    # print(f"{k} Nearest Neighbors of {target_ch_name}:")
    for i, (distance, index) in enumerate(zip(distances.flatten(), indices.flatten())):
        if i == 0:
            pass
        else:
            neighbor_point = electrode_positions_3d[index]
            # print(f"Neighbor {i + 1}: Index {index}, Distance {distance:.2f}, Coordinates {neighbor_point}")
            neighbor_ch_names.append(neighbor_point[0])
            
    return indices.flatten()[1:], np.array(neighbor_ch_names)


def find_bad_trails(epochs):
    """
    Channels that meet following conditions will be marked as bad for the trail:
        (1) Channels with a voltage difference of 100 μV through the duration of the epoch;
        (2) Channels that were flat;
        (3) Channels with more than a 30 μV difference with the nearest six neighbors; 
    :param mne Epochs 
        Epochs to find bad channels per trial. 
    :return: drop_log: tuple of n_trials length
        Tuple representing bad channels names per trial.
    """
    epochs_copy = epochs.copy()
    
    # channels with a voltage difference of 100 μV through the duration of the epoch
    reject_criteria = dict(eeg=100e-6)
    # flat channels (less than 1 µV of peak-to-peak difference)
    flat_criteria = dict(eeg=1e-6)
    
    epochs_copy.drop_bad(reject=reject_criteria, flat=flat_criteria)
    drop_log = epochs_copy.drop_log

    # channels with more than a 30 μV difference with the nearest six neighbors
    for idx, _ in enumerate(epochs_copy):
        epoch = epochs[idx]
        epoch_data = epoch.get_data(copy=True)
        for ch_name, ch_idx in zip(epochs_copy.info['ch_names'], np.arange(0, len(epochs_copy.info['ch_names']))):
            mean_channel_data = np.array(np.mean(epoch_data[0,ch_idx,:]))
    
            ch_neighbors_indices, ch_neighbors_names = get_k_nearest_neighbors(
                target_ch_name = ch_name,
                epochs = epoch,
                k=6
            )
    
            mean_neighbors_data = np.array([np.mean(epoch_data[0, ch_neighbor_index, :]) 
                                            for ch_neighbor_index in ch_neighbors_indices])
    
            # # if channels has more than a 30 μV difference with the nearest six neighbors
            if (abs(mean_neighbors_data - mean_channel_data) > 30e-6).all():
                print(f'BAD------ trail index {idx}, channel: {ch_name}')
                print(mean_channel_data)
                print(mean_neighbors_data)
    
                new_drop_log_item = drop_log[idx] + (ch_name, ) if ch_name not in drop_log[idx] else drop_log[idx]
                drop_log = tuple(new_drop_log_item if i == idx else item for i, item in enumerate(drop_log))

    del epochs_copy
    
    return drop_log

def calculate_percentage(tuple_of_tuples, element):
    total_tuples = len(tuple_of_tuples)
    # Avoid division by zero
    if total_tuples == 0:
        return 0 

    tuples_with_element = sum(1 for inner_tuple in tuple_of_tuples if element in inner_tuple)
    percentage = (tuples_with_element / total_tuples)
    return percentage

def find_global_bad_channels(epochs, drop_log):
    '''
    (1) Channels with an absolute correlation with the nearest six neighboring channels that fell below .4;
    (2) Channels that were marked as bad for more than 20% of epochs;
    :param epochs: 
    :param drop_log: 
    :return: 
    '''
    epochs_copy = epochs.copy()
    epochs_data = epochs_copy.get_data(copy=True)
    concatenated_epochs_data = np.concatenate(epochs_data, axis=1)
    
    global_bad_channels_drop_log = {}
    
    # (1) Channels with an absolute correlation with the nearest six neighboring channels that fell below .4
    for ch_name, ch_idx in zip(epochs_copy.info['ch_names'], np.arange(0, len(epochs_copy.info['ch_names']))):
        channel_data = concatenated_epochs_data[ch_idx]
    
        ch_neighbors_indices, ch_neighbors_names = get_k_nearest_neighbors(
            target_ch_name = ch_name,
            epochs = epochs_copy,
            k=6
        )
        channel_neighbors_data = np.array([concatenated_epochs_data[ch_neighbor_index]
                                           for ch_neighbor_index in ch_neighbors_indices])
        channels_corr = np.tril(np.corrcoef(channel_neighbors_data, channel_data), k=-1)
    
        if (abs(channels_corr[-1][:-1]) < .4).all():
            print(f'BAD------ global channel: {ch_name}, low corr')
            print(channels_corr[-1][:-1])
    
            # mark channel as globally bad
            global_bad_channels_drop_log[ch_name] = ['LOW CORR NEIGH']
            # update drop_log
            drop_log = tuple(drop_log[i] + (ch_name,)
                             if ch_name not in drop_log[i] else drop_log[i] for i, item in enumerate(drop_log))
    
    # (2) Channels that were marked as bad for more than 20% of epochs   
    for ch_name, ch_idx in zip(epochs_copy.info['ch_names'], np.arange(0, len(epochs_copy.info['ch_names']))):
        percentage = calculate_percentage(drop_log, ch_name)
    
        if percentage > 0.2:
            print(f'BAD------ global channel: {ch_name}, percentage: {percentage}')
    
            if ch_name in global_bad_channels_drop_log:
                global_bad_channels_drop_log[ch_name].append('BAD FOR MORE THAN 20%')
            else:
                global_bad_channels_drop_log[ch_name] = ['BAD FOR MORE THAN 20%']
    
            # update drop_log
            drop_log = tuple(drop_log[i] + (ch_name,)
                             if ch_name not in drop_log[i] else drop_log[i] for i, item in enumerate(drop_log))
            
    del epochs_copy
            
    return drop_log, global_bad_channels_drop_log

def get_bad_trials(drop_log):
    """
    If more than 10% of channels were marked bad for an epoch (trial), the entire epoch was rejected
    :param drop_log: 
    :return: trials_to_drop: ndarray
    '"""
    trials_to_drop = []
    return trials_to_drop

def interpolate_bad_channels(epochs, drop_log, global_bad_channels_drop_log):
    '''
    Bad channels were interpolated using spherical splines
    :param epochs: 
    :param drop_log: 
    :param global_bad_channels_drop_log: 
    :return: 
    '''
    return epochs, drop_log
    
# def pre_process_eeg(input_fname, trigger_map=None, parameters=None):
#     raw = mne.io.read_raw_bdf(input_fname, eog=['EX7', 'EX8'])
# 
#     # 1. re-reference: to mastoids
#     raw.set_eeg_reference(ref_channels=['M1', 'M2'])
# 
#     # 2. 4-th order Butterworth filters
#     raw_filtered = raw.copy().filter(
#         l_freq=.1,
#         h_freq=30.0,
#         n_jobs=10,
#         method='iir',
#         iir_params=None,
#     )
# 
#     # 3. Notch filter at 50 Hz
#     raw_filtered = raw_filtered.filter(
#         freqs=np.arange(50, 251, 50),
#         n_jobs=10,
#         method='iir',
#         iir_params=None,
#     )
# 
#     # 4. segmentation -400 to 800 ms around the response
#     epochs = create_epochs(raw_filtered, tmin=-.4, tmax=.8)
# 
#     # 5. ocular artifact correction with ICA
#     refined_epochs = ocular_correction_ica(epochs, raw, heog=['EX3', 'EX4'], veog=['EX1', 'EX2'])
# 
#     # 6. Trial-wise Bad Channels Identification
#     drop_log = find_bad_trails(epochs) 
#
#     # 7. Global Bad Channel Identification - <.4 corr with 6 neigh. and channels marked as bad for more than 20% trials
#     drop_log, global_bad_channels_drop_log = find_global_bad_channels(epochs, drop_log)
#
#     # 8. Drop bad trails
#     get_bad_trials(drop_log) -> if more than 10% of channels were marked bad for an epoch, the entire epoch was rejected todo
#    
#     # 9. Interpolate bad channels (and thus update drop log)
#     interpolate_bad_channels(epochs, ) -> changed epochs and changes updated_drop_log todo
# 

In [123]:
response_event_dict = {
    'Stimulus/RE*ex*1_n*1_c_1*R*FB': 10003,
    'Stimulus/RE*ex*1_n*1_c_1*R*FG': 10004,
    'Stimulus/RE*ex*1_n*1_c_2*R': 10005,
    'Stimulus/RE*ex*1_n*2_c_1*R': 10006,
    'Stimulus/RE*ex*2_n*1_c_1*R': 10007,
    'Stimulus/RE*ex*2_n*2_c_1*R*FB': 10008,
    'Stimulus/RE*ex*2_n*2_c_1*R*FG': 10009,
    'Stimulus/RE*ex*2_n*2_c_2*R': 10010,
}

new_response_event_dict = {"correct_response": 0, "error_response": 1}

events_mapping = {
    'correct_response': [10003, 10004, 10008, 10009],
    'error_response': [10005, 10006, 10007, 10010],
}

In [124]:
raw = mne.io.read_raw_brainvision(
    vhdr_fname = 'data/GNG_AA0303-64 el.vhdr',   
)

Extracting parameters from data/GNG_AA0303-64 el.vhdr...
Setting channel info structure...


In [125]:
epochs = create_epochs(
    raw, 
    tmin=-.1, 
    tmax=.6,
    events_to_select=response_event_dict,  # response_event_dict
    new_events_dict=new_response_event_dict,  # new_response_event_dict
    events_mapping=events_mapping,  # events_mapping
    reject=None,
    reject_by_annotation=False,
)

Used Annotations descriptions: ['Stimulus/RE*ex*1_n*1_c_1*R*FB', 'Stimulus/RE*ex*1_n*1_c_1*R*FG', 'Stimulus/RE*ex*1_n*1_c_2*R', 'Stimulus/RE*ex*1_n*2_c_1*R', 'Stimulus/RE*ex*2_n*1_c_1*R', 'Stimulus/RE*ex*2_n*2_c_1*R*FB', 'Stimulus/RE*ex*2_n*2_c_1*R*FG', 'Stimulus/RE*ex*2_n*2_c_2*R']
Not setting metadata
311 matching events found
No baseline correction applied
0 projection items activated
Loading data for 311 events and 181 original time points ...
1 bad epochs dropped


In [126]:
# """
#     Channels that meet following conditions were marked as bad for the trail:
#         (1) Channels with a voltage difference of 100 μV through the duration of the epoch;
#         (2) Channels that were flat;
#         (3) Channels with more than a 30 μV difference with the nearest six neighbors; 
#     :param epochs with globally setted bad channels that have < .04 correlation with the neighbors: 
#     :return: epochs with channels marked as bad
#     """
# 
# epochs_copy = epochs.copy()
# drop_log = None
# 
# # channels with a voltage difference of 100 μV through the duration of the epoch
# reject_criteria = dict(eeg=100e-6)
# # flat channels (less than 1 µV of peak-to-peak difference)
# flat_criteria = dict(eeg=1e-6)
# 
# epochs_copy.drop_bad(reject=reject_criteria, flat=flat_criteria)
# drop_log = epochs_copy.drop_log # drop_log - bo bedziemy tylko odrzucac te epoki, dla których ilość złych kanałów pzekroczy 10%
# 
# for idx, _ in enumerate(epochs_copy):
#     epoch = epochs[idx]
#     epoch_data = epoch.get_data(copy=True)
#     for ch_name, ch_idx in zip(epoch.info['ch_names'], np.arange(0, len(epochs_copy.info['ch_names']))):
#         mean_channel_data = np.array(np.mean(epoch_data[0,ch_idx,:]))    
# 
#         ch_neighbors_indices, ch_neighbors_names = get_k_nearest_neighbors(
#             target_ch_name = ch_name, 
#             epochs = epoch, 
#             k=6
#         )
# 
#         mean_neighbors_data = np.array([np.mean(epoch_data[0, ch_neighbor_index, :]) for ch_neighbor_index in ch_neighbors_indices])
# 
#         # # if channels has more than a 30 μV difference with the nearest six neighbors
#         if (abs(mean_neighbors_data - mean_channel_data) > 30e-6).all():
#             print(f'BAD------ trail index {idx}, channel: {ch_name}')
#             print(mean_channel_data)
#             print(mean_neighbors_data)
# 
#             new_drop_log_item = drop_log[idx] + (ch_name, ) if ch_name not in drop_log[idx] else drop_log[idx] 
#             drop_log = tuple(new_drop_log_item if i == idx else item for i, item in enumerate(drop_log))

    Rejecting  epoch based on EEG : ['F4']
    Rejecting  epoch based on EEG : ['PO4']
    Rejecting  epoch based on EEG : ['F4']
    Rejecting  epoch based on EEG : ['PO4']
    Rejecting  epoch based on EEG : ['F4']
    Rejecting  epoch based on EEG : ['PO4']
    Rejecting  epoch based on EEG : ['PO4']
    Rejecting  epoch based on EEG : ['F4']
    Rejecting  epoch based on EEG : ['F4']
    Rejecting  epoch based on EEG : ['F4']
    Rejecting  epoch based on EEG : ['PO4']
    Rejecting  epoch based on EEG : ['PO4']
    Rejecting  epoch based on EEG : ['F8']
    Rejecting  epoch based on EEG : ['P9']
    Rejecting  epoch based on EEG : ['F4', 'PO4']
    Rejecting  epoch based on EEG : ['PO4']
    Rejecting  epoch based on EEG : ['PO4']
    Rejecting  epoch based on EEG : ['PO4']
    Rejecting  epoch based on EEG : ['PO4']
    Rejecting  epoch based on EEG : ['PO4']
    Rejecting  epoch based on EEG : ['F4']
    Rejecting  epoch based on EEG : ['PO7']
    Rejecting  epoch based on EEG :

In [128]:
# epochs_copy = epochs.copy()
# epochs_data = epochs_copy.get_data(copy=True)
# print(epochs_data.shape)
# concatenated_epochs_data = np.concatenate(epochs_data, axis=1)
# 
# # zepsuj do testow
# concatenated_epochs_data[0] = np.random.normal(0, 5, concatenated_epochs_data[0].shape)
# 
# global_bad_channels_drop_log = {}
# 
# # (1) Channels with an absolute correlation with the nearest six neighboring channels that fell below .4
# for ch_name, ch_idx in zip(epochs.info['ch_names'], np.arange(0, len(epochs.info['ch_names']))):
#     channel_data = concatenated_epochs_data[ch_idx]
# 
#     ch_neighbors_indices, ch_neighbors_names = get_k_nearest_neighbors(
#         target_ch_name = ch_name,
#         epochs = epochs.copy(),
#         k=6
#     )
#     channel_neighbors_data = np.array([concatenated_epochs_data[ch_neighbor_index]
#                                        for ch_neighbor_index in ch_neighbors_indices])
#     channels_corr = np.tril(np.corrcoef(channel_neighbors_data, channel_data), k=-1)
# 
#     if (abs(channels_corr[-1][:-1]) < .4).all():
#         print(f'BAD------ global channel: {ch_name}, low corr')
#         print(channels_corr[-1][:-1])
# 
#         # mark channel as globally bad
#         global_bad_channels_drop_log[ch_name] = ['LOW CORR NEIGH']
#         # update drop_log
#         drop_log = tuple(drop_log[i] + (ch_name,)
#                          if ch_name not in drop_log[i] else drop_log[i] for i, item in enumerate(drop_log))
# 
# # (2) Channels that were marked as bad for more than 20% of epochs   
# for ch_name, ch_idx in zip(epochs_copy.info['ch_names'], np.arange(0, len(epochs_copy.info['ch_names']))):
#     percentage = calculate_percentage(drop_log, ch_name)
# 
#     if percentage > 0.2:
#         print(f'BAD------ global channel: {ch_name}, percentage: {percentage}')
# 
#         if ch_name in global_bad_channels_drop_log:
#             global_bad_channels_drop_log[ch_name].append('BAD FOR MORE THAN 20%')
#         else:
#             global_bad_channels_drop_log[ch_name] = ['BAD FOR MORE THAN 20%']
# 
#         # update drop_log
#         drop_log = tuple(drop_log[i] + (ch_name,)
#                          if ch_name not in drop_log[i] else drop_log[i] for i, item in enumerate(drop_log))

(310, 64, 181)
BAD------ global channel: Fp1, low corr
[ 0.00038896 -0.00124657 -0.0016246   0.00173066  0.00396127 -0.00232104]
BAD------ global channel: Fp1, percentage: 1.0
