In [None]:
import gc

gc.enable()
gc.collect()

In [None]:
import joblib

In [None]:
import glob
from emgdv_utils import *

In [None]:
SIL_THRESH = 0.93
PNR_THRESH = 10
CoV_THRESH = 2

MIN_PEAKS_SRC = 30

N_COMPONENTS = 300  #0  # maximum number of attempts to detect MU
N_ITERATIONS = 100  # maximum iterations during MU detection attempt

lowcut, highcut = 30, 400  #lowcut, highcut = 20, 500  # as in Farina tutorial
apply_butterworth_filter = True

notch_filter_hz = 50
apply_notch_filter = True

apply_wavelet_filter = False
apply_zscore = True
apply_winsorize = False

MIN_STANDALONE = MIN_PEAKS_SRC
SENSITIVITY_STANDALONE = 1.5
# mu considered identical if their firings ts overlap in 30% of time +-1 ts. Negro advocates for 0.3
DROP_ROA = 0.3

cluster_method = 'k-means'

extension_method = 'spline'
apply_extension = True
NUM_EXTEND = SPLINE_DEGREE = 72

WHITEN_SOLVER = 'eigh'  # svd

FS = 2222
FORCE_CHANNEL_FIRST = True

save_json = True

In [None]:
def run(FILE, PREFIX):

    data = pd.read_csv(FILE)

    name = FILE.split('/')[-1].split('_')[0]
    results_save_path = f'{PREFIX}{name}'

    os.makedirs(results_save_path, exist_ok=True)
    os.makedirs(results_save_path + '/bin/', exist_ok=True)
    os.makedirs(results_save_path + '/AP/raw/', exist_ok=True)
    os.makedirs(results_save_path + '/AP/smooth/', exist_ok=True)

    ch_lens = []
    for col in list(data):
        l = len(data[col].dropna())
        ch_lens.append(l)

    ch_len = max(ch_lens)
    for col in list(data):
        data[col] = scipy_signal.resample(data[col].dropna(), ch_len)

    if not FORCE_CHANNEL_FIRST:
        FORCE_CHANNEL = data.abs().mean(axis=1).values
    else:
        FORCE_CHANNEL = data.iloc[:, 0].dropna().values
        data = data.iloc[:, 1:]

    for ch, y in enumerate(data.T.values):
        N = len(y)
        yf = rfft(y)
        xf = rfftfreq(N, 1 / FS)

        fig, ax = plt.subplots(1, 1, figsize=(6, 3))
        ax.plot(xf, np.abs(yf), color='darkred')
        ax.set_xlabel('Frequency, Hz', labelpad=0)
        ax.set_ylabel('Amplitude', labelpad=0)
        ax.tick_params(pad=0)
        ax.set_title(f'CH{ch+1}')
        plt.savefig(
            f'{results_save_path}/CH{ch+1}-frequencies-spectrum-raw.png',
            dpi=300,
            bbox_inches='tight')
        plt.close()

    FORCE_CHANNEL = moving_average(FORCE_CHANNEL, FS)
    FORCE_CHANNEL[FORCE_CHANNEL < 0] = 0

    force_min = np.quantile(FORCE_CHANNEL, 0.01)
    force_max = np.quantile(FORCE_CHANNEL, 0.99)

    FORCE_CHANNEL = ((FORCE_CHANNEL - force_min) /
                     (force_max - force_min)) * 100
    FORCE_CHANNEL[FORCE_CHANNEL < 0] = 0

    SIGNAL = copy.deepcopy(data.T.values)
    CH, TS = SIGNAL.shape
    RMS = (((SIGNAL.T).T**2).mean(axis=0))**0.5
    RMS = moving_average(RMS, FS)

    if apply_winsorize:
        data = skew_filter(data, FS, 3)
        
    SIGNAL = copy.deepcopy(data.T.values)

    for ch_row in range(CH):

        s = moving_average(SIGNAL[ch_row, :], 3)
        if apply_butterworth_filter:
            s = butterworth_filter(s, lowcut, highcut, FS, order=2)
        if apply_notch_filter:

            if not apply_butterworth_filter:
                y = butterworth_filter(s)
            else:
                y = s

            N = len(y)
            yf = rfft(y)
            xf = rfftfreq(N, 1 / FS)

            mn = np.mean(np.abs(yf))
            sd = np.std(np.abs(yf))
            r = mn + sd * 10

            while np.max(np.abs(yf)) > r:
                N = len(y)
                yf = rfft(y)
                xf = rfftfreq(N, 1 / FS)
                notch_filter_hz = xf[np.argmax(np.abs(yf))]
                print('powerlinepeak:', notch_filter_hz)
                s = y = notch_filter(s, FS, notch_filter_hz)

        if apply_wavelet_filter:
            s = wavelet_filter(s.reshape(1, -1)).flatten()

        SIGNAL[ch_row, :] = s

    DATA_FILTERED = copy.deepcopy(SIGNAL)

    SHAPE = np.array([6, CH / 2])
    fig, axes = plt.subplots(CH, 1, sharey=True, sharex=True, figsize=SHAPE)
    sns.despine()
    plt.subplots_adjust(wspace=0, hspace=0)

    for ch_row, ax in enumerate(axes):
        ax.plot(SIGNAL[ch_row, :], lw=1, color='k')
        ax.set_ylabel(f'CH{ch_row}')
        ax.set_yticklabels([])
        ax.tick_params(pad=0)
        ax.set_ylabel(ax.get_ylabel(), labelpad=0)

    ax.set_xlabel('Time (samples)')

    plt.savefig(f'{results_save_path}/data-channels.pdf',
                dpi=300,
                bbox_inches='tight')
    plt.close()

    if apply_zscore:
        scaler = StandardScaler()
        SIGNAL = scaler.fit_transform(SIGNAL.T).T

    if apply_extension:
        spline = None
        if extension_method == 'spline':

            spline = SplineTransformer(degree=SPLINE_DEGREE,
                                       n_knots=SPLINE_DEGREE + 1,
                                       extrapolation='periodic',
                                       knots='quantile',
                                       include_bias=True)

            SIGNAL = spline.fit_transform(SIGNAL.T).T

        if extension_method == 'extend':
            SIGNAL = extend(SIGNAL, NUM_EXTEND)

    SIGNAL = whiten(SIGNAL, WHITEN_SOLVER)
    
    keep = []
    for i in range(SIGNAL.shape[0]):
        if len(np.unique(SIGNAL[i, :])) < 10:
            continue
        keep.append(i)

    SIGNAL = SIGNAL[keep, :]
    

    data_filtered = pd.DataFrame(DATA_FILTERED.T)
    data_filtered.columns = [f'ch{k+1}' for k in range(data_filtered.shape[1])]
    data_filtered['force_channel'] = FORCE_CHANNEL
    data_filtered.to_csv(f'{results_save_path}/data-prepared.csv.gz',
                         decimal='.',
                         sep=',',
                         index=False,
                         compression='gzip')

    for ch, y in enumerate(DATA_FILTERED):
        N = len(y)
        yf = rfft(y)
        xf = rfftfreq(N, 1 / FS)

        fig, ax = plt.subplots(1, 1, figsize=(6, 3))
        ax.plot(xf, np.abs(yf), color='darkred')
        ax.set_xlabel('Frequency, Hz', labelpad=0)
        ax.set_ylabel('Amplitude', labelpad=0)
        ax.tick_params(pad=0)
        ax.set_title(f'CH{ch+1}')
        plt.savefig(f'{results_save_path}/CH{ch+1}-frequencies-spectrum.png',
                    dpi=300,
                    bbox_inches='tight')
        plt.close()
        
    print(SIGNAL.shape)
    WEIGHTS, SOURCES_INFO, DCR = fastICA(
        SIGNAL,
        N_COMPONENTS,
        sil_thresh=SIL_THRESH,
        pnr_thresh=PNR_THRESH,
        cov_thresh=CoV_THRESH,
        iterations_main=N_ITERATIONS,
        MIN_PEAKS_SRC=MIN_PEAKS_SRC,
        FS=FS,
        MIN_STANDALONE=MIN_STANDALONE,
        SENSITIVITY_STANDALONE=SENSITIVITY_STANDALONE,
        DROP_ROA=DROP_ROA,
        cluster_method=cluster_method)

    MU_FOUND = len(SOURCES_INFO)

    res, stage_res = get_stats_sources_info_full(SOURCES_INFO, FORCE_CHANNEL,
                                                 RMS, TS, FS, {})

    res.to_excel(f'{results_save_path}/statistics.xlsx')

    if MU_FOUND > 0:

        try:
            plot_regression(results_save_path, res)
        except Exception as e:
            print(e)

        try:
            plot_raster(results_save_path, FS, TS, SOURCES_INFO, FORCE_CHANNEL,
                        cluster_method, extension_method)
        except Exception as e:
            print(e)

        try:
            plot_ap_raw(results_save_path, CH, FS, MU_FOUND, DATA_FILTERED,
                        SOURCES_INFO)
        except Exception as e:
            print(e)

        try:
            plot_ap_smooth(results_save_path, CH, FS, MU_FOUND, DATA_FILTERED,
                           SOURCES_INFO)
        except Exception as e:
            print(e)

    joblib.dump(spline, results_save_path + '/bin/spline.bin')
    joblib.dump(scaler, results_save_path + '/bin/zscorer.bin')
    joblib.dump(WEIGHTS, results_save_path + '/bin/WEIGHTS.bin')
    sources_info = [{'mu_ts': src['mu_ts']} for src in SOURCES_INFO]
    joblib.dump(sources_info, results_save_path + '/bin/sources_info.bin')
    joblib.dump(DCR, results_save_path + '/bin/DCR.bin')
    joblib.dump(FORCE_CHANNEL, results_save_path + '/bin/FORCE_CHANNEL.bin')
    joblib.dump(keep, results_save_path + '/bin/KEEP_CHANNEL.bin')

In [None]:
import time
from tqdm.auto import tqdm

In [None]:
run_durations = []

FILES1 = glob.glob('raw_data/rep1/*.csv')
PREFIX1 = 'decomposition/rep1/'

FILES2 = glob.glob('raw_data/rep2/*.csv')
PREFIX2 = 'decomposition/rep2/'

for FILE1, FILE2 in list(zip(FILES1, FILES2))[:]:

    for FILE, PREFIX in zip([FILE1, FILE2], [PREFIX1, PREFIX2]):
        
        start_time = time.time()
        name = FILE.split('/')[-1].split('_')[0]
        print(name)
        run(FILE, PREFIX)
        run_duration = time.time() - start_time
        print(run_duration/60)
        run_durations.append(run_duration)
        gc.collect()

In [None]:
run_durations = np.array(run_durations)/60

In [None]:
def f_median(x, precision=1):
    q25 = round(np.nanquantile(x, 0.25), precision)
    q50 = round(np.nanquantile(x, 0.5), precision)
    q75 = round(np.nanquantile(x, 0.75), precision)
    r = f"{q50}[{q25}-{q75}]"
    return r

f_median(run_durations)