In [None]:
import numpy as np
import matplotlib
import pathlib
import mne
import matplotlib.pyplot as plt

In [None]:
# use this can show graph in new page!
# useful!
# matplotlib.use('Qt5Agg')

In [None]:
def read_fif(name):
  """
   use mne read fif data
   return back raw data
  """
  fname = 'OpenMIIR-RawEEG_v1/'+ name + '-raw.fif'
  raw = mne.io.read_raw_fif(fname, preload=True)
  return raw

In [None]:
# we write our own read fif function by using mne
raw = read_fif('P01')
raw

## Step 1: data pre-processing

### Learn our datasets

#### We will use P01 subject to show how we do the data pre-processing part

In [None]:
# check raw data plot 
fig = raw.plot()
plt.show()


In [None]:
# want to know how many events do we have
events = mne.find_events(raw) 

In [None]:
# show events graph
# we want to know each events happen times
fig = mne.viz.plot_events(events)

### Visualize the sensor locations

In [None]:
# we want to know which channel being used
# it can show which channels are bad channels too
fig = raw.plot_sensors(show_names=True)

### Filter

In [None]:
# The data was then filtered with a bandpass keeping a frequency range between 0.5 and 30 Hz. 
# This also removed any slow signal drift in the EEG
filt_raw = raw.copy().filter(l_freq=0.5, h_freq=30)

In [None]:
# plot after filter
fig = filt_raw.plot(events=events)

### Repairing artifacts with ICA

Refe:
Repairing artifacts with ICA:
https://mne.tools/stable/auto_tutorials/preprocessing/40_artifact_correction_ica.html

In [None]:
ica = mne.preprocessing.ICA(n_components=20, max_iter='auto', random_state=97)
ica.fit(filt_raw)
ica

In [None]:
# Now we can examine the ICs to see what they captured. 
# plot_sources will show the time series of the ICs.
filt_raw.load_data()
# This can help us determine which ICs we need to drop
fig = ica.plot_sources(filt_raw, show_scrollbars=False)

EEG Independent Component Labeling:
https://labeling.ucsd.edu/tutorial/labels

In [None]:
# We can also visualize the scalp field distribution of each component using plot_components.
fig = ica.plot_components()

In [None]:
# we can plot an overlay of the original signal against 
# the reconstructed signal with the artifactual ICs excluded

# eye blinks
fig1 = ica.plot_overlay(filt_raw, exclude=[0], picks='eeg')
# heartbeats
fig2 = ica.plot_overlay(filt_raw, exclude=[1], picks='eeg')

In [None]:
ica.exclude = [0, 1]  # indices chosen based on various plots above

In [None]:
# barplot of ICA component "EOG match" scores
# We can use EOG channel to select ICA components
eog_indices, eog_scores = ica.find_bads_eog(raw)
fig = ica.plot_scores(eog_scores)

In [None]:
# we can compare raw_data before ica deal with artifacts and after 

# ica.apply() changes the Raw object in-place, so let's make a copy first:
reconst_raw = raw.copy()
ica.apply(reconst_raw)

raw.plot()
reconst_raw.plot()
del reconst_raw

### Evoked responses: epoching and averaging

In [None]:
# Now we can extract epochs from the continuous data
# get_epochs is our own method 
epochs = mne.Epochs(filt_raw, events, preload=True)
epochs = ica.apply(epochs, exclude=ica.exclude)
epochs

In [None]:
# plot the epochs graph 
fig = epochs.plot(events=events)

### Figure out which channel is good to use

In [None]:
# let’s look at our evoked responses for some conditions we care about. 
# pick any events we want ex. events 11
events_11 = epochs['11'].average()
# l_vis = epochs[2].average()

In [None]:
# we can check each channel graph for this events
# This can help us pick the best channel we want to do training
fig = events_11.plot(spatial_colors=True)

In [None]:
# Scalp topographies can also be obtained non-interactively with the plot_topomap method
fig = events_11.plot_joint()

We think one channel is not enough, we may need combine all 61 channels. However, each subject may have different bad channels. We want to pick same channels for each subject.
Therefore, we use MNE to determine all subjects bad channels. Here are results:
| Subjects      | Bad Channels |
| ----------- | ----------- |
| P01      | P8, P10, T8      |
| P04   | T8        |
| P06      | Iz, FT4       |
| P07   | None        |
| P09      | None       |
| P11   | T7, T8        |
| P12      | C3, PO3       |
| P13   | Iz        |
| P14   | T7, F7        |

According above tale, we drop this bad channel list: [C3, F7, FT4, Iz, P8, P10, PO3, T7, T8]

In [None]:
# We can drop this bad channel list, and create a good channel list
bad_channel_list = ['C3', 'F7', 'FT4', 'Iz', 'P8', 'P10', 'PO3', 'T7', 'T8', 'EXG1', 'EXG2', 'EXG3', 'EXG4', 'EXG5', 'EXG6', 'STI 014']
all_channel_list = raw.ch_names
good_channel = [i for i in all_channel_list if i not in bad_channel_list]
print('Channels length: '+ str(len(good_channel)))
good_channel

In [None]:
# after we pick the channel
# check more specific info for this channel during events <11>
# for example, we pick channel Pz
epochs['11'].plot_image(picks=['Pz'])

## Step2: Get the Data We Need

#### We will use P01 subject to show how we get data for each subject

### First, we write some helper function to help us deal with some data

In [None]:
def drop_useless_events(events):
  """
    1. drop non-music events,ex:1000,1111,2001
    2. drop condition 4
  """
  # drop >= 1000
  events = [i for i in events if not int(i) >= 1000]

  result = [ele for ele in events if not ele.endswith("4")]
  return result

In [None]:
def avg_events(epochs, channels, events):
  """
   Each events will repeat 5 times for each subject, 
   so we first average all 5 repeat events data to 1.
   Next, we average all 56 good channels to 1.
   return result back
  """
  epochs = epochs.copy()
  result = {}
  for event in events:
    epochs_avg = epochs[event].average()
    epochs_avg_channels = epochs_avg.copy().pick_channels(channels)
    data = epochs_avg_channels.get_data()
    # average all 56 channels
    result[event] = data.mean(axis=0)
  return result

In [None]:
# which events do we need to predict?
event_id = drop_useless_events(epochs.event_id)

### Save result to events_dic_avg

In [None]:
%%capture 
# hide output
events_dic_avg = avg_events(epochs, good_channel, event_id)

Each events has 359 points. Data Type Dict, Key: events ID, Value: List of 359 points

In [None]:
# for example events 11 with conditons 2
# print first ten
print(events_dic_avg['112'][:10])

### Minimize the number of points
Minimize the number of points in each events from 359 to 60

In [None]:
def minimize_event(events_dic):
    result = {}
    for key in events_dic_avg:
        point_list = []
        points = events_dic_avg[key]
        splits = np.array_split(points, 60)
        for group in splits:
            ave = sum(group)/len(group)
            point_list.append(ave)
        result[key] = point_list
    return result

In [None]:
events_dic_avg_min = minimize_event(events_dic_avg)

Each events has 60 points. Data Type Dict, Key: events ID, Value: List of 60 points

In [None]:
print(events_dic_avg_min['11'])
print(len(events_dic_avg_min['112']))

## Step 3: Get data for multiple subjects

Having only one subject's data is clearly not enough for training data. After we get the data of the first subject, we need to repeat the previous steps to get the data of several more subjects to help us complete the subsequent training data. 

According to the article "Toward Studying Music Cognition with Information Retrieval Techniques: Lessons Learned from the OpenMIIR Initiative", it meations that " one participant was excluded for the experiments described in this paper because of a considerable number of trials with movement artifacts due to coughing."\
We check each subject data plot and then figure out subject P05 should be dropped from our experiments

In [None]:
# P05_raw = read_fif('P05')
# P05_raw.plot()
# del P05_raw

We will have 3 train subjects, and 6 test subjects 

In [None]:
# train_subjects = ['P01', 'P04', 'P06'] 
train_subjects = ['P01', 'P04'] 
# test_subjects = ['P07', 'P09', 'P11', 'P12', 'P13', 'P14']
test_subjects = ['P07']

We integrate the steps of geting data into one function called whole_process. 

In [None]:
def whole_process(name, chs):
    #1.Get dataset
    raw = read_fif(name)

    #2. Get events of the dataset
    events = mne.find_events(raw)

    #3. Filter
    filt_raw = raw.copy().filter(l_freq=0.5, h_freq=30)

    #4. Repairing artifacts with ICA
    ica = mne.preprocessing.ICA(n_components=20, max_iter='auto', random_state=97)
    ica.fit(filt_raw)

    ica.exclude = [0, 1]

    #5. Evoked responses: epoching and averaging
    # Now we can extract epochs from the continuous data
    # get_epochs is our own method 
    epochs = mne.Epochs(filt_raw, events, preload=True)
    epochs = ica.apply(epochs, exclude=ica.exclude)

    #6. Get the data we need
    channels = ['Pz']
    # Drop useless non-music events like 1000, 1111 and events with condition 4
    event_id = drop_useless_events(epochs.event_id)
    # Get each events average data for channel Pz and save result to events_dic_avg
    events_dic_avg = avg_events(epochs, chs, event_id)
    # Minimize the number of points in each events from 359 to 60 and save result to result
    result = minimize_event(events_dic_avg)
    del raw, filt_raw
    return result

In [None]:
events_dic_avg_04 = whole_process('P04', good_channel)

In [None]:
events_dic_avg_06 = whole_process('P06',good_channel)

In [None]:
print(events_dic_avg_04['11'])

In [None]:
print(events_dic_avg_06['11'])

In [None]:
# %%capture 
# # hide output
# train_dict = {}
# test_dict = {}
# # first deal with train dict
# for t in train_subjects:
#   # we already have P01 data skip it
#   if t == 'P01':
#     train_dict[t] = events_dic_avg_min
#   else:
#     train_dict[t] = whole_process(t, good_channel)

# # then deal with test dict
# for s in test_subjects:
#   # we already have P01 data skip it
#   test_dict[s] = whole_process(s, good_channel)


In [None]:
# print(train_dict.keys())
# print(test_dict.keys())

In [None]:
# fig, ax = plt.subplots(figsize=(12, 6))
# x = np.arange(0, 60, 1)
# y = train_dict['P04']['11']
# z = test_dict['P07']['11']

# ax.plot(y, color='blue', label='Train P04 event 11')
# ax.plot(z, color='black', label='Test P07 event 11')

# plt.xlim([25, 50])
# plt.show()

## delelte after we done the part 2

In [None]:
# import json

# with open('train.json', 'w') as fp:
#     json.dump(train_dict, fp)

# with open('test.json', 'w') as fp:
#     json.dump(test_dict, fp)