In [1]:
import pickle
import sys
from collections import defaultdict

import numpy as np
import pandas as pd
from tensorflow import keras

from src.cfd import CFD
from src.dataset import X_TIME, load_expanded_dataset_train_val
from src.models import optimal_model_builder_all_ch
from src.network_utils import train_model as _base_train_model, gaussian_kernel

PWD = '../../..'
sys.path.append(PWD)

OPTIMAL_CFD_THRESHOLDS_TRAIN_DATA_PATH = PWD + '/data/tmp/many_channels_tests/optimal_cfd_thresholds_train_data.pkl'

In [2]:
BASE_CFD_THRESHOLD = 0.2
CFD_N_BASELINE = 6

base_cfd = CFD(n_baseline=CFD_N_BASELINE, threshold=BASE_CFD_THRESHOLD)

In [3]:
N_BASELINE = 6

LR = 0.01
ES_MIN_DELTA = 0.01

N_EPOCHS = 10
BATCH_SIZE = 2048
LOSS_WEIGHT = 1000

ES_PATIENCE = 5
LR_PATIENCE = 50

In [4]:
train_dataset, val_dataset = load_expanded_dataset_train_val(PWD)
train_dataset[0].shape, val_dataset[0].shape

((26719,), (6680,))

In [5]:
with open(OPTIMAL_CFD_THRESHOLDS_TRAIN_DATA_PATH, 'rb') as file:
    optimal_cfd_thresholds = pickle.load(file)

print('Optimal thresholds:')
for (plane, channel), threshold in optimal_cfd_thresholds.items():
    print(f'({plane:>1}, {channel:>2}): {threshold:0.3f}')

cfd_by_channel = {(p, ch): CFD(n_baseline=N_BASELINE, threshold=thresh) for (p, ch), thresh in
                  optimal_cfd_thresholds.items()}

Optimal thresholds:
(1,  2): 0.180
(1, 11): 0.185
(2,  2): 0.220
(2, 11): 0.145
(3,  2): 0.130
(3, 11): 0.175


# Utils

In [6]:
def compute_true_t_cfd(t_cfd_avg: float, t_cfd: float, t_0: float) -> float:
    # use only timestamps from other channels
    return ((3 * t_cfd_avg) - (t_cfd + t_0)) / 2 - t_0


def compute_true_cfd_dataset(dataset_t_cfd_avg: np.ndarray, dataset_wav: dict, dataset_t0: dict) -> dict:
    true_cfd_dataset = defaultdict(list)
    for key in dataset_wav.keys():
        for i, t_cfd_avg in enumerate(dataset_t_cfd_avg):
            if dataset_t0[key][i] is None:
                true_cfd_dataset[key].append(None)
            else:
                wav = dataset_wav[key][i]
                t_0 = dataset_t0[key][i]
                t_cfd = base_cfd.predict_single(X_TIME, wav)
                true_cfd_dataset[key].append(compute_true_t_cfd(t_cfd_avg, t_cfd, t_0))

    return true_cfd_dataset


def build_nn_dataset(dataset_wav: dict, dataset_true_cfd: dict) -> tuple[np.ndarray, np.ndarray]:
    x, y = [], []
    for key, data in dataset_wav.items():
        cfd_data = dataset_true_cfd[key]
        for i in range(len(cfd_data)):
            if cfd_data[i] is None:
                continue
            wav = data[i]
            true_t_cfd = cfd_data[i]

            x.append(wav)
            y.append(true_t_cfd)

    x, y = np.array(x), np.array(y)

    # UNet
    y = np.array([gaussian_kernel(t) for t in y])

    return x, y


def build_and_train_network(iteration: int, x_train: np.ndarray, x_val: np.ndarray, y_train: np.ndarray,
                            y_val: np.ndarray, train: bool = True,
                            verbose: int = 2) -> tuple[keras.Model, pd.DataFrame]:
    model = optimal_model_builder_all_ch()
    name = f"optimal_it_{iteration}"
    history = _base_train_model(model, name, "iti_experiments", x_train, y_train, x_val, y_val, LR, train, N_EPOCHS,
                                verbose, BATCH_SIZE, LR_PATIENCE, ES_PATIENCE, ES_MIN_DELTA, LOSS_WEIGHT,
                                root=PWD + '/data')
    return model, history

In [7]:
true_cfd_train = compute_true_cfd_dataset(*train_dataset)
true_cfd_val = compute_true_cfd_dataset(*val_dataset)
x_train, y_train = build_nn_dataset(train_dataset[1], true_cfd_train)
x_val, y_val = build_nn_dataset(val_dataset[1], true_cfd_val)
x_train.shape, x_val.shape, y_train.shape, y_val.shape

((80157, 24), (20040, 24), (80157, 24), (20040, 24))

# Test

In [8]:
# 1. train model
model, _ = build_and_train_network(iteration=0, x_train=x_train, x_val=x_val, y_train=y_train, y_val=y_val, train=True)

Epoch 1/10
40/40 - 9s - loss: 688.6419 - val_loss: 271608.9062 - lr: 0.0100 - 9s/epoch - 215ms/step
Epoch 2/10
40/40 - 3s - loss: 102.5286 - val_loss: 38.2720 - lr: 0.0100 - 3s/epoch - 82ms/step
Epoch 3/10
40/40 - 3s - loss: 52.6943 - val_loss: 42.8193 - lr: 0.0100 - 3s/epoch - 78ms/step
Epoch 4/10
40/40 - 3s - loss: 31.2414 - val_loss: 34.8021 - lr: 0.0100 - 3s/epoch - 83ms/step
Epoch 5/10
40/40 - 3s - loss: 24.3111 - val_loss: 31.1145 - lr: 0.0100 - 3s/epoch - 83ms/step
Epoch 6/10
40/40 - 3s - loss: 21.2316 - val_loss: 28.5692 - lr: 0.0100 - 3s/epoch - 83ms/step
Epoch 7/10
40/40 - 3s - loss: 19.6412 - val_loss: 27.6635 - lr: 0.0100 - 3s/epoch - 83ms/step
Epoch 8/10
40/40 - 3s - loss: 18.7672 - val_loss: 26.3875 - lr: 0.0100 - 3s/epoch - 84ms/step
Epoch 9/10
40/40 - 3s - loss: 18.1920 - val_loss: 25.1725 - lr: 0.0100 - 3s/epoch - 84ms/step
Epoch 10/10
40/40 - 3s - loss: 17.6821 - val_loss: 23.4594 - lr: 0.0100 - 3s/epoch - 81ms/step
