In [None]:
import numpy as np
import plotly.graph_objects as go
import mne
import configparser
from main_meg_qc import sanity_check
from initial_meg_qc import get_all_config_params, sanity_check, initial_processing


In [None]:
config = configparser.ConfigParser()
config.read('settings.ini')
data_file='/Volumes/M2_DATA/MEG_QC_stuff/data/from openneuro/ds003483/sub-009/ses-1/meg/sub-009_ses-1_task-deduction_run-1_meg.fif'
#data_file='/Volumes/M2_DATA/MEG_QC_stuff/data/from lab/forrest_gump_meg/en04ns31_vp15/190524/vp15_block1-1.fif'
raw = mne.io.read_raw_fif(data_file)
raw_cropped = raw.copy()
tmin_my_plot=100
tmax_my_plot=360
duration_my_plot=tmax_my_plot-tmin_my_plot
raw_cropped.crop(tmin=tmin_my_plot, tmax=tmax_my_plot)
#raw_cropped.copy().pick_types(meg=False, stim=False,ecg=True).plot()


#raw_cropped.copy().pick_types(meg=False, stim=False,eog=True).plot()

# all_qc_params = get_all_config_params('settings.ini')
# dict_of_dfs_epoch, epochs_mg, channels, raw_filtered, raw_filtered_resampled, raw_cropped, raw, active_shielding_used = initial_processing(default_settings=all_qc_params['default'], filtering_settings=all_qc_params['Filtering'], epoching_params=all_qc_params['Epoching'], data_file=data_file)
# in forrest gump data epoching gives RuntimeError: Event time samples were not unique. Consider setting the `event_repeated` parameter.



In [None]:
all_qc_params = get_all_config_params('settings.ini')
dict_of_dfs_epoch, epochs_mg, channels, raw_filtered, raw_filtered_resampled, raw_cropped, raw, active_shielding_used = initial_processing(default_settings=all_qc_params['default'], filtering_settings=all_qc_params['Filtering'], epoching_params=all_qc_params['Epoching'], data_file=data_file)

In [None]:
type(raw.info)


In [None]:
#do peak detection on ecg channel to find if channel is too noisy:
ecg_ch_data=raw_cropped.get_data(picks='ecg')[0]
#print(ecg_ch_data)

#peaks_ecg_channel_loc,  peaks_ecg_channel_mag = mne.preprocessing.peak_finder(ecg_ch[0], thresh=1e-4)

sfreq=raw_cropped.info['sfreq']
thresh_lvl = 1.2

thresh=(max(ecg_ch_data) - min(ecg_ch_data)) / thresh_lvl 
pos_peak_locs, pos_peak_magnitudes = mne.preprocessing.peak_finder(ecg_ch_data, extrema=1, thresh=thresh, verbose=False) #positive peaks
neg_peak_locs, neg_peak_magnitudes = mne.preprocessing.peak_finder(ecg_ch_data, extrema=-1, thresh=thresh, verbose=False) #negative peaks

normal_pos_peak_locs, normal_pos_peak_magnitudes = mne.preprocessing.peak_finder(ecg_ch_data, extrema=1, verbose=False) #positive peaks

max_pair_dist_sec=60/35

ind_break_start = np.where(np.diff(normal_pos_peak_locs)/sfreq>max_pair_dist_sec)


duration_crop = tmax_my_plot-tmin_my_plot

t=np.arange(0, duration_crop, 1/sfreq) 
# fig = go.Figure()
# fig.add_trace(go.Scatter(x=t, y=ecg_ch_data, name='data'));
# fig.add_trace(go.Scatter(x=t[pos_peak_locs], y=pos_peak_magnitudes, mode='markers', name='+peak'));
# fig.add_trace(go.Scatter(x=t[neg_peak_locs], y=neg_peak_magnitudes, mode='markers', name='-peak'));

# # fig.add_vrect(x0="2018-09-24", x1="2018-12-18", 
# #               annotation_text="decline", annotation_position="top left",
# #               fillcolor="green", opacity=0.25, line_width=0)

# for n in ind_break_start[0]:
#     fig.add_vline(x=t[normal_pos_peak_locs][n],
#               annotation_text='break', annotation_position="bottom right",line_width=0.6,annotation=dict(font_size=8))

# fig.update_layout(
#     title={
#     'text': "First magnetometer: peaks detected",
#     'y':0.85,
#     'x':0.5,
#     'xanchor': 'center',
#     'yanchor': 'top'},
#     xaxis_title="Time in seconds",
#     yaxis = dict(
#         showexponent = 'all',
#         exponentformat = 'e'))
    
# fig.show()

# if amplitudes is not None and len(amplitudes)>1*duration_crop/60: #allow 2 non-standard peaks per minute. Or 2? DISCUSS
#     print('ECG channel is too noisy. Peak-to-peak amplitudes detected over the set limit: '+str(len (amplitudes)))

print(duration_crop)
all_peaks=np.concatenate((pos_peak_locs,neg_peak_locs),axis=None)
if len(all_peaks)>3/duration_crop*60: 
#allow 2 non-standard peaks per minute. Or 0? DISCUSS. implies that noiseness has to be repeated regularly.  
# if there is only 1 little piece of time with noise and the rest is good, will not show that one. 
# include some time limitation of noisy times?
    print('ECG channel is too noisy. \nUnusual peaks in ECG amplitudes detected over the set limit: '+str(len (all_peaks))+'. Peaks per minute: '+str(len(all_peaks)/duration_crop*60))



In [None]:
#run cell above as a func

from initial_meg_qc import detect_noisy_ecg_eog, detect_extra_channels

ECG_channel_name, EOG_channel_name=detect_extra_channels(raw_cropped)

bad_ecg = detect_noisy_ecg_eog(raw_cropped, picked_channels_ecg_or_eog = ECG_channel_name,  thresh_lvl=1.2, plotflag=False)


In [None]:
# Finding channels most affected by ecg artifacts:  either calc area under the curve or peak height.

class Mean_artif_peak_on_channel:

    def __init__(self, channel, mean_artifact_epoch, peak_loc, peak_magnitude, artif_over_threshold=False):
        self.channel =  channel
        self.mean_artifact_epoch = mean_artifact_epoch
        self.peak_loc = peak_loc
        self.peak_magnitude = peak_magnitude
        self.artif_over_threshold = artif_over_threshold

    def __repr__(self):
        return 'Mean artifact peak on channel: ' + str(self.channel) + '\n - peak location inside artifact epoch: ' + str(self.peak_loc) + '\n - peak magnitude: ' + str(self.peak_magnitude) + '\n - artifact magnitude over threshold: ' + str(self.artif_over_threshold)+ '\n'


mag_ch_names = raw.copy().pick_types(meg='mag').ch_names if 'mag' in raw else None
grad_ch_names = raw.copy().pick_types(meg='grad').ch_names if 'grad' in raw else None
channels = {'mags': mag_ch_names, 'grads': grad_ch_names}

ecg_epochs = mne.preprocessing.create_ecg_epochs(raw_cropped, tmin=-0.1, tmax=0.1)

#averaging the ECG epochs together:
avg_ecg_epochs = ecg_epochs.average(picks=channels['mags'])#.apply_baseline((-0.5, -0.2))
#avg_ecg_epochs is evoked:Evoked objects typically store EEG or MEG signals that have been averaged over multiple epochs.
#The data in an Evoked object are stored in an array of shape (n_channels, n_times)
#print(avg_ecg_epochs.data[0,:]) #data of first channel,all time points
t=np.arange(0, len(avg_ecg_epochs.data[0,:]), 1/raw.info['sfreq'])

thresh_lvl_mean = 1.3

#fig = go.Figure()
#ecg_peaks_on_channels={}
ecg_peaks_on_channels=[]
avg_ecg_epoch_data_all=avg_ecg_epochs.data
for ch_ind, ch in enumerate(channels['mags']):
    avg_ecg_epoch_data=avg_ecg_epoch_data_all[ch_ind]
    thresh_mean=(max(abs(avg_ecg_epoch_data)) - min(abs(avg_ecg_epoch_data))) / thresh_lvl_mean

    #HERE INSTEAD OF RELATIVE THRESHOLD FOR EACH CHANNEL DECIDE HOW HIGH SHOULD THE ECG PEAK BE TO SAY THAT THERE IS ARTIFACT

    mean_peak_locs, mean_peak_magnitudes = mne.preprocessing.peak_finder(abs(avg_ecg_epoch_data), extrema=1, verbose=False, thresh=thresh_mean) 
    #print(mean_peak_locs, mean_peak_magnitudes)

    biggest_peak_ind=np.argmax(mean_peak_magnitudes)
    #ecg_peaks_on_channels[ch]=[mean_peak_locs[biggest_peak_ind], mean_peak_magnitudes[biggest_peak_ind]] 

    ecg_peaks_on_channels.append(Mean_artif_peak_on_channel(channel=ch, mean_artifact_epoch=avg_ecg_epoch_data, peak_loc=mean_peak_locs[biggest_peak_ind], peak_magnitude=mean_peak_magnitudes[biggest_peak_ind], artif_over_threshold=False))
    #.append([ch, mean_peak_locs[biggest_peak_ind]])

    #if ch_ind==0 or ch_ind==1 or ch_ind==2 or ch_ind==11 or ch_ind==18:
    
        #fig.add_trace(go.Scatter(x=t, y=abs(avg_ecg_epoch_data), name=ch))
        #fig.add_trace(go.Scatter(x=[t[mean_peak_locs[biggest_peak_ind]]], y=[mean_peak_magnitudes[biggest_peak_ind]], mode='markers', name='+peak'));

    
#fig.show()
#print(ecg_peaks_on_channels)


In [None]:

#Continue: find mean ECG magnitude over all channels:
mean_ecg_magnitude = np.mean([potentially_affected_channel.peak_magnitude for potentially_affected_channel in ecg_peaks_on_channels])
#print(mean_ecg_magnitude)

affected_channels=[]
for ch_ind, potentially_affected_channel in enumerate(ecg_peaks_on_channels):
    if potentially_affected_channel.peak_magnitude>mean_ecg_magnitude:
        potentially_affected_channel.artif_over_threshold=True
        affected_channels.append(potentially_affected_channel)


#Plot ech affected channels:
fig = go.Figure()

t = np.arange(-0.1, 0.1, 1/sfreq)
fig.add_trace(go.Scatter(x=t, y=[(mean_ecg_magnitude)]*len(t), name='mean ECG magnitude'))

for ch in affected_channels:

    fig.add_trace(go.Scatter(x=t, y=abs(ch.mean_artifact_epoch), name=ch.channel))
    fig.add_trace(go.Scatter(x=[t[ch.peak_loc]], y=[ch.peak_magnitude], mode='markers', name='+peak'));


fig.update_layout(
    xaxis_title='Time in seconds',
    yaxis = dict(
        showexponent = 'all',
        exponentformat = 'e'),
    yaxis_title='Mean artifact magnitude over epochs',
    title={
        'text': 'Channels affected by ECG artifact',
        'y':0.85,
        'x':0.5,
        'xanchor': 'center',
        'yanchor': 'top'})
        
fig.show()

In [None]:
#Calling as a function:

from ECG_meg_qc import find_ecg_affected_channels

mag_ch_names = raw.copy().pick_types(meg='mag').ch_names if 'mag' in raw else None
grad_ch_names = raw.copy().pick_types(meg='grad').ch_names if 'grad' in raw else None
channels = {'mags': mag_ch_names, 'grads': grad_ch_names}


ecg_affected_channels, all_figs=find_ecg_affected_channels(raw_cropped, channels, m_or_g_chosen=['mags'], norm_lvl=2, thresh_lvl_mean=1.3, tmin=-0.1, tmax=0.1, plotflag=True)


In [None]:
# Find ecg affected epochs:

tmin=-0.05
tmax=0.05
m_or_g='mags'
ecg_epochs = mne.preprocessing.create_ecg_epochs(raw, tmin=tmin, tmax=tmax, picks=channels[m_or_g])

df_ecg_epochs = ecg_epochs.to_data_frame()

df_ecg_epochs['mean'] = df_ecg_epochs.iloc[:, 3:-1].abs().mean(axis=1)
df_ecg_epochs

#Now plot the mean colum of the 0 epoch:

fig = go.Figure()

sfreq=raw.info['sfreq']
t = np.arange(tmin, tmax+1/sfreq, 1/sfreq)
#fig.add_trace(go.Scatter(x=t, y=[(mean_ecg_magnitude)]*len(t), name='mean ECG magnitude'))

all_means_of_epochs = [None] * len(ecg_epochs) #preassign
for ep in range(0,len(ecg_epochs)):
    df_one_ep=df_ecg_epochs.loc[df_ecg_epochs['epoch'] == ep]
    all_means_of_epochs[ep]=list(df_one_ep.loc[:,"mean"])
    if ep ==0 or ep==5 or ep==66 or ep==80:
        fig.add_trace(go.Scatter(x=t, y=all_means_of_epochs[ep], name='epoch '+str(ep)))

fig.update_layout(
    xaxis_title='Time in seconds',
    yaxis = dict(
        showexponent = 'all',
        exponentformat = 'e'),
    yaxis_title='Mean artifact magnitude over channels',
    title={
        'text': "Mean ECG magnitude over channels: one line for each ECG epoch",
        'y':0.85,
        'x':0.5,
        'xanchor': 'center',
        'yanchor': 'top'})
fig.show()



In [None]:
#Calling as a function:

from ECG_meg_qc import find_ecg_affected_epochs

mag_ch_names = raw.copy().pick_types(meg='mag').ch_names if 'mag' in raw else None
grad_ch_names = raw.copy().pick_types(meg='grad').ch_names if 'grad' in raw else None
channels = {'mags': mag_ch_names, 'grads': grad_ch_names}


ecg_affected_epochs, all_figs=find_ecg_affected_epochs(raw_cropped, channels, m_or_g_chosen=['mags'], norm_lvl=0.5, thresh_lvl_mean=1.3, tmin=-0.05, tmax=0.05, plotflag=True)

In [None]:
#find the unit of ECG channel.

# print(ECG_channel_name[0])
# print(ECG_channel_name[0] in raw_cropped.info['ch_names'])
indx=np.where(ECG_channel_name[0] in raw_cropped.info['ch_names'])
raw.info['chs'][indx[0][0]]['unit']


In [None]:

all_qc_params = get_all_config_params('settings.ini')
dict_of_dfs_epoch, epochs_mg, channels, raw_filtered, raw_filtered_resampled, raw_cropped, raw, active_shielding_used = initial_processing(default_settings=all_qc_params['default'], filtering_settings=all_qc_params['Filtering'], epoching_params=all_qc_params['Epoching'], data_file=data_file)

m_or_g_chosen = ['mags', 'grads']

m_or_g_chosen = sanity_check(m_or_g_chosen, channels)

if len(m_or_g_chosen) == 0: 
    raise ValueError('No channels to analyze. Check presence of mags and grads in your data set and parameter do_for in settings.')


default_section = config['DEFAULT']
tmin = default_section['data_crop_tmin']
tmax = default_section['data_crop_tmax']

if not tmin: 
    tmin = 0
else:
    tmin=float(tmin)
if not tmax: 
    tmax = raw.times[-1] 
else:
    tmax=float(tmax)

duration = tmax-tmin

print('Data duration is ', duration, ' seconds')

In [None]:
# Find channel names:

picks_ECG = mne.pick_types(raw_cropped.info, ecg=True)
picks_EOG = mne.pick_types(raw_cropped.info, eog=True)
if picks_ECG.size == 0:
    print('No ECG channels found is this data set')
if picks_EOG.size == 0:
    print('No EOG channels found is this data set')
else:
    ECG_channel_name=[]
    for i in range(0,len(picks_ECG)):
        ECG_channel_name.append(raw.info['chs'][picks_ECG[i]]['ch_name'])
    
    EOG_channel_name=[]
    for i in range(0,len(picks_EOG)):
        EOG_channel_name.append(raw.info['chs'][picks_EOG[i]]['ch_name'])
    print('ECG channel: ', ECG_channel_name)
    print('EOG channels: ',EOG_channel_name)

#Look at the stimulus channel (can limit to only 3-6 sec here for example, or not):
#raw_bandpass.copy().pick_types(meg=False, stim=True).plot(start=3, duration=6)
raw_cropped.copy().pick_types(meg=False, stim=False,ecg=True).plot()
raw_cropped.copy().pick_types(meg=False, stim=False,eog=True).plot()


In [None]:
#HERE WE DROP THE ECG CHANNEL TO TRY TO RECONSTRUCT ECG EVENTS ON BASE OF MAGNETOMETERS.
# RESULTS IS WORSE. IT FINDS NOW 95 INSTEAD OF 62 EVENTS! SPEAKS ALSO TO HOW REASONABLE 
# IT IS TO DO RECONSTRUCTION AT ALL. MAYBE WE SHOULD NOT EVEN GIVE SUCH OPTION? 
# OR COME UP WITH IDEA HOW TO EVALUATE THE RESULT OF RECONSTRUCTION.

#raw_cropped.drop_channels(ECG_channel_name)

In [None]:

#Find ECG events by localizing the R wave peaks.

ecg_events, ch_ecg, average_pulse, ecg=mne.preprocessing.find_ecg_events(raw_cropped, return_ecg=True, verbose=True)
# can explicitely give ECG channel here but not nessesary. MNE can detect it automatically.

# ecg_events array
# The events corresponding to the peaks of the R waves.
# ch_ecg - str - Name of channel used.

# average_pulse float
# The estimated average pulse. If no ECG events could be found, this will be zero.

# ecg array | None
# The ECG data of the synthesized ECG channel, if any. This will only be returned if return_ecg=True was passed.

# HOW TO FIND THE TIME OF EVENTS:
# https://github.com/mne-tools/mne-python/issues/5201
ecg_events_times  = (ecg_events[:, 0] - raw.first_samp) / raw.info['sfreq']

#https://mne.tools/stable/glossary.html#term-first_samp
#The first_samp attribute of Raw objects is an integer representing the number of time samples that passed between 
# the onset of the hardware acquisition system and the time when data recording started. This approach to sample 
# numbering is a peculiarity of VectorView MEG systems, but for consistency it is present in all Raw objects regardless 
# of the source of the data. In other words, first_samp will be 0 in Raw objects loaded from non-VectorView data files.

# checg=ch_ecg.tolist()
# ECG_channel_name_auto=[]
# for i in range(0,len([checg])):
#     ECG_channel_name_auto.append(raw.info['chs'][checg][i]['ch_name'])

# #print('ECG events: \n', ecg_events) 
# print('Channels used to detect ECG: ', ECG_channel_name_auto) 

print('Average pulse: ', average_pulse) 
print('ECG data of the synthesized ECG channel: \n',ecg)
print('Times of ECG events: \n', ecg_events_times)
print('ECG channel used: ', raw_cropped.info['chs'][ch_ecg]['ch_name'])

In [None]:
len(ecg_events)

# THE RECONSTRUCTION OF ECG EVENTS MY MNE ON BASE OF MAGNETOMETERS MAY GIVE WRONG RESULT: 
# CREATED HERE 95 EVENTS, WITH ECG CHANNEL ON SAME DATA 62 EVENTS


In [None]:
#Plot the found ECG events on one of the channels:

data_mags=raw_cropped.get_data(picks = channels['mags'])
data_grads=raw_cropped.get_data(picks = channels['grads'])

dt=duration/len(data_mags[5]) #take 1 random mag channel, for example 5

t=np.arange(tmin, tmax, 1/raw.info['sfreq'])

fig = go.Figure()
fig.add_trace(go.Scatter(x=t, y=data_mags[2], name='data'));
#fig.add_vline(x=ecg_events_times, line_width=1, line_dash="dash", line_color="green")
for ecg in ecg_events_times:
    fig.add_vline(x=ecg, line_width=1, line_color="green")

fig.update_layout(
    title={
    'text': "ECG events on data of 1 magnetometer",
    'y':0.85,
    'x':0.5,
    'xanchor': 'center',
    'yanchor': 'top'},
    xaxis_title="Time in sec",
    yaxis = dict(
        showexponent = 'all',
        exponentformat = 'e'))

fig.show()

In [None]:
#THIS FUNCTION CREATES EPOCHS ON BASE OF ECG EVENTS. 
# It calls find_ecg_events internally, so no need to run previous cell for it.

%matplotlib inline

m_or_g = 'mags'
ecg_epochs = mne.preprocessing.create_ecg_epochs(raw_cropped, tmin=-0.1, tmax=0.1)
f=ecg_epochs.plot_image(combine='mean', picks=m_or_g[0:-1])[0]
print(f)

# import mpld3
# mpld3.save_html(f,'myfig.html')

In [None]:


avg_ecg_epochs.plot_joint(times=[-0.25, -0.025, 0, 0.025, 0.25])
# I guess 0 represents the actual ECG event and all around - the epoch and how the ECG signal developes during 
# the ECG epoch on everage over all channels?


In [None]:
# How to do correlations of the shape of mean ecg with the data. To see which channels are contaminated most.
# Jochem describedd in the meeting:
# - average epochs around ECG events channel wise -> vector 
# - divide vector by norm (square root of dot product of vector with itself) --> vector of length 1 (reference heartbeat vector)
# - dot product between epoch and norm vector
# (look up in the Statistical Learning course again)


# - 1. average epochs around ECG events channel wise -> vector 
ecg_evoked = mne.preprocessing.create_ecg_epochs(raw_cropped).average()
#print(ecg_evoked.data[0, :]) # for example: averaged ECG event (evoked) data of first channel, all timepoints

#This is the mean epoch over all channels, all ecg epochs: is that the right idea?
ecg_event_mean_overall=np.mean(avg_ecg_epochs.data, axis=0)

import plotly.express as px
fig = px.line(ecg_event_mean_overall)
fig.show()

# Project vector of average ECG artifact onto epoch vector:
# - divide vector by norm (square root of dot product of vector with itself) --> vector of length 1 (reference heartbeat vector)
# - dot product between epoch and norm vector
#  
y_noise=ecg_event_mean_overall

# So project now this average vector onto each of the channels data?




In [None]:
# projecting the data now onto the noise vector

#y_noise_norm = np.sqrt(sum(y_noise**2))  

from numpy.linalg import norm
y_noise_norm = norm(y_noise)


proj_of_x_on_y_all=[]
cos_x_y_all = []
for x_data_with_noise in data_mags:
    proj_of_x_on_y = (np.dot(x_data_with_noise, y_noise_norm)/y_noise_norm**2)*y_noise_norm
    proj_of_x_on_y_all += [np.mean(proj_of_x_on_y)]

    #x_data_with_noise_norm = np.sqrt(sum(x_data_with_noise))  
    x_data_with_noise_norm = norm(x_data_with_noise)


    cos_x_y=np.dot(x_data_with_noise, y_noise_norm)/np.dot(x_data_with_noise_norm, y_noise_norm)
    cos_x_y_all += [cos_x_y]


cos_x_y_all

print(proj_of_x_on_y_all)

fig = px.line(proj_of_x_on_y_all)
fig.show()



In [None]:
import math

def dotproduct(v1, v2):
  return sum((a*b) for a, b in zip(v1, v2))

def length(v):
  return math.sqrt(dotproduct(v, v))

def angle(v1, v2):
  return math.acos(dotproduct(v1, v2) / (length(v1) * length(v2)))


ang_all=[]
for x_data_with_noise in data_mags:
    ang_all.append(np.degrees(angle(y_noise, x_data_with_noise)))

fig = px.line(ang_all)
fig.show()

In [None]:
fig = px.line(proj_of_x_on_y_all[77])
fig.show()


In [None]:


# General VECTOR PROJECTION technique: https://www.geeksforgeeks.org/vector-projection-using-python/

# finding norm of the vector v
v_norm = np.sqrt(sum(v**2))    
  
# Apply the formula as mentioned above
# for projecting a vector onto another vector
# find dot product using np.dot()
proj_of_u_on_v = (np.dot(u, v)/v_norm**2)*v

# But according to Jochems paper, need to project epoch (data+ noise) onto the noise (ecg_evoked), not the other way around? 
# To find the noise component in the epoch.


# _____________

# Actually, found mne.preprocessing.compute_proj_ecg which seemst to be already doing all the described steps. Is this what we need?
#
# "Compute SSP (signal-space projection) vectors for ECG artifacts.

# This function will:

# Filter the ECG data channel.

# Find ECG R wave peaks using mne.preprocessing.find_ecg_events().

# Filter the raw data.

# Create Epochs around the R wave peaks, capturing the heartbeats.

# Optionally average the Epochs to produce an Evoked if average=True was passed (default).

# Calculate SSP projection vectors on that data to capture the artifacts.

# Evoked objects typically store EEG or MEG signals that have been averaged over multiple epochs,

# 

# _______
# Evoked: shape (n_channels, n_times)
# Epoch: shape (n_epochs, n_channels, n_times)


In [None]:
df_m = epochs_mags[3].to_data_frame()

df_m

In [None]:
# Find EOG events:
# https://mne.tools/stable/generated/mne.preprocessing.find_eog_events.html

eog_events=mne.preprocessing.find_eog_events(raw, thresh=None, ch_name=None)
# threshfloat | None
# Threshold to trigger the detection of an EOG event. This controls the thresholding of the underlying 
# peak-finding algorithm. Larger values mean that fewer peaks (i.e., fewer EOG events) will be detected. 
# If None, use the default of (max(eog) - min(eog)) / 4, with eog being the filtered EOG signal.

# UNLIKE ECG, CAN NOT RECONSTRUCT EOG EVENTS IF NO DEDICATED CHANNEL IS PRESENT. SO EPOCHS CANT BE CREATED EITHER
# AS EOG channels it needs either: with type eog or ordinary EEG channel that was placed close to the eyes, like Fp1 or Fp2

# But we dont have any of these

# Another related function, but again needs EOG channel or other channels where EOG should be extracted from:
# https://mne.tools/stable/generated/mne.preprocessing.ICA.html#mne.preprocessing.ICA.find_bads_eog

eog_events_times  = (eog_events[:, 0] - raw.first_samp) / raw.info['sfreq']


In [None]:
# RUN THIS CELL ONLY IF YOU DONT HAVE EOG CHANNELS

# "If you don’t have an EOG channel, find_bads_eog has a ch_name parameter that you can use as a proxy for EOG. 
# You can use a single channel, or create a bipolar reference from frontal EEG sensors and use that as virtual EOG channel. 
# This carries a risk however: you must hope that the frontal EEG channels only reflect EOG and not brain dynamics in 
# the prefrontal cortex (or you must not care about those prefrontal signals)."

# So here I plot the magnetometers layout, choose 4 magnetometers near the eyes and try to extract eyeblinks from there

# What is better mags or grads for eyeblinks reconstruction? Or both?

%matplotlib inline

layout_from_raw=mne.channels.find_layout(raw.info, ch_type='mag')
fig=layout_from_raw.plot()

eog_events=mne.preprocessing.find_eog_events(raw, thresh=1e-8, ch_name=['MEG0521', 'MEG0911', 'MEG0511', 'MEG0921'])
# Threshold to trigger the detection of an EOG event. This controls the thresholding of the underlying peak-finding algorithm. 
# Larger values mean that fewer peaks (i.e., fewer EOG events) will be detected. 
# If None, use the default of (max(eog) - min(eog)) / 4, with eog being the filtered EOG signal.

# Detected only 3 EOG events in th entire 1h dataset if threshold is set to defailt: thresh=None
# -> Try to adjust threshold or, maybe it just doesnt really work.
# -> Above: set the threshold to 1e-8, detects 17 eyeblink events - also too little. But I got no idea what threshold value is optimal.

#%% Now want to see the data for these 4 channels I chose for eyeblinks:
chans = ['MEG0521', 'MEG0911', 'MEG0511', 'MEG0921']

#I only know how to plot channels by the idexes, not by names, so have to extract indexes first:
EOG_ch=[]
for c in chans:
    ch=[item for item in mags if c in item]
    EOG_ch.append(ch[0][1])

%matplotlib qt
raw.plot(order=EOG_ch, start=12, duration=4) #plot here only a part of channel."

#Lets discuss if these look like eyeblink!

In [None]:
# Extract EOG epochs:

eog_epochs = mne.preprocessing.create_eog_epochs(raw, baseline=(-0.5, -0.2))
#eog_epochs = mne.preprocessing.create_eog_epochs(raw)
#This function will internally call the find_eog_events again (no need to run previous cell). 


eog_epochs.plot_image(combine='mean')
eog_epochs.average().plot_joint()

#Looks not the same as in preprocessing tutorial:
# https://mne.tools/stable/auto_tutorials/preprocessing/10_preprocessing_overview.html#sphx-glr-auto-tutorials-preprocessing-10-preprocessing-overview-py

#Claims that it found EOG evemts but they dont really look like any events.

In [None]:
# Next, there is ICA used for ECG and EOG artifacts repair.
# https://mne.tools/stable/auto_tutorials/preprocessing/40_artifact_correction_ica.html

# Functions ica_find_ecg_events and ica_find_eog_events - are not working independently, 
# they need to functions above to run first





In [None]:
# found interesting function to fix stimulus artifacts: 
# https://mne.tools/stable/generated/mne.preprocessing.fix_stim_artifact.html