In [3]:
%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np
import EEGAnalysis as ea

import dask.array as da

In [4]:
datadir = "../../Data"
resultdir = "../../Result"
patientName = "Chen Zhou"
fs = 2000

targetexp = "180901-3-10"

datacontainer = ea.CompactDataContainer(
    datadir, resultdir, patientName, 
    targetexp, fs, (-2, 10)
)

In [94]:
chidx = 108
ch_split = datacontainer.group_channel_by_marker(chidx, "grating")

frange = np.logspace(np.log10(1), np.log10(200), 40)
%time dwt_result = ea.dwt(ch_split, fs, frange, reflection=True)

CPU times: user 4.28 s, sys: 615 ms, total: 4.9 s
Wall time: 4.92 s


In [108]:
ttsplit = np.vstack((ch_split, ch_split, ch_split, ch_split, ch_split, ch_split))
np.shape(ttsplit)

(120, 24000)

In [109]:
frange = np.logspace(np.log10(1), np.log10(200), 40)
%time dwt_result = ea.dwt(ttsplit, fs, frange, reflection=True)

CPU times: user 31.7 s, sys: 14.4 s, total: 46.1 s
Wall time: 57.2 s


In [110]:
frange = np.logspace(np.log10(1), np.log10(200), 40)
%time dwt_result = da_dwt(ttsplit, fs, frange, reflection=True)

CPU times: user 2min 5s, sys: 14.4 s, total: 2min 20s
Wall time: 53.2 s


In [111]:
## wavelet
def da_morlet(F, fs):
    """Morlet wavelet"""
    wtime = da.linspace(-1, 1, 2*fs, chunks=(2*fs,))
    s = 6 / (2 * np.pi * F)
    wavelet = da.exp(2*1j*np.pi*wtime*F) * da.exp(-wtime**2/(2*s**2))
    return wavelet

def da_dwt(data, fs, frange, wavelet=da_morlet, reflection=False):

    dist_data = da.from_array(data, chunks=(1, np.size(data, 1)))

    if np.ndim(dist_data) == 1:
        dist_data = da.reshape(dist_data, (1, len(data)))

    if reflection:
        data_flip = da.fliplr(dist_data)
        data_fft = da.hstack((data_flip, dist_data, data_flip))
        data_fft = da.rechunk(data_fft, chunks=(1, np.size(data_fft, 1)))
    else:
        data_fft = dist_data

    nConv = np.size(data_fft, -1) + int(2*fs)
    fft_data = da.fft.fft(data_fft, nConv)

    Dwt = np.zeros((np.size(frange), np.size(data, 0), np.size(data, 1)), dtype="complex")

    for idx, F in enumerate(frange):
        fft_wavelet = da.fft.fft(wavelet(F, fs), nConv)
        conv_wave = da.fft.ifft(fft_wavelet * fft_data, nConv)
        conv_wave = conv_wave[:, fs:-fs]

        if reflection:
            Dwt[idx, :, :] = conv_wave[:, np.size(data, 1):-np.size(data, 1)].compute()
        else:
            Dwt[idx, :, :] = conv_wave.compute()

    return Dwt