In [None]:
import mne
import os
from preprocessor.filehandling import get_files
import os.path as op
from mne import chpi
from mne.preprocessing import maxwell_filter
import numpy as np
from preprocessor.cleaning import ICA
from os import makedirs
import json
import matplotlib.pyplot as plt

### !!!deal only with one raw fif file at a time !!! ###
raw_subjects_dir = '/data/sheng/MEG/'
subject = "case_0000"
meg_subjects_dir = op.join('raw_subjects_dir', subject, 'spontaneous')

## these are maxfiltered by Paul's tech
filepath, filename = get_files(op.join(raw_subjects_dir, subject, ''), incl_strings = ['spont_rsa', 'tsss.fif'],  ftype='.fif')
print('Select the fif file from blew list to be processed **********************************')

for item in filepath:
    print(item)


In [None]:
SELECTED_FILE = 1
FILEPATH = filepath[SELECTED_FILE]
FILENAME = filename[SELECTED_FILE]
raw = mne.io.read_raw_fif(FILEPATH, allow_maxshield=True, preload=True)
print(raw.info)
########################################################################
# PSD check on the raw
########################################################################
figu = raw.plot_psd(fmin=1, fmax=249, n_jobs=2)
ax=figu.axes[0]
ax.xaxis.set_major_locator(plt.MultipleLocator(10))
ax.xaxis.set_minor_locator(plt.MultipleLocator(2))
ax.grid(True, which='both')

ax=figu.axes[1]
ax.xaxis.set_major_locator(plt.MultipleLocator(10))
ax.xaxis.set_minor_locator(plt.MultipleLocator(2))
ax.grid(True, which='both')

plt.show()

In [None]:
#%% Plot data to check bad channels
# visually check bad channels, and click bad channels to mark
raw_copy = raw.copy().filter(1, 100)
# raw_copy = mne.chpi.filter_chpi(raw_copy)
raw_copy.plot(duration=10.0, bad_color='red')


#%% print bad channels and copy it to raw
raw.info['bads'] = raw_copy.info['bads'].copy()
print(raw.info['bads'])

In [None]:
# save bad channels
bad_channel_dir = op.join(meg_subjects_dir, FILENAME[:-4], 'bad_channels', '')
makedirs(bad_channel_dir , exist_ok=True)
fpath_bad_channel = op.join(bad_channel_dir, FILENAME[:-4] + "_bad_channels.txt")
bad_channels = raw.info['bads']
with open(fpath_bad_channel, 'w') as file:
     file.write(json.dumps(bad_channels))

In [None]:
'''
######################################################  
# Motion compensation (WIP)                       ##
# Compute all necessary CHPI stuffs, CHECK with Paul/Matias about this!!!
######################################################  
chpi_amplitudes = chpi.compute_chpi_amplitudes(raw)
chpi_locs = chpi.compute_chpi_locs(raw.info, chpi_amplitudes)
head_pos = chpi.compute_head_pos(raw.info, chpi_locs)

########################################################## 
# Run the Maxfilter process 
##########################################################
calibration_path = 'calibration_filepath' + 'sss_cal_BioMag_TRIUX_3126.dat'
crosstalk_path = 'crosstalk_filepath' + 'ct_sparse_BioMag_TRIUX_3126.fif'      
raw_ts = maxwell_filter(raw, head_pos=head_pos, st_correlation=0.9, st_duration = 20,\
                        calibration=calibration_path, cross_talk=crosstalk_path)

# PSD check
raw_ts.plot_psd(fmin=1, fmax=249, n_jobs=8)
'''

In [None]:
############################################################################################
###                                            Notch Filtering                          ###
############################################################################################
# Notch filtering

raw_ts = raw ## if raw is maxfiltered by Paul's tech
freqs = np.arange(60, 301, 60)
freqs = np.concatenate((np.array([27.5]), freqs))
raw_filt = raw_ts.notch_filter(freqs, n_jobs=4)
raw_filt = raw_ts.notch_filter([59], n_jobs=4)
# Band-pass filtering
l_freq = None
h_freq = 249
raw_filt = raw_filt.filter(l_freq, h_freq, n_jobs=8)

# PSD check
raw_filt.plot_psd(fmin=1, fmax=h_freq, n_jobs=8)

In [None]:
##############################################
# Save filtered data
##############################################
#filt_dir = op.join(meg_subjects_dir, filename[0][:-4], 'filtered', '')
#makedirs(filt_dir , exist_ok=True) # if exists, do nothing
#fpath_filt = op.join(filt_dir, filename[0][:-4] + "_filt.fif")
fpath_filt = op.join(raw_subjects_dir, subject, '',FILENAME[:-4] + "_filt.fif")
raw_filt.save(fpath_filt, overwrite=True)

In [None]:
##############################################
#%% ICA
##############################################
subj_dir = op.join(raw_subjects_dir, subject, '')
ica_dir = op.join(meg_subjects_dir, 'ica', '', filename[0][:-4], '')
makedirs(ica_dir, exist_ok=True)
fpath_ica = op.join(raw_subjects_dir, subject, fpath_filt[:-4] + '_ica.fif')

# Run ICA and save results
ica = ICA(raw_filt, FILENAME[:-4], subj_dir)
ica.compute_ica(n_components=70, l_freq=1, h_freq=100)
ica.make_ica_figs(ica_dir)
ica.ica.save(fpath_ica, overwrite=True)

In [None]:



#######################################################
#%% plot ICA to visually mark bad components
#######################################################
ica = mne.preprocessing.read_ica(fpath_ica)
raw_filt = mne.io.read_raw_fif(fpath_filt, preload=True)
ica.plot_sources(raw_filt.copy().filter(1, 100, n_jobs=8))

In [None]:
# Use EOG and ECG channels to select ICA components.
# Note!! The EOG and ECG results are not 100% reliable. Just use it for reference
## This operation can be very CPU intense.

# ocular motor
eog_indices, eog_scores = ica.find_bads_eog(raw_filt, threshold=0.4, measure='correlation')
ica.plot_scores(eog_scores)

In [None]:
## heart 
ecg_indices, ecg_scores = ica.find_bads_ecg(raw_filt, method='correlation')
ica.plot_scores(ecg_scores)

In [None]:
[eog_scores[eog_indices], ecg_scores[ecg_indices]]

In [None]:
# visualize
idx = np.array(ecg_indices + eog_indices)
ica_scores = np.concatenate((ecg_scores[ecg_indices], eog_scores[eog_indices]))

plt.figure(figsize=(7.5, 2))
plt.bar(idx, ica_scores)
#plt.bar([0], [eog_scores[0]], color='red')
plt.xlim(-1, 69)
plt.show()

In [None]:
print(eog_indices)
print(ecg_indices)

In [None]:
toExclude = list(np.unique(eog_indices + ecg_indices))
ica.exclude = [0] + toExclude
ica.save(fpath_ica, overwrite=True) # save ica once more after setting ica.excludes

In [None]:
print(fpath_ica)

In [None]:
#%% Set excluded ica components and save the ica results

print(ica.exclude)
ica.save(fpath_ica, overwrite=True) # save ica once more after setting ica.excludes
# Apply the ICs exclusion to raw_filt
ica.plot_overlay(raw_filt)
raw_ica = ica.apply(raw_filt)

# Apply the ICs exclusion to raw_filt
ica.plot_overlay(raw_filt)
raw_ica = ica.apply(raw_filt)

In [None]:
#%% Final check
raw_ica.plot_psd(fmin=1, fmax=h_freq, n_jobs=8)
raw_ica_copy = raw_ica.copy().filter(1, 100) 
raw_ica_copy.plot(duration=10.0, bad_color='red')

In [None]:
#%% save bad segments marked from final check
raw_ica.set_annotations(raw_ica_copy.annotations)

# save data
fpath_clean = op.join(raw_subjects_dir, subject, fpath_ica[:-4] + '_cleaned.fif')
raw_ica.save(fpath_clean, overwrite=True)