# Realign two EEGLab.Epoch object's epochs.

@Author [@FranckPrts](hstate_dictps://github.com/FranckPrts).

Here we provide a python-based solution to **realign concurrent epochs contained in two epoched object that lost their alignement because of losing their concurrence during preprocessing.**

After procedding to preprocessing in EEGLAB, the data were saved in `.set` format. When later loading this data, the 

Our main goal here is to remove epochs in a given EEG that should have been removed during preprocessing when the concurent epoch in the other EEG was. At the end, we should have two epoched EEG with the same amount of epochs and where each epoch at a given index correspond what was curcurent while recording.

As exemplified below: 

#IMAGE

Things that might be idiosynchratic:

1. Our EEG data was segmented in 1sec epochs (which helps with congruency between epochs and the time they where collected).
2. Preprocessing was done independelty for each EEG data. 
3. Each dyad were preprocessed 2-3 times (iteration).
    - epochs ID that were rejected were noted in a separate file between each round.
    - the objects were save at the end of each iteration and re-read after so .

Our issue arise from the fact that once each step was performed, saving the data would lead to losing track of what was the epochs original IDs. 

In that process we see that an epoch that originally had the ID #6 can end up with the new ID #3. 

To retrieve the original id of the epoch, we will have to work bakward from the last iteration of preprocessing to the first iteration. At each step we will store what was the previous ID of the epochs so we can find their original IDs. 

## Imports

In [1]:
# Package 
import mne
import numpy as np
import pandas as pd

# Custom functions
from utils import align_utils

# %matplotlib inline

We import two eeg stream that were preprocessed in MATLAB

In [2]:
files_to_process = np.loadtxt("files_to_process.csv",
                 delimiter=",", dtype=str)

dyad = [x for x in files_to_process]
# Careful, the file_to_process is in the order (dyad_nb, eeg_filepath_child, eeg_filepath_adutl)
dy = dyad[0]
data_path = '../FINS-data/'

In [3]:
eeg1 = align_utils.EpochsEEGLAB_to_mneEpochsFIF('{}{}_{}_FP/{}'.format(data_path, dy[0], 'child', dy[1])) 
eeg2 = align_utils.EpochsEEGLAB_to_mneEpochsFIF('{}{}_{}_FP/{}'.format(data_path, dy[0], 'adult', dy[2]))

Extracting parameters from /Users/zoubou/Documents/Work/NYU/Brito-Lab/FINS-Codes/../FINS-data/220_child_FP/FINS_220_Child_FreePlay_xchan_rej3.set...
Not setting metadata
159 matching events found
No baseline correction applied
0 projection items activated
Ready.
Reading /var/folders/vv/stc9rswn5c95vxdzpx7z6qqr0000gn/T/tmpvagbk3f8tmp.fif ...
    Found the data of interest:
        t =       0.00 ...     998.00 ms
        0 CTF compensation matrices available
0 bad epochs dropped
Not setting metadata
159 matching events found
No baseline correction applied
0 projection items activated
Extracting parameters from /Users/zoubou/Documents/Work/NYU/Brito-Lab/FINS-Codes/../FINS-data/220_adult_FP/FINS_220_Adult_FreePlay_xchan_ica_rej3.set...
Not setting metadata
206 matching events found
No baseline correction applied
0 projection items activated
Ready.


  tmp = mne.io.read_epochs_eeglab(path)
  tmp.save(tmpdir+"tmp.fif", overwrite=True, verbose=None)
  return mne.read_epochs(tmpdir+"tmp.fif")
  tmp = mne.io.read_epochs_eeglab(path)
  tmp.save(tmpdir+"tmp.fif", overwrite=True, verbose=None)


Reading /var/folders/vv/stc9rswn5c95vxdzpx7z6qqr0000gn/T/tmpdxl72z3vtmp.fif ...
    Found the data of interest:
        t =       0.00 ...     998.00 ms
        0 CTF compensation matrices available
0 bad epochs dropped
Not setting metadata
206 matching events found
No baseline correction applied
0 projection items activated


  return mne.read_epochs(tmpdir+"tmp.fif")


Let's see how many epochs we have per EEG file:

In [4]:
print('EEG-1 has {} epochs.'.format(eeg1.get_data().shape[0]))
print('EEG-2 has {} epochs.'.format(eeg2.get_data().shape[0]))

EEG-1 has 159 epochs.
EEG-2 has 206 epochs.


Well, there should be the same amount of epochs in each file. Moreover, when looking at the index of each epochs (see the x-axis of the plots bellow) we can see that they are all continuous, thus, not indicating which epochs were rejected:

In [5]:
# eeg1.plot()

In [6]:
# eeg2.plot()

### What's the plan now?

When loading an file in the EEGLAB format,  You have the following epoch indexes in your preprocessed file: 

`1, 2, 3, 4, 5, 6, 7`

And you know that the following epochs were rejected:

`3, 7, 8`

but then get 

`1, 2, 3, 4`

We'll now reconstruct the original epoch index as follows? (Within brackets):

`1(1), 2(2), NaN, 4(3), 5(4), 6(5), NaN, NaN, 9(6), 10(7)`


> **Careful, we have multiple round of rejection, so that method will have to be iterated over each round.**

In [7]:
df1 = eeg1.to_data_frame()
df2 = eeg2.to_data_frame()

In [8]:
Eeg1epochsIDs = df1.epoch.unique()
Eeg2epochsIDs = df2.epoch.unique()

## Make an example

In [9]:
df = pd.DataFrame({'Letters': ['A', 'B', 'C', 'D', 'E', 'F', 'G', 'H'], 'Indexes': [0, 1, 2, 3, 4, 5, 6, 7]})

# pd.set_option('display.max_rows', len(state_dict))
df

Unnamed: 0,Letters,Indexes
0,A,0
1,B,1
2,C,2
3,D,3
4,E,4
5,F,5
6,G,6
7,H,7


In [10]:
# Now we remove two rows in a first round:
# Create a list of elements to remove
rmed_1 = [1, 3]

# Create a boolean mask indicating which rows to keep
mask = df['Indexes'].isin(rmed_1)

# Remove the rows that match the elements in the list
df.drop(index=df[mask].index, inplace=True)

# Now we reset the index the same way saving this 'eeg' file would when being read for the next iteration's round 
df.Indexes = [i for i in range(len(df))]
df

Unnamed: 0,Letters,Indexes
0,A,0
2,C,1
4,E,2
5,F,3
6,G,4
7,H,5


In [11]:
# Now were remove three rows and directly reset the indexes
# Create a list of elements to remove
rmed_2 = [2, 5, 0]

# Create a boolean mask indicating which rows to keep
mask = df['Indexes'].isin(rmed_2)

# Remove the rows that match the elements in the list
df.drop(index=df[mask].index, inplace=True)

df.Indexes = [i for i in range(len(df))]

df

Unnamed: 0,Letters,Indexes
2,C,0
5,F,1
6,G,2


Alrigth, now we have two list containning the indexes that were removed **`at the time of their round of rejection`**. 

Keep in mind that the index #4 could be deleted in multiple round as #4 could be reassigned when the file is re-read.

In [12]:
final_idx = df['Indexes'].tolist()
print('Index that were rejected at the\n\t1st round: {}\n\t2st round: {}'.format(rmed_1, rmed_2))
print('The indexes as they are after the last rejection round {}'.format(final_idx))

Index that were rejected at the
	1st round: [1, 3]
	2st round: [2, 5, 0]
The indexes as they are after the last rejection round [0, 1, 2]


Define the list of epochs that were rejected as the `list` of `list` containing the IDs of the epochs that were rejected at each round of preprocessing. **The first `list` should contain the epochs IDs rejected at first preprocessing round and the last element should correspond to the last.** 

In [13]:
# Define the list of epochs that were rejected
rmed_list=[rmed_1, rmed_2]
rmed_list

[[1, 3], [2, 5, 0]]

In [14]:
# def create_initial_state_dict (final_idx=list):
#     state_dict = {}
#     for i in final_idx:
#         state_dict[i] = i
#     return state_dict

In [15]:
# def take_a_step_back (state_dict=dict, rm_idx=list, is_first_round=bool) :
    
#     # If this is the first round (i.e., the last rejection round)
#     if is_first_round: 
#         # Order the keys of the state_dict so we can iterate over them     
#         existing_states = []
#         for key in state_dict.keys():
#             existing_states.append(key)
#         existing_states.sort()

#         # Add placeholders in the state_dict for the new state
#         # added by introducing the removed states
#         for new_key in range(len(rm_idx)):
#             state_dict[existing_states[-1]+new_key+1] = existing_states[-1]+new_key+1

#         # Order the keys of the state_dict so we can iterate over them 
#         existing_states = []
#         for key in state_dict.keys():
#             existing_states.append(key)
#         existing_states.sort()
#     else:
#         existing_states = []
#         for key in state_dict.keys():
#             existing_states.append(key)
#         existing_states.sort()

#         # Find what is the highest values of a state that exist,
#         # we'll need that to define the values of the states we're
#         # adding correctly (e.g. if we have [0(0), 1(NaN), 2(1)], 
#         # we're adding the state #3 with value '2': [0(0), 1(NaN), 2(1), 3(2)])
#         tmp = [state_dict[i] for i in state_dict.keys() if state_dict[i] != 'NaN']
#         tmp.sort()
#         highest_state_val = tmp[-1]
#         del tmp

#         # Add placeholders in the state_dict for the new state
#         # added by introducing the removed states
#         for new_key in range(len(rm_idx)):
#             state_dict[existing_states[-1]+new_key+1] = highest_state_val+new_key+1

#         existing_states = []
#         for key in state_dict.keys():
#             existing_states.append(key)
#         existing_states.sort()

#     # Sort the states that were removed so we make sure we start by the 
#     # lowest idx to reiterate shifting idx correctly
#     rm_idx.sort()

#     # For each index that was removed,
#     for rmed in rm_idx:
#         # Check all existing idx, and if the index
#         # already exist, update it such that ...
#         for existing in existing_states:
#             # ... only the indexes that would be shifted by introducing 
#             # a NaN sees their 'new' index substracted 1.
#             # We substract 1 because in the persepctive of the initial df,
#             # the index of a given state lost a rank because of removing a
#             # state that was anterior to it.
#             if existing > rmed:
#                 if state_dict[existing] == 'NaN':
#                     pass
#                 else:
#                     state_dict[existing] -= 1
#         state_dict[rmed] = 'NaN'

#     return state_dict

In [16]:
# def revert_to_original_idx (final_idx=list, rmed_list=list, verbose=True):
    
#     # First initialise the state dict containing the true idx as keys 
#     # and their corresponding epoch (here the letters)
#     state_dict = create_initial_state_dict(final_idx)

#     if verbose:
#         print('Initial state:')
#         for i in state_dict.keys():
#             print('\t',i, state_dict[i])

#     # revert "final_idx" so we can beging by the last round 
#     final_idx.reverse()

#     # Initialize the boolean "first_round_bool" to be true so the 
#     # function knows it has to take the special step it needs to 
#     # for 
#     first_round_bool = True
    
#     for rm_idx in rmed_list:

#         updated_state_dict = take_a_step_back(
#             state_dict=state_dict, 
#             rm_idx=rm_idx, 
#             is_first_round=first_round_bool)
        
#         if verbose:
#             print('Updated state:')
#             for i in updated_state_dict.keys():
#                 print('\t',i, updated_state_dict[i])
            
#         # Not the first round anymore
#         first_round_bool = False



In [17]:
updated_state_dict = align_utils.revert_to_original_idx(
    last_state = final_idx,
    removed_list  = rmed_list,
    verbose    = True
)

Initial state:
	 0 0
	 1 1
	 2 2
Updated state:
	 0 NaN
	 1 0
	 2 NaN
	 3 1
	 4 2
	 5 NaN
Updated state:
	 0 NaN
	 2 0
	 4 NaN
	 5 1
	 6 2
	 7 NaN
	 1 NaN
	 3 NaN


Let's verify that we reconstructed the correspondance between IDs and their original place correctly.

In [18]:
letters=['A', 'B', 'C', 'D', 'E', 'F', 'G', 'H']

tt = sorted(updated_state_dict.keys())

print('\tOriginal ID\tFinal state\tLetter',)
for i in tt:
    print('\t',i,'\t\t',updated_state_dict[i],'\t\t', letters[i])
    


	Original ID	Final state	Letter
	 0 		 NaN 		 A
	 1 		 NaN 		 B
	 2 		 0 		 C
	 3 		 NaN 		 D
	 4 		 NaN 		 E
	 5 		 1 		 F
	 6 		 2 		 G
	 7 		 NaN 		 H


We now have a dictionnary that has `keys` representing each of the original epoch ID and `values` representing the state of that epoch at the end of preprocessing.

The **state** can be either
- `NaN` which indicated that this epoch was removed during preprocessing **or,**
- the id that was initially associated to the remaining epochs.

Now that we have this for a subject EEG 

## Useful references

- To get comfortable with the MNE documentation, you should know that MNE is based on python [Object Oriented Programming (00P)](hstate_dictps://realpython.com/python3-object-oriented-programming/). These objects are defined from a python `Class`.
    - You can get familiarized with the OOP structure and its componenent, e.g. `methods` (a function associated to the the object) and `astate_dictribute` (a variable associated to the object), wit [this tutorial](hstate_dictps://www.datacamp.com/tutorial/python-oop-tutorial).
    - In MNE, we find [`Raw` objects](hstate_dictps://mne.tools/stable/generated/mne.io.Raw.html) (continuous data) or [`Epoch` objects](hstate_dictps://mne.tools/stable/generated/mne.Epochs.html) (a collection of epochs). 

You can find an introduction to the **Epochs data structure** [here](hstate_dictps://mne.tools/stable/auto_tutorials/epochs/10_epochs_overview.html) in MNE. 

### Extracting the epoch data

We're now going to extract the epoch data from the mne.EpochFIF to apply the operation described above.