In [None]:
import os
import random

import numpy as np
import torch
from matplotlib import pyplot as plt
from sklearn import metrics
from tqdm import tqdm

from heartbeat_detector.dataset.dataset import HeartbeatDataloaders
from heartbeat_detector.models.unet_1d import UNet1d
from heartbeat_detector.utils import seed_everything

In [None]:
plt.rcParams['figure.figsize'] = [20, 7]

In [None]:
seed_everything(420)
*__, test_dataloader = HeartbeatDataloaders(r'data\processed\2s_sin_with_channels\dataset.csv', 40, 2).get_train_validation_test_dataloaders()

In [None]:
DEVICE = 'cuda'

In [None]:
model_path = r"out\trusting-sow-659\checkpoints\unet1d_epoch_015.pth"
model = torch.load(model_path).to(DEVICE)
model.eval();

In [None]:
pred_batches = []
labels_bathes = []
signals_batches = []
filenames = []

with torch.no_grad():
    for filename_batch, signal_batch, label_batch in tqdm(test_dataloader, total=len(test_dataloader)):
        signal_batch = signal_batch.to(DEVICE)
        label_batch = label_batch.to(DEVICE)

        preds_batch = model(signal_batch)

        signals_batches.append(signal_batch.cpu().numpy().squeeze())
        pred_batches.append(preds_batch.cpu().numpy().squeeze())
        labels_bathes.append(label_batch.cpu().numpy().squeeze())
        filenames.extend(filename_batch)



In [None]:
pred = pred_batches[0][0]
label = labels_bathes[0][0]
signal = signals_batches[0][0]

In [None]:
preds = np.vstack(pred_batches).flatten()
labels = np.vstack(labels_bathes).flatten()
signals = np.vstack(signals_batches).flatten()

In [None]:
ground_truth = (labels >= 0.9985329) * 1

In [None]:
def get_metrics(
    target: np.ndarray,
    pred: np.ndarray,
    half_window_size: int
) -> tuple[float, float, float, list[int], list[int]]:

    tp, fp, fn = 0, 0, 0
    fp_indexes = []
    fn_indexes = []

    target_peaks_indexes = np.where(target == 1)[0]
    pred_peaks_indexes = np.where(pred == 1)[0]

    for pred_peak_index in pred_peaks_indexes:
        flag = False
        for window_mover in range(half_window_size + 1):
            if pred_peak_index - window_mover in target_peaks_indexes:
                tp += 1
                flag = True
                break
            elif pred_peak_index + window_mover in target_peaks_indexes:
                tp += 1
                flag = True
                break
        if not flag:
            fp += 1
            fp_indexes.append(pred_peak_index)
    
    for target_peak_index in target_peaks_indexes:
        flag = False
        for window_mover in range(half_window_size + 1):
            if target_peak_index - window_mover in pred_peaks_indexes:
                flag = True
                break
            elif target_peak_index + window_mover in pred_peaks_indexes:
                flag = True
                break
        if not flag:
            fn += 1
            fn_indexes.append(target_peak_index)

    return tp, fp, fn, fp_indexes, fn_indexes

def mean_compressor(pred: np.ndarray) -> int:
    return len(pred) // 2

def max_compressor(pred: np.ndarray) -> int:
    return np.argmax(pred)

from typing import Callable

def compress(
    pred: np.ndarray,
    treshold: float,
    comressor: Callable[[np.ndarray], int]
    ) -> np.ndarray:

    slices_starts = []
    slices_ends = []
    flag = False

    for index, elem in enumerate(pred):

        if (elem > treshold) and (flag == False):
            slices_starts.append(index)
            flag = True
        if (elem < treshold) and (flag == True):
            slices_ends.append(index)
            flag = False

    if flag:
        slices_ends.append(len(pred))

    peaks_indexes = []

    for slice_start, slice_end in zip(slices_starts, slices_ends):
        realtive_peak_index = comressor(pred[slice_start: slice_end])
        peaks_indexes.append(slice_start + realtive_peak_index)

    modified_pred = np.zeros(len(pred))

    for peak_index in peaks_indexes:
        modified_pred[peak_index] = 1
    
    return modified_pred



In [None]:
modified_pred_single = compress(pred, 0.3, max_compressor)

In [None]:
fig, ax = plt.subplots(1, 1, dpi=300, tight_layout=True, figsize=(7, 3))

ax.plot([x for x in range(len(pred))], pred, label='Предсказание модели')
ax.plot([x for x in range(len(pred))], modified_pred_single, label='Модифицированное предсказание', color='red')
ax.set_xlim(1950, 2110)
ax.set_ylim(-0.002, 1)
ax.legend()
ax.axis('off');

In [None]:
modify_pred = compress(preds, 0.3, max_compressor)


In [None]:
tp, fp, fn, fp_indexes, fn_indexes = get_metrics(ground_truth, modify_pred, 15)


In [None]:
tp, fp, fn

In [None]:
precision = tp / (tp + fp)
recall = tp / (tp + fn)

precision, recall

# Threshold 0.5, half window size 7
Precision, Recall = (0.9462863725146492, 0.9087498647300011)

# Threshold 0.3, half window size 15
Precision, Recall = (0.9609696279046722, 0.9565265636751452)

In [None]:
import random

In [None]:
def get_filename_by_global_index(index: int) -> str:
    return filenames[index // 10_000]

In [None]:
get_filename_by_global_index(79842 - 1000)

In [None]:
# fn_index = random.choice(fn_indexes)
# current_fn_pos = fn_index
for fn_index in fn_indexes[20:120:5]:
    fig, ax = plt.subplots(1, 1, tight_layout=True, dpi=300, figsize=(10.7, 6))
    eps = 250

    left = max(0, fn_index // 10000 * 10000)
    right = min(len(preds), (fn_index // 10000 + 1) * 10000)

    alpha = 0.6


    # ax.plot(range(left, right), modify_pred[left:right], 'r', label='Мод. предсказание', alpha=1)
    ax.plot(range(left + 3000, right - 3000), signals[left+3000:right-3000], 'b--', label='Исходный сигнал', alpha=alpha)
    ax.plot(range(left + 3000, right - 3000), labels[left+3000:right-3000], 'g', label='Модифицированная разметка', alpha=alpha)
    ax.set_ylim(-1, 1)
    ax.set_xlim(left+3000, right-3000)
    ax.legend()
    ax.axis('off')
    # ax.set_title(get_filename_by_global_index(fn_index))
    plt.show()

In [None]:
# fn_index = random.choice(fn_indexes)
# current_fn_pos = fn_index
for fn_index in fn_indexes[:20:5]:
    fig, ax = plt.subplots(1, 1, tight_layout=True, dpi=300, figsize=(10.7, 6))
    eps = 250

    left = max(0, fn_index // 10000 * 10000)
    right = min(len(preds), (fn_index // 10000 + 1) * 10000)

    alpha = 0.6

    mod = np.argwhere(modify_pred[left:right] > 0.5)
    ax.plot(mod.flatten() + left - 1, [1 for __ in mod], 'rD', label='Мод. предсказание', alpha=1)
    ax.plot(range(left, right), signals[left:right], 'b--', label='Исходный сигнал', alpha=alpha)
    ax.plot(range(left, right), labels[left:right], 'g', label='Исходная разметка', alpha=alpha)
    ax.set_ylim(-1, 2)
    ax.set_xlim(left, right)
    ax.legend()
    ax.axis('off')
    # ax.set_title(get_filename_by_global_index(fn_index))
    plt.show()

In [None]:
# fp_index = random.choice(fp_indexes)
# current_fp_pos = fp_index
for fp_index in fp_indexes[:100:10]:
    fig, ax = plt.subplots(1, 1, dpi=300, tight_layout=True, figsize=(10.7, 6))
    eps = 250

    left = max(0, fp_index - 1000)
    right = min(len(preds), fp_index + 1000)

    alpha = 0.6

    ax.plot(range(left, right), preds[left:right], 'r', label='Предсказание', alpha=alpha)
    ax.plot(range(left, right), signals[left:right], 'b--', label='Исходный сигнал', alpha=alpha)
    ax.plot(range(left, right), labels[left:right], 'g', label='Исходная разметка', alpha=alpha)
    ax.set_ylim(-2, 2)
    ax.set_xlim(left, right)
    ax.legend()
    ax.axis('off')
    plt.show()