# Benchmark Demo
This demo is for reproducing the benchmark in:
<br>Zhou et al. <b>"Effective and Efficient Neural Networks for Spike Inference from In Vivo Calcium Imaging"</b>


# Requirements:

<b>Hardwares:</b>
- CUDA-enabled GPU
- 24Gb of system RAM (recommended)

<b>Sofwares and packages:</b>
- python == 3.6
- torch  >= 1.7.1 (with CUDA and cuDNN toolkits)
- numpy  >= 1.19.2
- scipy  >= 1.5.2
- tqdm   >= 4.59.0

- MATLAB (2017b - 2020b)
    - Please install MATLAB engine for Python.
    - Then include the <i>spikes-master</i> and <i>brick-master</i> folders as working path in MATLAB. 

<b>Note: </b>
- The above requirements are for producing large-scale simulations in benchmark. 
- Minor variation in benchmark results may exist due to different setups as listed above. 
- For inference purposes in actual applications, any regular PC with a <b>CPU</b> is sufficient (see <i>ENS2_demo.ipynb</i> for details). 

# Database

Before running this benchmark demo, please download the database (Rupprecht et al. 2021) from the below repository:
<br>https://github.com/HelmchenLabSoftware/Cascade/tree/master/Ground_truth

<b>Please put the downloaded dataset 1-27 into the "ground_truth" folder on a same file level as this notebook.</b>
<br>For example:
<br>.../ Benchmark_demo.ipynb
<br>.../ ground_truth/ DS01-OGB1-m-V1/ (files)
<br>.../ ground_truth/ DS02-OGB1-2-m-V1/ (files)
<br>(etc...)
<br>.../ ground_truth/ DS21-jGECO1a-m-V1/ (files)
<br>(etc...)
<br>.../ ground_truth/ DS27-GCaMP6f-m-PV-vivo-V1/ (files)

In [1]:
from __future__ import print_function, division
import os
import time
import datetime
import sys

import argparse
import random
import math

import torch
import torch.nn as nn
import torch.nn.parallel
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
import torch.optim as optim
from torch.autograd import Variable

import torch.utils.data
from collections import OrderedDict
from torch.utils.data import TensorDataset, DataLoader
from pytorchtools import EarlyStopping

import numpy as np
import scipy
from scipy import signal
import scipy.io as scio
import copy
import glob as glob
from tqdm.auto import trange, tqdm

random.seed(0)
torch.manual_seed(0)
np.random.seed(0)

cudnn.deterministic = True
cudnn.benchmark = False

HalfTensor = torch.cuda.HalfTensor
FloatTensor = torch.cuda.FloatTensor
LongTensor = torch.cuda.LongTensor

import matlab.engine
eng = matlab.engine.start_matlab()

# Initialize Hyper-parameters

In [2]:
global opt
opt = argparse.ArgumentParser()
opt.sampling_rate = 60     # re-sampling rate of data
opt.smoothing_std = 0.025  # size of smoothing window (unit:s)
opt.smoothing = opt.smoothing_std * opt.sampling_rate

opt.causal_kernel = False
opt.gaussian_kernel = not opt.causal_kernel # use Gaussian smoothing kernel
    
opt.signal_len = 96
opt.classes = 6

opt.lr = 0.001
opt.epochs = 5000
opt.patience = 500
opt.batch_size = 1024
opt.sample_interval = opt.epochs

# Pre-process Training Data

In [3]:
def load_all_ground_truth(sampling_rate):
    ground_truth_folder='./ground_truth/'
    datasets = glob.glob(os.path.join(ground_truth_folder, 'DS*'))
    dataset_trial = dict()
    dataset_neuron = dict()
    dataset_prop = dict()
    for j, dataset in enumerate(datasets):
        files = glob.glob(os.path.join(dataset, '*'))
        recording_trial_list = list()
        recording_neuron_list = list()
        recording_sample_count = 0
        recording_raw_fs = []
        recording_noise_level = []
        recording_firing_rate = []
        for i, file in enumerate(files):
            try:
                recording_neuron, recording_trial, count = load_recordings_from_file(file, sampling_rate)
                recording_neuron_list.extend(recording_neuron)
                recording_trial_list.extend(recording_trial)
                recording_sample_count += count
                recording_raw_fs.append(recording_neuron[0]['frame_rate'])
                recording_noise_level.append(recording_neuron[0]['noise_level'])
                recording_firing_rate.append(recording_neuron[0]['firing_rate'])
            except:
                print('Problem loading file {} from {}'.format(i+1, dataset))
                pass
        dataset_neuron[j+1] = recording_neuron_list
        dataset_trial[j+1] = recording_trial_list
        dataset_prop[j+1] = dict()
        dataset_prop[j+1]['name'] = os.path.basename(dataset)
        dataset_prop[j+1]['dset'] = j+1
        dataset_prop[j+1]['frame_rate'] = np.mean(recording_raw_fs)
        dataset_prop[j+1]['noise_level'] = np.mean(recording_noise_level)
        dataset_prop[j+1]['firing_rate'] = np.mean(recording_firing_rate)
        dataset_prop[j+1]['neuron_number'] = len(recording_neuron_list)
        dataset_prop[j+1]['trial_number'] = len(recording_trial_list)
        dataset_prop[j+1]['duration'] = int(recording_sample_count/sampling_rate/60)

        print(f'[D{j+1:2d}]  Raw Fs: {np.mean(recording_raw_fs):>5.1f}  Noise: {np.mean(recording_noise_level):>4.1f}  FR: {np.mean(recording_firing_rate):>4.1f}  #Neuron: {len(recording_neuron_list):3d}  #Trial: {len(recording_trial_list):3d}  #Sample: {recording_sample_count:7d}  Duration: {int(recording_sample_count/sampling_rate/60):4d}  Name: {os.path.basename(dataset)[:20]}')
    return dataset_neuron, dataset_trial, dataset_prop

def build_causal_kernel(sampling_rate):
    xx = np.arange(0,199)/sampling_rate
    yy = scipy.stats.invgauss.pdf(xx,opt.smoothing/sampling_rate*2.0,101/sampling_rate,1)
    ix = np.argmax(yy)
    yy = np.roll(yy,int((99-ix)/1.5))
    causal_smoothing_kernel = yy/np.nansum(yy)
    return causal_smoothing_kernel
        
def load_recordings_from_file(file_path, sampling_rate):

    data = scio.loadmat(file_path)['CAttached'][0]

    recording_trial = list()

    trace_seq, spike_seq, rate_seq = np.zeros((0, 1),dtype='float32'), np.zeros((0,),dtype='int'), np.zeros((0,),dtype='float32')
    trace_seg, spike_seg, rate_seg = np.zeros((0, opt.signal_len),dtype='float32'), np.zeros((0, opt.signal_len),dtype='int'), np.zeros((0, opt.signal_len),dtype='float32')
    spike_num, rate_num, class_num = np.zeros((0, 1),dtype='int'), np.zeros((0, 1),dtype='float32'), np.zeros((0, 1),dtype='int')
    
    # record sampling rate after processing
    FS, FS_resampled = [],[]
    
    # for calculation of neuron-wise noise level
    concat_traces_mean = np.zeros((0,),dtype='float32')
    
    # for calculation of neuron-wise firing rate
    concat_events = []
    concat_times = []
    
    # for calibration of ER computation
    concat_events_times = []
    cum_times = 0
    
    for i,trial in enumerate(data):
        # find the relevant elements in the data structure
        # (dF/F traces; spike events; time stamps of fluorescence recording)
        keys = trial[0][0].dtype.descr
        keys_unfolded = list(sum(keys, ()))

        try:
            traces_index = int(keys_unfolded.index("fluo_mean")/2)
            fluo_time_index = int(keys_unfolded.index("fluo_time")/2)
            events_index = int(keys_unfolded.index("events_AP")/2)
        except:
            continue

        # spikes
        events = trial[0][0][events_index]
        events = events[~np.isnan(events)]
        ephys_sampling_rate = 1e4
        event_time = events/ephys_sampling_rate
        
        # fluorescence
        fluo_times = np.squeeze(trial[0][0][fluo_time_index])
        traces_mean = np.squeeze(trial[0][0][traces_index])
        traces_mean = traces_mean[:fluo_times.shape[0]]

        traces_mean = traces_mean[~np.isnan(fluo_times)]
        fluo_times = fluo_times[~np.isnan(fluo_times)]

        frame_rate = 1/np.mean(np.diff(fluo_times))

        # concatenate for statistics
        concat_traces_mean = np.concatenate([concat_traces_mean,traces_mean], axis=0)
        concat_events.append(len(events))
        concat_times.append(fluo_times[-1])
        
        # calibrate onset time
        event_time = event_time[np.logical_and(fluo_times[0]<=event_time,event_time<=fluo_times[-1])]
        event_time = event_time - fluo_times[0] + 1/frame_rate
        fluo_times = fluo_times - fluo_times[0] + 1/frame_rate
        if event_time.size==0:
            continue

        # resampling
        num_samples = int(round(traces_mean.shape[0]*sampling_rate/frame_rate))
        (traces_mean,fluo_times_resampled) = scipy.signal.resample(traces_mean,num_samples,np.squeeze(fluo_times),axis=0)
        frame_rate_resampled = 1/np.nanmean(np.diff(fluo_times_resampled))
        
        # calibrate bin size
        fluo_times_resampled = fluo_times_resampled*frame_rate_resampled/sampling_rate
        frame_rate_resampled = 1/np.nanmean(np.diff(fluo_times_resampled))
        
        # cleaning data
        num_samples -= int(np.floor(frame_rate_resampled))
        traces_mean = traces_mean[int(np.ceil(0.5*frame_rate_resampled)):int(np.ceil(0.5*frame_rate_resampled))+num_samples]
        fluo_times_resampled = fluo_times_resampled[int(np.ceil(0.5*frame_rate_resampled)):int(np.ceil(0.5*frame_rate_resampled))+num_samples]
        
        # calibration again
        if event_time.size==0 or fluo_times_resampled.size==0:
            continue
        event_time = event_time[np.logical_and(fluo_times_resampled[0]<=event_time,event_time<=fluo_times_resampled[-1])]
        event_time = event_time - fluo_times_resampled[0] + 1/frame_rate_resampled
        fluo_times_resampled = fluo_times_resampled - fluo_times_resampled[0] + 1/frame_rate_resampled
        if event_time.size==0:
            continue

        # bin the ground truth (spike times) into time bins determined by the resampled calcium recording
        fluo_times_bin_centers = fluo_times_resampled
        fluo_times_bin_edges = np.append(fluo_times_bin_centers,fluo_times_bin_centers[-1]+1/frame_rate_resampled/2) - 1/frame_rate_resampled/2

        [events_binned,event_bins] = np.histogram(event_time, bins=fluo_times_bin_edges)

        # concatenate event_time
        concat_events_times = np.concatenate([concat_events_times, event_time+cum_times])
        cur_times = len(traces_mean)/frame_rate_resampled
        cum_times += cur_times

        # trial-wise firing rate
        firing_rate = sum(events_binned)/(len(event_bins)/frame_rate_resampled)
        
        # trial-wise noise level
        noise_level = np.nanmedian(np.abs(np.diff(traces_mean)))*100/np.sqrt(frame_rate)
        
        # do pre-processing here if needed
        traces_mean = np.expand_dims(traces_mean, 0)
        
        # padding incase recording is too short
        traces_mean = np.concatenate([np.zeros((1,opt.signal_len//2)), traces_mean, np.zeros((1,opt.signal_len//2))], axis=1)
        events_binned = np.concatenate([np.zeros((opt.signal_len//2,)), events_binned, np.zeros((opt.signal_len//2,))], axis=0)
        
        # smooth spikes to facilitate gradient descents
        if opt.causal_kernel:
            if 'causal_smoothing_kernel' not in locals():
                causal_smoothing_kernel = build_causal_kernel(sampling_rate)            
            events_binned_smooth = np.convolve(events_binned.astype(float),causal_smoothing_kernel,mode='same')
        if opt.gaussian_kernel:
            events_binned_smooth = scipy.ndimage.filters.gaussian_filter(events_binned.astype(float), sigma=opt.smoothing_std*sampling_rate)
        
        # format data into segments
        data_len = len(events_binned)-opt.signal_len
        before = int(opt.signal_len//2)
        after = int(opt.signal_len//2-1)
        
        X = np.zeros((data_len, opt.signal_len), dtype='float32')
        YY_spike = np.zeros((data_len, opt.signal_len), dtype='int')
        YY_rate = np.zeros((data_len, opt.signal_len), dtype='float32')
        Y_spike = np.zeros((data_len, 1), dtype='int')
        Y_rate = np.zeros((data_len, 1), dtype='float32')
        Y_class = np.zeros((data_len, 1), dtype='int')

        traces_mean = traces_mean.astype('float32')
        events_binned = events_binned.astype('int')
        events_binned_smooth = events_binned_smooth.astype('float32')
        
        for time_point in range(data_len):
            X[time_point,:] = traces_mean[:,time_point:time_point+opt.signal_len]
            YY_spike[time_point,:] = events_binned[time_point:time_point+opt.signal_len]
            YY_rate[time_point,:] = events_binned_smooth[time_point:time_point+opt.signal_len]
            Y_spike[time_point] = events_binned[time_point+before]
            Y_rate[time_point] = events_binned_smooth[time_point+before]
            Y_class[time_point] = events_binned[time_point+before]
        Y_class[Y_class>=opt.classes] = opt.classes-1
        
        recording_trial.append(dict(time_resampled=fluo_times_resampled, 
                                    frame_rate=frame_rate,frame_rate_resampled=frame_rate_resampled,
                                    firing_rate=firing_rate,
                                    noise_level=noise_level,
                                    trace_seq=traces_mean[:,before:before+data_len].T, 
                                    spike_seq=events_binned[before:before+data_len], 
                                    rate_seq=events_binned_smooth[before:before+data_len],
                                    trace_seg=X, spike_seg=YY_spike, rate_seg=YY_rate,
                                    spike_num=Y_spike, rate_num=Y_rate, class_num=Y_class,
                                    events_times=event_time, elapsed_times=cur_times))

        FS.append(frame_rate)
        FS_resampled.append(frame_rate_resampled)
        
        trace_seq = np.concatenate([trace_seq, traces_mean[:,before:before+data_len].T], axis=0)
        spike_seq = np.concatenate([spike_seq, events_binned[before:before+data_len]], axis=0)
        rate_seq = np.concatenate([rate_seq, events_binned_smooth[before:before+data_len]], axis=0)
        trace_seg = np.concatenate([trace_seg, X], axis=0)
        spike_seg = np.concatenate([spike_seg, YY_spike], axis=0)
        rate_seg = np.concatenate([rate_seg, YY_rate], axis=0)
        spike_num = np.concatenate([spike_num, Y_spike], axis=0)
        rate_num = np.concatenate([rate_num, Y_rate], axis=0)
        class_num = np.concatenate([class_num, Y_class], axis=0)
        
    recording_neuron = [dict(frame_rate=np.nanmean(FS), frame_rate_resampled=np.nanmean(FS_resampled), 
                             firing_rate=np.sum(concat_events)/np.sum(concat_times),
                             noise_level=np.nanmedian(np.abs(np.diff(concat_traces_mean)))*100/np.sqrt(np.nanmean(FS)),
                             trace_seq=trace_seq, spike_seq=spike_seq, rate_seq=rate_seq,
                             trace_seg=trace_seg, spike_seg=spike_seg, rate_seg=rate_seg,
                             spike_num=spike_num, rate_num=rate_num, class_num=class_num,
                             events_times=concat_events_times, elapsed_times=cum_times)]
    
    return recording_neuron, recording_trial, trace_seg.shape[0]

In [4]:
# datasets, datasets_raw, datasets_prop = load_all_ground_truth(opt.sampling_rate)

# Define Models

In [5]:
from Models import * 

# Define Helper Functions

In [6]:
# Greedy algorithm to estimate spike-event from spike-rate

def estimate_spike(rate, std=opt.smoothing):
    rate = np.float32(np.array(copy.deepcopy(rate))).squeeze()
    # remove bubbles produced by neural network
    rate[rate<0.02/std] = 0
    # initialize
    rate_diff = np.diff(np.int8(rate>0))
    est_spike = np.zeros(rate.shape, dtype='float32')
    est_rate = np.zeros(rate.shape, dtype='float32')
    onset, offset = 0, 0
    for idx in range(len(rate_diff)):
        # locate each piece of slices with spike rate
        if rate_diff[idx] == 1:
            onset = idx+1
        elif rate_diff[idx] == -1:
            if onset > 0:
                offset = idx
                # extract pieces of slices
                slices = rate[onset:offset+1]
                # at least one spike is included when probability is over 0.5
                could_add = True
                cur_spike = np.zeros(slices.shape, dtype='float32')
                if np.sum(slices)>=0.5:
                    cur_spike[np.argmax(slices)] = 1
                cur_rate = scipy.ndimage.filters.gaussian_filter(cur_spike, sigma=std, mode='constant', cval=0.)
                cur_loss = np.sum((slices-cur_rate)**2)
                # iteratively insert spikes that are best-match
                while could_add:
                    candidate_spike = cur_spike + np.eye(len(slices),len(slices),dtype='float32')
                    candidate_rate = scipy.ndimage.filters.gaussian_filter(candidate_spike, sigma=(0,std), mode='constant', cval=0.)
                    candidate_loss = np.sum(np.power(slices-candidate_rate,2),1)
                    new_loss, new_loss_idx = np.amin(candidate_loss), np.argmin(candidate_loss)
                    if new_loss - cur_loss <= -0.00000001:
                        cur_spike = candidate_spike[new_loss_idx,:]
                        cur_rate = candidate_rate[new_loss_idx,:]
                        cur_loss = new_loss
                        could_add = True
                    else:
                        est_spike[onset:offset+1] = cur_spike
                        est_rate[onset:offset+1] = cur_rate
                        could_add = False
        # force estimation with maximum slice length of 500 data points
        elif idx - onset >= 500-1:
            if len(rate_diff) > idx+1:
                rate_diff[idx+1] = -1
            if len(rate_diff) > idx+2:
                rate_diff[idx+2] = 1
    return est_spike

In [7]:
# Function to extract spike-event timestamps from spike trains

def extract_event(spike):
    spike_input = np.squeeze(copy.deepcopy(spike))
    event_output = []
    while np.sum(spike_input>0):
        event_output += ((np.where(spike_input>0)[0]+1)/opt.sampling_rate).tolist()
        spike_input -= 1
    event_output.sort()
    return event_output

In [8]:
# Function to initialize models

def weights_init_normal(m):
    classname = m.__class__.__name__
    if classname.find("Conv") != -1:
        torch.nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find("Linear") != -1:
        torch.nn.init.normal_(m.weight.data, 0.0, 0.02)

In [9]:
# Custom Corrleation based loss function

def pearson_corr_loss(output, target):
    vx = torch.squeeze(output.reshape(-1,1))
    vy = torch.squeeze(target.reshape(-1,1))
    vx = vx - torch.mean(vx)
    vy = vy - torch.mean(vy)
    batch_loss = torch.dot(vx, vy) / (torch.norm(vx) * torch.norm(vy) + 1e-7)
    return batch_loss

# Custom van Rossum distance based loss function

def eucd_loss(output, target_rate, target_spike):
    vx = torch.squeeze(output.reshape(-1,1))
    vy = torch.squeeze(target_rate.reshape(-1,1))
    vz = torch.squeeze(target_spike.reshape(-1,1))
    eucd_loss = torch.sqrt(torch.sum((vx-vy)**2)/opt.smoothing_std/torch.sum(vz))
    return eucd_loss

In [10]:
# Measurement used in CASCADE (Correlation, Error, and Bias)

def evaluate_cascade(predict_input, truth_input):
    predict = np.squeeze(copy.deepcopy(predict_input))
    predict[predict<0] = 0
    truth = np.squeeze(copy.deepcopy(truth_input))
    truth[truth<0] = 0
    corr = np.corrcoef(predict, truth)[0,-1]
    diff = predict - truth
    FP = np.sum(np.abs(diff[diff>0]))
    FN = np.sum(np.abs(diff[diff<0]))
    denorm = np.sum(np.abs(truth))
    error = (FP+FN)/denorm
    bias = (FP-FN)/denorm
    return corr, error, bias

# van Rossum distance measurement

def evaluate_eucd(predict, truth_rate, truth_spike):
    test_eucd = np.sqrt(np.sum((predict.squeeze()-truth_rate.squeeze())**2)/opt.smoothing_std/np.sum(truth_spike))
    return test_eucd

# Error rate measurement

def evaluate_er(predict_input, truth_input, win=0.05):
    predict, truth = copy.deepcopy(predict_input), copy.deepcopy(truth_input)
    tli, tlj = [], []
    for i in predict:
        tli.append(float(i))
    for j in truth:
        tlj.append(float(j))
    tli.sort()
    tlj.sort()
    test_er = eng.computeER(eng.cell2mat(tli),eng.cell2mat(tlj),eng.double(win))
    return test_er

# Victor-Purpura distance measurement

def evaluate_vpd(predict_input, truth_input, cost=1):
    predict, truth = copy.deepcopy(predict_input), copy.deepcopy(truth_input)
    tli, tlj = [], []
    for i in predict:
        tli.append(float(i))
    for j in truth:
        tlj.append(float(j))
    tli.sort()
    tlj.sort()
    test_pvd = np.float(eng.spkd(eng.cell2mat(tli), eng.cell2mat(tlj), eng.double(cost))) / len(tlj)
    return test_pvd

# Implement Benchmark in a Leave-one-dataset-out Manner

In [11]:
class BENCHMARK(object):
    def __init__(self):
        
        self.DATA = [[]]*27
        self.MODEL = {}
        
        self.cluster = {2:[2,4],3:[3,5],4:[2,4],5:[3,5],
                        6:[6,7,8,9,10,11],7:[6,7,8,9,10,11],8:[6,7,8,9,10,11],9:[6,7,8,9,10,11],10:[6,7,8,9,10,11],11:[6,7,8,9,10,11],
                        12:[12,13,14,15,16],13:[12,13,14,15,16],14:[12,13,14,15,16],15:[12,13,14,15,16],16:[12,13,14,15,16],
                        17:[2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21],
                        18:[18,19],19:[18,19],
                        20:[2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21],21:[2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21],
                        22:[22,23],23:[22,23],
                        24:[24,25,26,27],25:[24,25,26,27],26:[24,25,26,27],27:[24,25,26,27]}
        
        self.anticluster = {2:[2,3,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21],
                            3:[2,3,4,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21],
                            4:[3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21],
                            5:[2,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21],
                            6:[2,3,4,5,6,12,13,14,15,16,17,18,19,20,21],
                            7:[2,3,4,5,7,12,13,14,15,16,17,18,19,20,21],
                            8:[2,3,4,5,8,12,13,14,15,16,17,18,19,20,21],
                            9:[2,3,4,5,9,12,13,14,15,16,17,18,19,20,21],
                            10:[2,3,4,5,10,12,13,14,15,16,17,18,19,20,21],
                            11:[2,3,4,5,11,12,13,14,15,16,17,18,19,20,21],
                            12:[2,3,4,5,6,7,8,9,10,11,12,17,18,19,20,21],
                            13:[2,3,4,5,6,7,8,9,10,11,13,17,18,19,20,21],
                            14:[2,3,4,5,6,7,8,9,10,11,14,17,18,19,20,21],
                            15:[2,3,4,5,6,7,8,9,10,11,15,17,18,19,20,21],
                            16:[2,3,4,5,6,7,8,9,10,11,16,17,18,19,20,21],
                            17:[2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21],
                            18:[2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,20,21],
                            19:[2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,19,20,21],
                            20:[2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21],
                            21:[2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21],
                            22:[22,24,25,26,27],
                            23:[23,24,25,26,27],
                            24:[22,23,24],25:[22,23,25],26:[22,23,26],27:[22,23,27]}
        
        
    def train(self, neuron='Exc', inputs='Raw', nets='UNet', losses='MSE', Fs='60', smoothing_std='0.025', smoothing_kernel='gaussian',
              cluster='None', hour='all', lr='0.001', kernel='3', node='150K', seg='96', batch='1024', es='500', verbose=0):
        
        global opt
        #### define test mode
        self.TEST = 'C'+datetime.datetime.fromtimestamp(time.time()).strftime('%Y%m%d')[2:]+'_'+neuron+'_'+Fs+'Hz_'+inputs+'_'+nets+'_'+losses+'_Test'

        if smoothing_std != '0.025':
            self.TEST = self.TEST+'_'+str(int(float(smoothing_std)*1000))+'ms'
        if smoothing_kernel == 'gaussian':
            opt.causal_kernel = False
            opt.gaussian_kernel = not opt.causal_kernel # use Gaussian smoothing kernel
        elif smoothing_kernel == 'causal':
            opt.causal_kernel = True
            opt.gaussian_kernel = not opt.causal_kernel # use causal smoothing kernel
            self.TEST = self.TEST+'_'+'causal'
        else:
            print("Smoothing kernel error...")
            return
            
        if cluster != 'None':
            self.TEST = self.TEST+'_'+cluster
        if hour != 'all':
            self.TEST = self.TEST+'_'+hour+'hr'
        
        if kernel != '3':
            self.TEST = self.TEST+'_'+kernel+'kernel'
        if node != '150K':
            self.TEST = self.TEST+'_'+node+'node'
            
        if lr != '0.001':
            opt.lr = float(lr)
            self.TEST = self.TEST+'_'+lr+'lr'
        if seg != '96':
            opt.signal_len = int(seg)
            self.TEST = self.TEST+'_'+seg+'seg'
        if batch != '1024':
            opt.batch_size = int(batch)
            self.TEST = self.TEST+'_'+batch+'batch'
        if es != '500':
            opt.patience = int(es)
            self.TEST = self.TEST+'_'+es+'es'
            
        print('【'+self.TEST+'】')
        fs_spec = (0,0)
        datasets, datasets_prop = [], []
        
        if neuron == 'Both':
            ds_on, ds_off = 2, 27
        elif neuron == 'Exc':
            ds_on, ds_off = 2, 21
        elif neuron == 'Inh':
            ds_on, ds_off = 22, 27
        else:
            print("Neuron type error...")
            return
        
        for dsets in range(ds_on, ds_off+1):
                        
            tqdm.write(f'dataset {dsets}: preparing data...')
            
            #### load or re-compile data
            opt.sampling_rate = float(Fs)
            opt.smoothing_std = float(smoothing_std)
            opt.smoothing = opt.smoothing_std * opt.sampling_rate
                    
            tqdm.write(f'Sampling rate is: {opt.sampling_rate}Hz, smoothing window is: {opt.smoothing_std*1000}ms')
            
            if fs_spec == (opt.sampling_rate, opt.smoothing_std):
                print('Using previous datasets...')
            else:
                print('Re-compiling datasets...')
                del datasets, datasets_prop
                datasets, _, datasets_prop = load_all_ground_truth(opt.sampling_rate)
                fs_spec = (opt.sampling_rate, opt.smoothing_std)

            #### initialize vault
            self.DATA[dsets-1] = {}
            self.DATA[dsets-1]['inputs'] = inputs
            self.DATA[dsets-1]['nets'] = nets
            self.DATA[dsets-1]['losses'] = losses
            self.DATA[dsets-1]['dataset'] = dsets
            self.DATA[dsets-1]['frame_rate'] = datasets_prop[dsets]['frame_rate']
            self.DATA[dsets-1]['noise_level'] = datasets_prop[dsets]['noise_level']
            self.DATA[dsets-1]['firing_rate'] = datasets_prop[dsets]['firing_rate']
            self.DATA[dsets-1]['neuron_number'] = datasets_prop[dsets]['neuron_number']
            
            self.DATA[dsets-1]['sampling_rate'] = opt.sampling_rate
            self.DATA[dsets-1]['correlation'] = []
            self.DATA[dsets-1]['error'] = []
            self.DATA[dsets-1]['bias'] = []
            self.DATA[dsets-1]['eucd'] = []
            self.DATA[dsets-1]['vpd'] = []
            self.DATA[dsets-1]['er50'] = []
            self.DATA[dsets-1]['er100'] = []
            self.DATA[dsets-1]['er500'] = []
            self.DATA[dsets-1]['gter50'] = []
            self.DATA[dsets-1]['loss'] = []

            self.DATA[dsets-1]['calcium'] = []
            self.DATA[dsets-1]['gt_rate'] = []
            self.DATA[dsets-1]['gt_spike'] = []
            self.DATA[dsets-1]['pd_rate'] = []
            self.DATA[dsets-1]['pd_spike'] = []
            self.DATA[dsets-1]['gt_event'] = []
            self.DATA[dsets-1]['pd_event'] = []
            self.DATA[dsets-1]['events_times'] = []
            
            start_time = datetime.datetime.now()
            self.model_ver = datetime.datetime.fromtimestamp(time.time()).strftime('%Y%m%d%H%M%S')[2:]+'_dsets'+str(dsets)+'_'+str(opt.sampling_rate)+'Hz_'+inputs+'_'+nets+'_'+losses

            self.DATA[dsets-1]['model_ver'] = self.model_ver
            
            #### prepare data
            random.seed(0)
            torch.manual_seed(0)
            np.random.seed(0)

            torch.cuda.empty_cache()
            cudnn.deterministic = True
            cudnn.benchmark = False
            
            test_trace = []
            test_rate = []
            test_spike = []
            test_event = []
            train_trace = np.zeros((800*10000,opt.signal_len),dtype='float32')
            train_rate = np.zeros((800*10000,opt.signal_len),dtype='float32')
            train_spike = np.zeros((800*10000,opt.signal_len),dtype='float32')
            train_trace[:] = np.nan
            train_rate[:] = np.nan
            train_spike[:] = np.nan
            count = 0
            cum_time = 0

            if cluster == 'cluster':
                dset_indexes = self.cluster[dsets]
            elif cluster == 'anticluster':
                dset_indexes = self.anticluster[dsets]
            else:
                dset_indexes = np.arange(ds_on, ds_off+1)
                
            for dset_index in dset_indexes:
                if dset_index == dsets:
                    for trial_index in range(len(datasets[dset_index])):
                        test_trace.append(datasets[dset_index][trial_index]['trace_seg'])
                        test_rate.append(datasets[dset_index][trial_index]['rate_seg'])
                        test_spike.append(datasets[dset_index][trial_index]['spike_seg'])
                        test_event.append(datasets[dset_index][trial_index]['events_times'])
                else:
                    if dset_index == 1:
                        continue
                    for trial_index in range(len(datasets[dset_index])):
  
                        tmp = datasets[dset_index][trial_index]['trace_seg']
                        increment = tmp.shape[0]
                        if not increment:
                            continue

                        train_trace[count:count+increment,:] = tmp
                        train_rate[count:count+increment,:] = datasets[dset_index][trial_index]['rate_seg']
                        train_spike[count:count+increment,:] = datasets[dset_index][trial_index]['spike_seg']
                        count += increment

            train_trace = train_trace[0:count,:]
            train_rate = train_rate[0:count,:]
            train_spike = train_spike[0:count,:]

            if hour != 'all':
                np.random.seed(dsets)
                shrink_idx = (np.random.random_sample(np.ceil(float(hour)*3600/(opt.signal_len/opt.sampling_rate)).astype('int')) * count).astype('int')
                train_trace = train_trace[shrink_idx,:]
                train_rate = train_rate[shrink_idx,:]
                train_spike = train_spike[shrink_idx,:]
            
            if np.sum(train_trace==np.nan) or np.sum(train_rate==np.nan) or np.sum(train_spike==np.nan):
                print('NaN error...')
                break

            Training_dataset = TensorDataset(torch.FloatTensor(train_trace),torch.FloatTensor(train_rate),torch.FloatTensor(train_spike))
            Training_dataloader = DataLoader(Training_dataset, shuffle=True, batch_size=opt.batch_size)

            tqdm.write(f'dataset {dsets}, pair sample {len(Training_dataset)}: start training...')
            
            #### initiate network
            
            if nets=='UNet':
                if kernel == '3':
                    kernel_size, padding_size = 3, 1
#                     print('Using UNet with 3-size kernels')
                elif kernel == '5':
                    kernel_size, padding_size = 5, 2
                    print('Using UNet with 5-size kernels')
                elif kernel == '7':
                    kernel_size, padding_size = 7, 3
                    print('Using UNet with 7-size kernels')
                else:
                    print('Network error...')
                    break

                if node == '50K':
                    init_features_num = 5
                    print('Using UNet with 50K nodes')
                elif node == '150K':
                    init_features_num = 9
#                     print('Using UNet with 150K nodes')
                elif node == '450K':
                    init_features_num = 16
                    print('Using UNet with 450K nodes')
                else:
                    print('Network error...')
                    break
            
                C = UNet(init_features=init_features_num, kernel_size=kernel_size, padding=padding_size).cuda()
            elif nets=='LeNet':
                C = LeNet().cuda()
            elif nets=='FCNet':
                C = FCNet().cuda()
            else:
                print('Network error...')
                break
                
            C.apply(weights_init_normal)
            C_optimizer = optim.Adam(C.parameters(), lr=opt.lr)
#             print('Learning rate is '+str(opt.lr))

            mse_loss = torch.nn.MSELoss().cuda()

            early_stopping = EarlyStopping(patience=opt.patience, verbose=True, delta=0.0000)
            is_earlystop = 0
            
            #### start training
            start_time = datetime.datetime.now()
            t = trange(1, opt.epochs+1, leave=True,  ncols=1000)
            for epoch in t:

                # extract training batch
                np.random.seed(epoch)
                rand_idx = (np.random.random_sample(opt.batch_size) * len(Training_dataset)).astype('int')
                [train_trace, train_rate, train_spike] = Training_dataset[rand_idx]
                train_trace = Variable(train_trace.type(FloatTensor))
                train_rate = Variable(train_rate.type(FloatTensor))
                train_spike = Variable(train_spike.type(FloatTensor))
                
                # train model
                C_optimizer.zero_grad()
                C.train()

                if losses=='MSE':
                    loss = mse_loss(C(train_trace), train_rate)
                elif losses=='EucD':
                    loss = eucd_loss(C(train_trace), train_rate, train_spike)
                elif losses=='Corr':
                    loss = -pearson_corr_loss(C(train_trace), train_rate)
                else:
                    print('Loss error...')
                    break

                loss.backward()
                C_optimizer.step()

                early_stopping(loss.item(), C)
                if early_stopping.early_stop:
                    print('Early stopping...Epoch '+str(epoch))
                    is_earlystop = 1
                    
                # gather status
                t.set_description(inputs+' '+nets+' ['+losses+': %0.3f]' % loss.item())

                self.DATA[dsets-1]['loss'].append(loss.item())
                
                # check performance
                if (epoch % (opt.sample_interval) == 0) or is_earlystop:

                    elapsed_time = datetime.datetime.now() - start_time
                    tqdm.write(f'[{inputs}][{nets}][{losses}]【Dset {dsets}】[Ep {epoch}][Time {str(elapsed_time):.7s}]')
                
                    for sub_i, (sub_trace,sub_rate,sub_spike,sub_event) in enumerate(zip(test_trace,test_rate,test_spike,test_event)):
                    
                        Testing_dataset = TensorDataset(torch.FloatTensor(sub_trace),torch.FloatTensor(sub_rate),torch.FloatTensor(sub_spike))
                        Testing_dataloader = DataLoader(Testing_dataset, shuffle=False, batch_size=opt.batch_size*4)
                        
                        # testing
                        start_time = datetime.datetime.now()
                        C.eval()
                        with torch.no_grad():
                            calcium = torch.zeros((0,opt.signal_len))
                            pd_rate_tmp = torch.zeros((0,opt.signal_len)).type(FloatTensor).cuda()
                            gt_rate = torch.zeros((0,opt.signal_len))
                            gt_spike = torch.zeros((0,opt.signal_len))

                            for batch, (test_data, test_rate_seg, test_spike_seg) in enumerate(Testing_dataloader):
                                test_output = C(Variable(test_data.type(FloatTensor)))

                                calcium = torch.cat([calcium, test_data], axis=0)
                                pd_rate_tmp = torch.cat([pd_rate_tmp, test_output], axis=0)
                                gt_rate = torch.cat([gt_rate, test_rate_seg], axis=0)
                                gt_spike = torch.cat([gt_spike, test_spike_seg], axis=0)

                        calcium = calcium[:, [opt.signal_len//2]].cpu().numpy()
                        gt_rate = gt_rate[:,[opt.signal_len//2]].cpu().numpy()
                        gt_spike = gt_spike[:,[opt.signal_len//2]].cpu().numpy()
                        pd_rate_tmp = pd_rate_tmp.cpu().numpy()

                        pd_rate = np.zeros((1,gt_rate.shape[0]+opt.signal_len-1))
                        for align_idx in range(pd_rate_tmp.shape[0]):
                            pd_rate[:,align_idx:align_idx+opt.signal_len] += pd_rate_tmp[align_idx,:]
                        pd_rate = pd_rate[:,opt.signal_len//2:opt.signal_len//2+gt_rate.shape[0]].transpose()/opt.signal_len

                        if (epoch == opt.epochs) or is_earlystop:
                            if verbose:
                                print('Estimate spike...')
                            pd_spike = estimate_spike(pd_rate, std=opt.smoothing)
                            pd_event = extract_event(pd_spike)
                            gt_event = extract_event(gt_spike)
                        else:
                            pd_spike = pd_rate * 0

                        elapsed_time = datetime.datetime.now() - start_time

                        self.DATA[dsets-1]['calcium'].append(calcium.squeeze())
                        self.DATA[dsets-1]['gt_rate'].append(gt_rate.squeeze())
                        self.DATA[dsets-1]['gt_spike'].append(gt_spike.squeeze())
                        self.DATA[dsets-1]['pd_rate'].append(pd_rate.squeeze())
                        self.DATA[dsets-1]['pd_spike'].append(pd_spike.squeeze())
                        self.DATA[dsets-1]['gt_event'].append(gt_event)
                        self.DATA[dsets-1]['pd_event'].append(pd_event)
                        self.DATA[dsets-1]['events_times'].append(sub_event)
            
                        # compute Corr, Error, Bias
                        test_corr, test_error, test_bias = evaluate_cascade(pd_rate, gt_rate)

                        self.DATA[dsets-1]['correlation'].append(test_corr)
                        self.DATA[dsets-1]['error'].append(test_error)
                        self.DATA[dsets-1]['bias'].append(test_bias)

                        # compute van Rossum distance
                        test_eucd = evaluate_eucd(pd_rate, gt_rate, gt_spike)

                        self.DATA[dsets-1]['eucd'].append(test_eucd)

                        # compute Victor-Purpura distance
                        if verbose:
                            print('Compute vpd...')
                        test_vpd = evaluate_vpd(pd_event, sub_event)

                        self.DATA[dsets-1]['vpd'].append(test_vpd)

                        # compute Error Rate
                        if verbose:
                            print('Compute error rate...')
                        test_gter50 = evaluate_er(extract_event(estimate_spike(gt_rate, std=opt.smoothing)), sub_event)
                        test_er50 = evaluate_er(pd_event, sub_event, 0.05) # ER with  50ms window

                        self.DATA[dsets-1]['gter50'].append(test_gter50)
                        self.DATA[dsets-1]['er50'].append(test_er50)

                        tqdm.write(f'[Dset {dsets}][Neuron {sub_i+1}]【Corr {test_corr*100:.2f}%】[vRD {test_eucd:.2f}][VPd {test_vpd:.2f}]【ER50: {test_er50*100:.2f}%】[GTER50: {test_gter50*100:.2f}%][Error {test_error:.2f}][Bias {test_bias:.2f}][Time {str(elapsed_time):.7s}]')
                
                if is_earlystop:
                    break

            self.MODEL[dsets] = C

            del Training_dataset, Training_dataloader, Testing_dataset, Testing_dataloader
            del train_trace, train_rate, train_spike
            del C

In [12]:
## Start benchmark

benchmark = BENCHMARK()
benchmark.train()

# del benchmark
print('############### END OF TESTING ###############')

【C220817_Exc_60Hz_Raw_UNet_MSE_Test】
dataset 2: preparing data...
Sampling rate is: 60.0Hz, smoothing window is: 25.0ms
Re-compiling datasets...
[D 1]  Raw Fs:  11.5  Noise:  0.7  FR:  1.8  #Neuron:  21  #Trial:  21  #Sample:  522354  Duration:  145  Name: DS01-OGB1-m-V1
[D 2]  Raw Fs:  15.1  Noise:  0.5  FR:  0.2  #Neuron:  16  #Trial:  47  #Sample:  415612  Duration:  115  Name: DS02-OGB1-2-m-V1
[D 3]  Raw Fs: 500.0  Noise:  0.3  FR:  1.2  #Neuron:   8  #Trial: 134  #Sample:   60217  Duration:   16  Name: DS03-Cal520-m-S1
[D 4]  Raw Fs:   7.7  Noise:  1.0  FR:  0.4  #Neuron:  15  #Trial:  42  #Sample:  269902  Duration:   74  Name: DS04-OGB1-zf-pDp
[D 5]  Raw Fs:   7.8  Noise:  1.6  FR:  1.3  #Neuron:   5  #Trial:  16  #Sample:  112243  Duration:   31  Name: DS05-Cal520-zf-pDp
[D 6]  Raw Fs:  30.0  Noise:  1.3  FR:  1.9  #Neuron:   8  #Trial:  23  #Sample:  163944  Duration:   45  Name: DS06-GCaMP6f-zf-aDp
[D 7]  Raw Fs:  30.0  Noise:  0.6  FR:  1.5  #Neuron:  10  #Trial:  35  #Sampl

  0%|                                                                                                         …

Early stopping...Epoch 2581
[Raw][UNet][MSE]【Dset 2】[Ep 2581][Time 0:02:31]
[Dset 2][Neuron 1]【Corr 65.23%】[vRD 2.71][VPd 0.88]【ER50: 78.63%】[GTER50: 0.43%][Error 0.91][Bias -0.76][Time 0:00:01]
[Dset 2][Neuron 2]【Corr 72.17%】[vRD 2.66][VPd 0.76]【ER50: 61.46%】[GTER50: 1.60%][Error 0.81][Bias -0.63][Time 0:00:01]
[Dset 2][Neuron 3]【Corr 51.71%】[vRD 3.08][VPd 0.93]【ER50: 93.33%】[GTER50: 1.75%][Error 1.07][Bias -0.70][Time 0:00:00]
[Dset 2][Neuron 4]【Corr 67.31%】[vRD 3.51][VPd 0.97]【ER50: 95.51%】[GTER50: 0.00%][Error 0.97][Bias -0.86][Time 0:00:00]
[Dset 2][Neuron 5]【Corr 58.97%】[vRD 2.74][VPd 0.96]【ER50: 92.86%】[GTER50: 0.00%][Error 1.06][Bias -0.67][Time 0:00:01]
[Dset 2][Neuron 6]【Corr 65.40%】[vRD 3.11][VPd 0.76]【ER50: 65.71%】[GTER50: 0.88%][Error 1.00][Bias -0.59][Time 0:00:00]
[Dset 2][Neuron 7]【Corr 52.17%】[vRD 2.73][VPd 0.90]【ER50: 87.50%】[GTER50: 0.00%][Error 1.20][Bias -0.49][Time 0:00:00]
[Dset 2][Neuron 8]【Corr 53.69%】[vRD 2.61][VPd 0.83]【ER50: 70.00%】[GTER50: 0.00%][Error 1.12

  0%|                                                                                                         …

Early stopping...Epoch 2656
[Raw][UNet][MSE]【Dset 3】[Ep 2656][Time 0:02:34]
[Dset 3][Neuron 1]【Corr 31.88%】[vRD 2.89][VPd 0.57]【ER50: 75.00%】[GTER50: 1.05%][Error 1.16][Bias -0.28][Time 0:00:00]
[Dset 3][Neuron 2]【Corr 38.39%】[vRD 3.29][VPd 0.53]【ER50: 75.61%】[GTER50: 2.33%][Error 1.91][Bias 0.55][Time 0:00:00]
[Dset 3][Neuron 3]【Corr 45.87%】[vRD 2.77][VPd 0.47]【ER50: 58.40%】[GTER50: 0.62%][Error 1.43][Bias 0.04][Time 0:00:00]
[Dset 3][Neuron 4]【Corr 27.65%】[vRD 3.20][VPd 0.34]【ER50: 80.52%】[GTER50: 1.28%][Error 1.57][Bias 0.21][Time 0:00:00]
[Dset 3][Neuron 5]【Corr 41.94%】[vRD 3.64][VPd 0.27]【ER50: 58.38%】[GTER50: 0.81%][Error 1.44][Bias 0.21][Time 0:00:00]
[Dset 3][Neuron 6]【Corr 35.28%】[vRD 3.54][VPd 0.36]【ER50: 63.11%】[GTER50: 0.56%][Error 1.24][Bias -0.10][Time 0:00:00]
[Dset 3][Neuron 7]【Corr 32.73%】[vRD 3.90][VPd 0.33]【ER50: 61.81%】[GTER50: 1.41%][Error 1.32][Bias -0.06][Time 0:00:00]
[Dset 3][Neuron 8]【Corr 39.91%】[vRD 3.97][VPd 0.29]【ER50: 65.35%】[GTER50: 2.33%][Error 1.65][Bi

  0%|                                                                                                         …

Early stopping...Epoch 1579
[Raw][UNet][MSE]【Dset 4】[Ep 1579][Time 0:01:32]
[Dset 4][Neuron 1]【Corr 58.53%】[vRD 2.67][VPd 0.48]【ER50: 46.67%】[GTER50: 0.00%][Error 1.67][Bias 0.81][Time 0:00:00]
[Dset 4][Neuron 2]【Corr 58.99%】[vRD 2.55][VPd 0.21]【ER50: 33.33%】[GTER50: 0.00%][Error 3.80][Bias 3.04][Time 0:00:00]
[Dset 4][Neuron 3]【Corr 66.16%】[vRD 2.53][VPd 0.24]【ER50: 31.82%】[GTER50: 0.00%][Error 2.57][Bias 1.94][Time 0:00:00]
[Dset 4][Neuron 4]【Corr 62.33%】[vRD 2.29][VPd 0.40]【ER50: 39.78%】[GTER50: 0.00%][Error 2.00][Bias 1.04][Time 0:00:00]
[Dset 4][Neuron 5]【Corr 51.83%】[vRD 2.36][VPd 0.52]【ER50: 56.06%】[GTER50: 0.13%][Error 1.31][Bias 0.07][Time 0:00:01]
[Dset 4][Neuron 6]【Corr 44.85%】[vRD 2.53][VPd 0.55]【ER50: 57.14%】[GTER50: 0.00%][Error 1.51][Bias 0.28][Time 0:00:00]
[Dset 4][Neuron 7]【Corr 50.93%】[vRD 2.60][VPd 0.67]【ER50: 64.09%】[GTER50: 1.54%][Error 1.15][Bias -0.23][Time 0:00:01]
[Dset 4][Neuron 8]【Corr 40.00%】[vRD 4.60][VPd 4.06]【ER50: 83.33%】[GTER50: 0.00%][Error 8.60][Bias

  0%|                                                                                                         …

Early stopping...Epoch 1709
[Raw][UNet][MSE]【Dset 5】[Ep 1709][Time 0:01:39]
[Dset 5][Neuron 1]【Corr 54.10%】[vRD 4.50][VPd 1.64]【ER50: 56.13%】[GTER50: 0.00%][Error 2.45][Bias 2.04][Time 0:00:00]
[Dset 5][Neuron 2]【Corr 40.60%】[vRD 5.49][VPd 3.50]【ER50: 77.29%】[GTER50: 0.00%][Error 4.66][Bias 4.06][Time 0:00:00]
[Dset 5][Neuron 3]【Corr 63.20%】[vRD 4.03][VPd 1.60]【ER50: 53.67%】[GTER50: 0.34%][Error 2.45][Bias 2.06][Time 0:00:00]
[Dset 5][Neuron 4]【Corr 58.85%】[vRD 2.90][VPd 0.58]【ER50: 47.15%】[GTER50: 0.13%][Error 1.48][Bias 0.77][Time 0:00:01]
[Dset 5][Neuron 5]【Corr 74.00%】[vRD 10.93][VPd 6.13]【ER50: 75.38%】[GTER50: 0.00%][Error 10.27][Bias 10.27][Time 0:00:00]
dataset 6: preparing data...
Sampling rate is: 60.0Hz, smoothing window is: 25.0ms
Using previous datasets...
dataset 6, pair sample 4239680: start training...


  0%|                                                                                                         …

Early stopping...Epoch 1492
[Raw][UNet][MSE]【Dset 6】[Ep 1492][Time 0:01:27]
[Dset 6][Neuron 1]【Corr 62.87%】[vRD 2.17][VPd 0.37]【ER50: 33.93%】[GTER50: 0.71%][Error 0.98][Bias 0.30][Time 0:00:01]
[Dset 6][Neuron 2]【Corr 55.71%】[vRD 3.63][VPd 1.14]【ER50: 47.16%】[GTER50: 0.47%][Error 1.78][Bias 1.35][Time 0:00:00]
[Dset 6][Neuron 3]【Corr 57.77%】[vRD 4.00][VPd 1.65]【ER50: 54.13%】[GTER50: 0.66%][Error 2.40][Bias 2.00][Time 0:00:00]
[Dset 6][Neuron 4]【Corr 70.90%】[vRD 2.58][VPd 0.73]【ER50: 35.92%】[GTER50: 1.44%][Error 1.40][Bias 0.93][Time 0:00:00]
[Dset 6][Neuron 5]【Corr 73.91%】[vRD 2.06][VPd 0.40]【ER50: 28.68%】[GTER50: 1.22%][Error 1.00][Bias 0.38][Time 0:00:01]
[Dset 6][Neuron 6]【Corr 66.78%】[vRD 2.17][VPd 0.40]【ER50: 30.07%】[GTER50: 0.00%][Error 1.32][Bias 0.71][Time 0:00:00]
[Dset 6][Neuron 7]【Corr 78.68%】[vRD 1.79][VPd 0.23]【ER50: 20.38%】[GTER50: 1.06%][Error 0.79][Bias 0.16][Time 0:00:00]
[Dset 6][Neuron 8]【Corr 79.58%】[vRD 1.68][VPd 0.25]【ER50: 19.54%】[GTER50: 0.74%][Error 0.65][Bias 

  0%|                                                                                                         …

Early stopping...Epoch 889
[Raw][UNet][MSE]【Dset 7】[Ep 889][Time 0:00:52]
[Dset 7][Neuron 1]【Corr 59.35%】[vRD 3.11][VPd 0.59]【ER50: 49.68%】[GTER50: 1.00%][Error 1.30][Bias 0.33][Time 0:00:00]
[Dset 7][Neuron 2]【Corr 61.46%】[vRD 2.29][VPd 0.36]【ER50: 39.42%】[GTER50: 0.33%][Error 1.08][Bias 0.29][Time 0:00:01]
[Dset 7][Neuron 3]【Corr 64.51%】[vRD 2.17][VPd 0.45]【ER50: 26.96%】[GTER50: 1.10%][Error 1.21][Bias 0.59][Time 0:00:00]
[Dset 7][Neuron 4]【Corr 62.59%】[vRD 2.24][VPd 0.39]【ER50: 34.45%】[GTER50: 0.00%][Error 1.11][Bias 0.41][Time 0:00:01]
[Dset 7][Neuron 5]【Corr 66.71%】[vRD 2.10][VPd 0.29]【ER50: 31.62%】[GTER50: 0.17%][Error 1.01][Bias 0.24][Time 0:00:00]
[Dset 7][Neuron 6]【Corr 52.34%】[vRD 2.24][VPd 0.45]【ER50: 47.31%】[GTER50: 0.12%][Error 0.94][Bias -0.19][Time 0:00:00]
[Dset 7][Neuron 7]【Corr 68.04%】[vRD 2.16][VPd 0.34]【ER50: 23.18%】[GTER50: 0.08%][Error 1.12][Bias 0.48][Time 0:00:01]
[Dset 7][Neuron 8]【Corr 77.37%】[vRD 2.01][VPd 0.30]【ER50: 20.22%】[GTER50: 0.20%][Error 1.10][Bias 0

  0%|                                                                                                         …

Early stopping...Epoch 702
[Raw][UNet][MSE]【Dset 8】[Ep 702][Time 0:00:41]
[Dset 8][Neuron 1]【Corr 72.08%】[vRD 3.89][VPd 0.68]【ER50: 52.48%】[GTER50: 1.42%][Error 0.74][Bias -0.63][Time 0:00:00]
[Dset 8][Neuron 2]【Corr 67.84%】[vRD 2.65][VPd 0.79]【ER50: 65.64%】[GTER50: 2.57%][Error 0.78][Bias -0.71][Time 0:00:00]
[Dset 8][Neuron 3]【Corr 65.53%】[vRD 3.07][VPd 0.81]【ER50: 68.58%】[GTER50: 3.48%][Error 0.80][Bias -0.75][Time 0:00:00]
[Dset 8][Neuron 4]【Corr 78.69%】[vRD 2.31][VPd 0.59]【ER50: 45.57%】[GTER50: 2.90%][Error 0.67][Bias -0.50][Time 0:00:01]
[Dset 8][Neuron 5]【Corr 67.62%】[vRD 2.15][VPd 0.29]【ER50: 31.61%】[GTER50: 1.51%][Error 0.71][Bias -0.14][Time 0:00:00]
[Dset 8][Neuron 6]【Corr 84.46%】[vRD 1.83][VPd 0.42]【ER50: 30.13%】[GTER50: 2.24%][Error 0.56][Bias -0.32][Time 0:00:01]
[Dset 8][Neuron 7]【Corr 60.18%】[vRD 2.60][VPd 0.82]【ER50: 71.32%】[GTER50: 2.65%][Error 0.82][Bias -0.74][Time 0:00:00]
[Dset 8][Neuron 8]【Corr 68.50%】[vRD 1.95][VPd 0.36]【ER50: 30.55%】[GTER50: 0.30%][Error 0.75][

  0%|                                                                                                         …

Early stopping...Epoch 1249
[Raw][UNet][MSE]【Dset 9】[Ep 1249][Time 0:01:12]
[Dset 9][Neuron 1]【Corr 71.84%】[vRD 2.85][VPd 0.45]【ER50: 28.92%】[GTER50: 0.89%][Error 1.09][Bias 0.57][Time 0:00:01]
[Dset 9][Neuron 2]【Corr 75.51%】[vRD 2.11][VPd 0.40]【ER50: 32.01%】[GTER50: 0.56%][Error 0.84][Bias -0.10][Time 0:00:00]
[Dset 9][Neuron 3]【Corr 69.74%】[vRD 3.74][VPd 0.80]【ER50: 35.88%】[GTER50: 0.34%][Error 1.33][Bias 0.92][Time 0:00:00]
[Dset 9][Neuron 4]【Corr 82.81%】[vRD 1.81][VPd 0.49]【ER50: 32.45%】[GTER50: 1.39%][Error 0.64][Bias -0.23][Time 0:00:02]
[Dset 9][Neuron 5]【Corr 79.93%】[vRD 2.61][VPd 0.52]【ER50: 23.46%】[GTER50: 0.00%][Error 1.38][Bias 1.00][Time 0:00:01]
[Dset 9][Neuron 6]【Corr 71.95%】[vRD 2.00][VPd 0.28]【ER50: 20.78%】[GTER50: 0.17%][Error 1.08][Bias 0.34][Time 0:00:01]
[Dset 9][Neuron 7]【Corr 86.64%】[vRD 1.65][VPd 0.31]【ER50: 22.04%】[GTER50: 0.75%][Error 0.77][Bias 0.11][Time 0:00:01]
[Dset 9][Neuron 8]【Corr 74.60%】[vRD 2.20][VPd 0.28]【ER50: 26.61%】[GTER50: 0.93%][Error 0.80][Bia

  0%|                                                                                                         …

Early stopping...Epoch 1583
[Raw][UNet][MSE]【Dset 10】[Ep 1583][Time 0:01:32]
[Dset 10][Neuron 1]【Corr 92.76%】[vRD 2.30][VPd 0.48]【ER50: 23.13%】[GTER50: 1.08%][Error 0.74][Bias 0.53][Time 0:00:00]
[Dset 10][Neuron 2]【Corr 93.11%】[vRD 1.62][VPd 0.23]【ER50: 13.07%】[GTER50: 1.27%][Error 0.50][Bias 0.27][Time 0:00:00]
[Dset 10][Neuron 3]【Corr 87.07%】[vRD 1.84][VPd 0.32]【ER50: 23.08%】[GTER50: 0.49%][Error 0.75][Bias 0.09][Time 0:00:00]
[Dset 10][Neuron 4]【Corr 89.74%】[vRD 1.97][VPd 0.37]【ER50: 22.00%】[GTER50: 1.10%][Error 1.16][Bias 0.74][Time 0:00:00]
[Dset 10][Neuron 5]【Corr 93.00%】[vRD 1.63][VPd 0.27]【ER50: 14.65%】[GTER50: 0.59%][Error 0.48][Bias 0.03][Time 0:00:00]
[Dset 10][Neuron 6]【Corr 89.15%】[vRD 1.86][VPd 0.34]【ER50: 23.54%】[GTER50: 0.44%][Error 0.50][Bias -0.16][Time 0:00:00]
[Dset 10][Neuron 7]【Corr 88.29%】[vRD 1.98][VPd 0.31]【ER50: 20.44%】[GTER50: 1.34%][Error 0.54][Bias -0.04][Time 0:00:00]
[Dset 10][Neuron 8]【Corr 89.74%】[vRD 1.58][VPd 0.27]【ER50: 17.76%】[GTER50: 0.50%][Error 

  0%|                                                                                                         …

Early stopping...Epoch 1150
[Raw][UNet][MSE]【Dset 11】[Ep 1150][Time 0:01:07]
[Dset 11][Neuron 1]【Corr 78.97%】[vRD 2.79][VPd 0.63]【ER50: 52.50%】[GTER50: 0.65%][Error 0.97][Bias -0.35][Time 0:00:00]
[Dset 11][Neuron 2]【Corr 83.47%】[vRD 2.77][VPd 0.55]【ER50: 42.64%】[GTER50: 0.62%][Error 1.19][Bias 0.11][Time 0:00:00]
[Dset 11][Neuron 3]【Corr 74.79%】[vRD 2.62][VPd 0.70]【ER50: 51.43%】[GTER50: 0.00%][Error 1.18][Bias -0.14][Time 0:00:00]
[Dset 11][Neuron 4]【Corr 47.85%】[vRD 3.37][VPd 1.22]【ER50: 61.76%】[GTER50: 0.00%][Error 2.03][Bias 0.95][Time 0:00:00]
[Dset 11][Neuron 5]【Corr 88.47%】[vRD 1.94][VPd 0.32]【ER50: 21.86%】[GTER50: 1.89%][Error 0.68][Bias 0.16][Time 0:00:00]
[Dset 11][Neuron 6]【Corr 85.09%】[vRD 1.93][VPd 0.39]【ER50: 22.54%】[GTER50: 0.00%][Error 0.83][Bias 0.32][Time 0:00:00]
[Dset 11][Neuron 7]【Corr 34.97%】[vRD 2.84][VPd 0.96]【ER50: 70.43%】[GTER50: 0.00%][Error 2.22][Bias 0.77][Time 0:00:00]
[Dset 11][Neuron 8]【Corr 31.86%】[vRD 2.68][VPd 0.91]【ER50: 87.10%】[GTER50: 0.00%][Error 

  0%|                                                                                                         …

Early stopping...Epoch 2317
[Raw][UNet][MSE]【Dset 12】[Ep 2317][Time 0:02:15]
[Dset 12][Neuron 1]【Corr 90.25%】[vRD 2.31][VPd 0.53]【ER50: 40.05%】[GTER50: 0.77%][Error 0.73][Bias -0.31][Time 0:00:00]
[Dset 12][Neuron 2]【Corr 77.45%】[vRD 3.59][VPd 1.49]【ER50: 48.19%】[GTER50: 0.00%][Error 2.63][Bias 2.30][Time 0:00:00]
[Dset 12][Neuron 3]【Corr 87.69%】[vRD 1.71][VPd 0.25]【ER50: 14.39%】[GTER50: 0.80%][Error 0.86][Bias 0.48][Time 0:00:00]
[Dset 12][Neuron 4]【Corr 93.43%】[vRD 1.53][VPd 0.29]【ER50: 12.39%】[GTER50: 0.00%][Error 0.76][Bias 0.66][Time 0:00:00]
[Dset 12][Neuron 5]【Corr 90.60%】[vRD 1.48][VPd 0.21]【ER50: 11.56%】[GTER50: 1.04%][Error 0.77][Bias 0.37][Time 0:00:00]
[Dset 12][Neuron 6]【Corr 79.97%】[vRD 2.03][VPd 0.43]【ER50: 24.43%】[GTER50: 0.00%][Error 1.23][Bias 0.63][Time 0:00:00]
dataset 13: preparing data...
Sampling rate is: 60.0Hz, smoothing window is: 25.0ms
Using previous datasets...
dataset 13, pair sample 4180479: start training...


  0%|                                                                                                         …

Early stopping...Epoch 1856
[Raw][UNet][MSE]【Dset 13】[Ep 1856][Time 0:01:48]
[Dset 13][Neuron 1]【Corr 80.63%】[vRD 2.17][VPd 0.50]【ER50: 27.17%】[GTER50: 0.44%][Error 1.08][Bias 0.70][Time 0:00:00]
[Dset 13][Neuron 2]【Corr 86.95%】[vRD 2.83][VPd 0.80]【ER50: 30.88%】[GTER50: 0.00%][Error 1.41][Bias 1.27][Time 0:00:00]
[Dset 13][Neuron 3]【Corr 87.04%】[vRD 1.94][VPd 0.45]【ER50: 26.87%】[GTER50: 1.05%][Error 0.99][Bias 0.55][Time 0:00:00]
[Dset 13][Neuron 4]【Corr 74.03%】[vRD 2.80][VPd 0.51]【ER50: 38.10%】[GTER50: 0.43%][Error 0.87][Bias -0.02][Time 0:00:00]
[Dset 13][Neuron 5]【Corr 87.91%】[vRD 2.07][VPd 0.33]【ER50: 25.08%】[GTER50: 0.62%][Error 0.74][Bias 0.07][Time 0:00:00]
[Dset 13][Neuron 6]【Corr 88.84%】[vRD 2.40][VPd 0.34]【ER50: 28.05%】[GTER50: 0.61%][Error 0.88][Bias 0.23][Time 0:00:00]
[Dset 13][Neuron 7]【Corr 88.60%】[vRD 2.14][VPd 0.43]【ER50: 28.50%】[GTER50: 0.47%][Error 0.82][Bias 0.13][Time 0:00:00]
[Dset 13][Neuron 8]【Corr 94.73%】[vRD 3.44][VPd 1.09]【ER50: 35.20%】[GTER50: 0.95%][Error 1

  0%|                                                                                                         …

Early stopping...Epoch 1573
[Raw][UNet][MSE]【Dset 14】[Ep 1573][Time 0:01:31]
[Dset 14][Neuron 1]【Corr 65.61%】[vRD 4.48][VPd 3.40]【ER50: 62.86%】[GTER50: 0.00%][Error 5.39][Bias 5.26][Time 0:00:00]
[Dset 14][Neuron 2]【Corr 79.66%】[vRD 2.09][VPd 0.24]【ER50: 23.48%】[GTER50: 0.11%][Error 1.22][Bias 0.55][Time 0:00:01]
[Dset 14][Neuron 3]【Corr 77.88%】[vRD 2.78][VPd 1.29]【ER50: 41.05%】[GTER50: 0.00%][Error 4.41][Bias 4.07][Time 0:00:00]
[Dset 14][Neuron 4]【Corr 77.53%】[vRD 3.29][VPd 1.26]【ER50: 39.92%】[GTER50: 1.00%][Error 2.08][Bias 1.94][Time 0:00:01]
[Dset 14][Neuron 5]【Corr 83.82%】[vRD 1.77][VPd 0.15]【ER50: 15.16%】[GTER50: 0.33%][Error 0.98][Bias 0.39][Time 0:00:01]
[Dset 14][Neuron 6]【Corr 67.18%】[vRD 5.15][VPd 4.06]【ER50: 66.93%】[GTER50: 1.18%][Error 8.29][Bias 8.28][Time 0:00:01]
[Dset 14][Neuron 7]【Corr 74.85%】[vRD 2.70][VPd 0.83]【ER50: 33.21%】[GTER50: 1.04%][Error 1.38][Bias 1.07][Time 0:00:01]
dataset 15: preparing data...
Sampling rate is: 60.0Hz, smoothing window is: 25.0ms
Using 

  0%|                                                                                                         …

Early stopping...Epoch 1605
[Raw][UNet][MSE]【Dset 15】[Ep 1605][Time 0:01:33]
[Dset 15][Neuron 1]【Corr 61.40%】[vRD 2.84][VPd 0.41]【ER50: 45.82%】[GTER50: 1.95%][Error 0.88][Bias -0.23][Time 0:00:01]
[Dset 15][Neuron 2]【Corr 55.12%】[vRD 2.75][VPd 0.36]【ER50: 48.02%】[GTER50: 1.26%][Error 0.99][Bias -0.06][Time 0:00:01]
[Dset 15][Neuron 3]【Corr 45.91%】[vRD 3.05][VPd 0.51]【ER50: 65.63%】[GTER50: 1.05%][Error 1.19][Bias -0.15][Time 0:00:01]
[Dset 15][Neuron 4]【Corr 57.60%】[vRD 3.20][VPd 0.47]【ER50: 48.20%】[GTER50: 1.67%][Error 1.02][Bias -0.08][Time 0:00:01]
[Dset 15][Neuron 5]【Corr 55.86%】[vRD 2.99][VPd 0.57]【ER50: 55.16%】[GTER50: 1.99%][Error 1.05][Bias -0.18][Time 0:00:00]
[Dset 15][Neuron 6]【Corr 61.72%】[vRD 2.79][VPd 0.44]【ER50: 50.39%】[GTER50: 1.95%][Error 0.91][Bias -0.24][Time 0:00:00]
[Dset 15][Neuron 7]【Corr 55.25%】[vRD 3.31][VPd 0.49]【ER50: 54.14%】[GTER50: 1.36%][Error 0.98][Bias -0.28][Time 0:00:02]
[Dset 15][Neuron 8]【Corr 52.56%】[vRD 3.11][VPd 0.47]【ER50: 56.75%】[GTER50: 0.83%][E

  0%|                                                                                                         …

Early stopping...Epoch 808
[Raw][UNet][MSE]【Dset 16】[Ep 808][Time 0:00:47]
[Dset 16][Neuron 1]【Corr 56.87%】[vRD 2.40][VPd 0.61]【ER50: 57.47%】[GTER50: 1.05%][Error 0.88][Bias -0.44][Time 0:00:00]
[Dset 16][Neuron 2]【Corr 60.05%】[vRD 2.23][VPd 0.53]【ER50: 49.29%】[GTER50: 0.84%][Error 0.89][Bias -0.32][Time 0:00:00]
[Dset 16][Neuron 3]【Corr 61.39%】[vRD 3.58][VPd 0.85]【ER50: 76.55%】[GTER50: 1.33%][Error 0.88][Bias -0.77][Time 0:00:00]
[Dset 16][Neuron 4]【Corr 66.32%】[vRD 3.24][VPd 0.87]【ER50: 77.56%】[GTER50: 1.89%][Error 0.84][Bias -0.81][Time 0:00:00]
[Dset 16][Neuron 5]【Corr 69.68%】[vRD 4.35][VPd 0.83]【ER50: 74.61%】[GTER50: 0.69%][Error 0.95][Bias -0.72][Time 0:00:00]
[Dset 16][Neuron 6]【Corr 58.07%】[vRD 3.76][VPd 0.81]【ER50: 75.00%】[GTER50: 0.84%][Error 0.92][Bias -0.70][Time 0:00:00]
[Dset 16][Neuron 7]【Corr 67.84%】[vRD 3.85][VPd 0.81]【ER50: 71.05%】[GTER50: 1.20%][Error 0.88][Bias -0.74][Time 0:00:00]
[Dset 16][Neuron 8]【Corr 58.42%】[vRD 4.54][VPd 0.88]【ER50: 81.91%】[GTER50: 1.28%][Err

  0%|                                                                                                         …

Early stopping...Epoch 1093
[Raw][UNet][MSE]【Dset 17】[Ep 1093][Time 0:01:03]
[Dset 17][Neuron 1]【Corr 72.48%】[vRD 2.24][VPd 0.56]【ER50: 43.03%】[GTER50: 0.30%][Error 1.23][Bias 0.31][Time 0:00:00]
[Dset 17][Neuron 2]【Corr 83.36%】[vRD 2.19][VPd 0.51]【ER50: 38.07%】[GTER50: 2.29%][Error 0.78][Bias -0.18][Time 0:00:00]
[Dset 17][Neuron 3]【Corr 89.62%】[vRD 2.30][VPd 0.54]【ER50: 37.26%】[GTER50: 1.52%][Error 0.61][Bias -0.45][Time 0:00:00]
[Dset 17][Neuron 4]【Corr 69.35%】[vRD 2.09][VPd 0.43]【ER50: 37.79%】[GTER50: 0.98%][Error 0.89][Bias -0.12][Time 0:00:00]
[Dset 17][Neuron 5]【Corr 84.39%】[vRD 2.00][VPd 0.59]【ER50: 35.38%】[GTER50: 0.00%][Error 1.32][Bias 0.75][Time 0:00:00]
[Dset 17][Neuron 6]【Corr 81.63%】[vRD 1.82][VPd 0.46]【ER50: 35.17%】[GTER50: 0.53%][Error 0.84][Bias -0.07][Time 0:00:00]
[Dset 17][Neuron 7]【Corr 80.34%】[vRD 2.63][VPd 0.59]【ER50: 46.49%】[GTER50: 1.80%][Error 0.84][Bias -0.32][Time 0:00:00]
[Dset 17][Neuron 8]【Corr 83.38%】[vRD 2.18][VPd 0.50]【ER50: 38.55%】[GTER50: 2.33%][Err

  0%|                                                                                                         …

Early stopping...Epoch 896
[Raw][UNet][MSE]【Dset 18】[Ep 896][Time 0:00:52]
[Dset 18][Neuron 1]【Corr 83.93%】[vRD 4.42][VPd 0.47]【ER50: 40.63%】[GTER50: 0.56%][Error 0.81][Bias -0.32][Time 0:00:02]
[Dset 18][Neuron 2]【Corr 77.50%】[vRD 3.53][VPd 0.24]【ER50: 34.10%】[GTER50: 0.85%][Error 0.98][Bias 0.10][Time 0:00:01]
[Dset 18][Neuron 3]【Corr 85.77%】[vRD 3.74][VPd 0.64]【ER50: 53.55%】[GTER50: 0.33%][Error 0.86][Bias -0.47][Time 0:00:00]
[Dset 18][Neuron 4]【Corr 56.45%】[vRD 4.42][VPd 0.63]【ER50: 53.48%】[GTER50: 0.31%][Error 1.46][Bias 0.41][Time 0:00:00]
dataset 19: preparing data...
Sampling rate is: 60.0Hz, smoothing window is: 25.0ms
Using previous datasets...
dataset 19, pair sample 4236787: start training...


  0%|                                                                                                         …

Early stopping...Epoch 1684
[Raw][UNet][MSE]【Dset 19】[Ep 1684][Time 0:01:38]
[Dset 19][Neuron 1]【Corr 70.13%】[vRD 2.77][VPd 0.80]【ER50: 70.31%】[GTER50: 0.47%][Error 0.87][Bias -0.64][Time 0:00:00]
[Dset 19][Neuron 2]【Corr 59.38%】[vRD 2.43][VPd 0.73]【ER50: 61.40%】[GTER50: 0.83%][Error 0.79][Bias -0.54][Time 0:00:00]
[Dset 19][Neuron 3]【Corr 76.66%】[vRD 2.05][VPd 0.48]【ER50: 30.72%】[GTER50: 0.60%][Error 1.07][Bias 0.17][Time 0:00:01]
[Dset 19][Neuron 4]【Corr 78.21%】[vRD 2.02][VPd 0.31]【ER50: 26.21%】[GTER50: 0.52%][Error 0.79][Bias 0.02][Time 0:00:01]
[Dset 19][Neuron 5]【Corr 83.07%】[vRD 1.84][VPd 0.32]【ER50: 18.18%】[GTER50: 0.00%][Error 1.16][Bias 0.31][Time 0:00:00]
[Dset 19][Neuron 6]【Corr 72.23%】[vRD 2.49][VPd 0.68]【ER50: 56.91%】[GTER50: 0.73%][Error 0.86][Bias -0.50][Time 0:00:01]
[Dset 19][Neuron 7]【Corr 66.32%】[vRD 2.48][VPd 0.62]【ER50: 53.49%】[GTER50: 1.82%][Error 1.08][Bias -0.12][Time 0:00:00]
[Dset 19][Neuron 8]【Corr 71.07%】[vRD 2.69][VPd 0.73]【ER50: 61.54%】[GTER50: 1.40%][Erro

  0%|                                                                                                         …

Early stopping...Epoch 797
[Raw][UNet][MSE]【Dset 20】[Ep 797][Time 0:00:46]
[Dset 20][Neuron 1]【Corr 62.89%】[vRD 4.24][VPd 0.66]【ER50: 63.74%】[GTER50: 1.09%][Error 1.15][Bias -0.33][Time 0:00:00]
[Dset 20][Neuron 2]【Corr 73.39%】[vRD 3.50][VPd 0.61]【ER50: 57.02%】[GTER50: 0.66%][Error 1.34][Bias 0.06][Time 0:00:00]
[Dset 20][Neuron 3]【Corr 43.55%】[vRD 6.37][VPd 3.43]【ER50: 70.37%】[GTER50: 0.00%][Error 21.90][Bias 21.29][Time 0:00:00]
[Dset 20][Neuron 4]【Corr 77.36%】[vRD 2.87][VPd 0.61]【ER50: 58.54%】[GTER50: 0.00%][Error 0.98][Bias -0.41][Time 0:00:00]
[Dset 20][Neuron 5]【Corr 73.35%】[vRD 2.35][VPd 0.53]【ER50: 43.48%】[GTER50: 0.00%][Error 1.50][Bias 0.43][Time 0:00:01]
[Dset 20][Neuron 6]【Corr 64.87%】[vRD 2.74][VPd 0.57]【ER50: 49.67%】[GTER50: 0.14%][Error 1.15][Bias -0.02][Time 0:00:01]
[Dset 20][Neuron 7]【Corr 76.96%】[vRD 3.43][VPd 0.50]【ER50: 46.34%】[GTER50: 0.00%][Error 1.53][Bias 0.40][Time 0:00:01]
[Dset 20][Neuron 8]【Corr 40.25%】[vRD 4.04][VPd 0.82]【ER50: 71.99%】[GTER50: 0.44%][Error

  0%|                                                                                                         …

Early stopping...Epoch 1716
[Raw][UNet][MSE]【Dset 21】[Ep 1716][Time 0:01:39]
[Dset 21][Neuron 1]【Corr 80.37%】[vRD 4.36][VPd 0.66]【ER50: 49.74%】[GTER50: 0.42%][Error 1.12][Bias -0.07][Time 0:00:01]
[Dset 21][Neuron 2]【Corr 74.41%】[vRD 3.60][VPd 1.33]【ER50: 55.18%】[GTER50: 0.00%][Error 2.83][Bias 2.02][Time 0:00:02]
[Dset 21][Neuron 3]【Corr 50.16%】[vRD 5.94][VPd 3.75]【ER50: 71.93%】[GTER50: 0.71%][Error 6.64][Bias 6.02][Time 0:00:01]
[Dset 21][Neuron 4]【Corr 59.15%】[vRD 3.87][VPd 0.77]【ER50: 67.57%】[GTER50: 1.53%][Error 1.04][Bias -0.47][Time 0:00:00]
[Dset 21][Neuron 5]【Corr 67.47%】[vRD 3.72][VPd 0.72]【ER50: 49.46%】[GTER50: 1.18%][Error 1.21][Bias 0.04][Time 0:00:01]
[Dset 21][Neuron 6]【Corr 80.74%】[vRD 4.04][VPd 0.65]【ER50: 43.54%】[GTER50: 0.17%][Error 1.00][Bias -0.05][Time 0:00:03]
[Dset 21][Neuron 7]【Corr 80.33%】[vRD 6.27][VPd 1.02]【ER50: 60.65%】[GTER50: 0.44%][Error 1.73][Bias 0.42][Time 0:00:01]
[Dset 21][Neuron 8]【Corr 63.81%】[vRD 3.43][VPd 0.64]【ER50: 54.46%】[GTER50: 0.78%][Error

<b>Notes:</b>
arguments could be changed to test other model configurations:

    - neuron = 'Exc', 'Inh', 'Both'                # select excitatory or inhibitory neurons
    - nets = 'UNet', 'LeNet', 'FCNet'              # define model structure
    - losses = 'MSE', 'EucD', 'Corr'               # define loss function
    - Fs = '60', '30', '15', '7.5'                 # define re-sampling rate
    - smoothing_std = '0.025','0.05','0.1','0.2'   # define smoothing window size
    - smoothing_kernel = 'gaussian', 'causal'      # define smoothing kernel shape
    - cluster = 'None', 'anticluster', 'cluster'   # select indicator types (correspond to "all", "different", "same" in manuscript)
    - hour = 'all', '20','10','5','1','0.5','0.1'  # amount of data used for training
    - lr = '0.001', '0.005', '0.0002'              # define learning rate
    - kernel = '3', '5', '7'                       # setting kernel sizes for U-Net
    - node = '50K','150K','450K'                   # setting node numbers for U-Net
    - seg = '32','64','96'                         # define segment length
    - batch = '256','512','1024','2048'            # define batch size
    - es = '2000','1000','500','300','100'         # define patience for early-stopping

<b>For example: </b>
<br>To test with different re-sampling rate, smoothing window size, and segment length:

In [13]:
# benchmark = BENCHMARK()
# benchmark.train(neuron='Exc', nets='UNet', losses='MSE', Fs='7.5', smoothing_std='0.2', seg='64')