# Entrain Detection
using entrain window detection method.
<!--accelerated by `multiprocessing`-->

- preview single channel
- group export all channels all days into `.mat` files
- group export all channels all days restore from `.mat` files

In [None]:
from multiprocessing import Pool

import EEGAnalysis as ea
import h5py
import numpy as np
import scipy.signal as signal
import matplotlib.pyplot as plt
import re, os
from tqdm import tqdm
from scipy.io import loadmat, savemat

In [None]:
patient_id = 'zhangchen'

ea_manager = ea.DataManager('/media/STORAGE/EEG/Data')  ## create data manager
patient = ea_manager.create_patient(patient_id)  ## create new patient
# patient = ea_manager.get_patient(patient_id)  ## load previous patient

##### Constants #####
_freq = 2000
ROI = (-3, 3)
EntrainWindow = (-1, 1)
EntrainThresh = 1.96

zbase = (-2,-1)
tspec = np.linspace(ROI[0], ROI[1], int(_freq)*(ROI[1]-ROI[0]))

## Bandpass ##
nyq = _freq / 2.0
width = 5.0 / nyq
ripple_db = 60.0

# Compute the order and Kaiser parameter for the FIR filter.
N, beta = signal.kaiserord(ripple_db, width)
taps = signal.firwin(N, [50/nyq, 80/nyq],window=('kaiser', beta),pass_zero=False)


In [None]:
def detect_entrain(chidx):
    _data_list = patient.load_isplit(chidx)
    _dates = _data_list.keys()

    # _result = np.zeros((np.size(frange, 0), np.size(tspec)))
    # _valid_count = 0

    _result = {'chidx':chidx, 'date':[], 'trace':[], 'detection':[], 'latency':[], 'amplitude':[]}
    for _idate in _dates:

        _data = _data_list[_idate]['value']
        
        _filtered = signal.hilbert(signal.filtfilt(taps, 1, _data))

        # first 20 markers for 10 sec trials, last 20 for 5 sec trials
        _marker = patient.get_marker(file=_idate)[20:]
        _entrain_marker = [np.mean(np.diff(_marker)) + _marker[-1]]
        
        _epoch = ea.create_1d_epoch_bymarker(_data, _entrain_marker, ROI, _freq).reshape((1,1,-1))
        entry_power = ea.decomposition.dwt_power(_epoch, _freq, baseline=(zbase[0]-ROI[0], zbase[1]-ROI[0]))
        
        _smooth = ea.decomposition.gaussianwind(entry_power[0], 2000, 0.05)
        _entrain = np.max(_smooth[(tspec > EntrainWindow[0]) & (tspec < EntrainWindow[1])]) > 1.96 and\
                   np.max(_smooth[(tspec < EntrainWindow[0]) | (tspec > EntrainWindow[1])]) < 1.96
        
        _latency = np.argmax(_smooth[(tspec > EntrainWindow[0]) & (tspec < EntrainWindow[1])]) / 2000 - EntrainWindow[0]
        _amplitude = np.max(_smooth[(tspec > EntrainWindow[0]) & (tspec < EntrainWindow[1])])
        
        _result['latency'].append(_latency)
        _result['amplitude'].append(_amplitude)
        _result['trace'].append(_smooth)
        _result['detection'].append(_entrain)
        _result['date'].append(_idate)
        
    _result['trace'] = np.array(_result['trace'])
    _result['detection'] = np.array(_result['detection'])
    _result['latency'] = np.array(_result['latency'])
    _result['amplitude'] = np.array(_result['amplitude'])

    return _result

In [None]:
_total_ch = len(patient._sgch_config['chidx'])
with Pool(processes=5) as _pool:
    _output = list(tqdm(_pool.imap_unordered(detect_entrain, range(_total_ch)), total=_total_ch))

In [None]:
_dates = np.array([item['date'] for item in _output])
_entrain = np.array([item['detection'] for item in _output])
_latency = np.array([item['latency'] for item in _output])
_amplitude = np.array([item['amplitude'] for item in _output])
_trace = np.array([item['trace'] for item in _output])

with h5py.File('./export/%s_entrain.h5'%patient_id, 'w') as _f:
    _f.create_dataset(_dates[:, 0::3][0][0][-1], data=_entrain[:, 0::3])
    _f.create_dataset(_dates[:, 1::3][0][0][-1], data=_entrain[:, 1::3])
    _f.create_dataset(_dates[:, 2::3][0][0][-1], data=_entrain[:, 2::3])
    
savemat('./export/%s_entrain.mat'%patient_id, {'1':_entrain[:, 0::3],'2':_entrain[:, 1::3],'3':_entrain[:, 2::3]})
savemat('./export/%s_latency.mat'%patient_id, {'1':_latency[:, 0::3],'2':_latency[:, 1::3],'3':_latency[:, 2::3]})
savemat('./export/%s_amplitude.mat'%patient_id, {'1':_amplitude[:, 0::3],'2':_amplitude[:, 1::3],'3':_amplitude[:, 2::3]})

with h5py.File('./export/%s_entrain_trace.h5'%patient_id, 'w') as _f:
    for item in _output:
        _g = _f.create_group('%03d'%item['chidx'])
        _g.create_dataset('date', data=np.array(item['date'], dtype='S'))
        _g.create_dataset('trace', data=item['trace'], compression='gzip')

k = dict([('channel%03d'%item['chidx'], item)for item in _output])
save('./export/%s_entrain_trace.mat'%patient_id, k)

---