In [None]:
import json
from functools import partial
from pathlib import Path

import numpy as np
import ipywidgets as widgets
import matplotlib.pyplot as plt

from egm_analyzer.models.onnx_wrapper import OnnxModelWrapper

# CHANGE ME

In [None]:
errors_filepath = Path(r"../out/25_ph1_errors.json")
signal_filepath = Path(r"D:\Datasets\lab\interim\X\X_25_ph1.npy")
ground_truth_filepath = Path(r"D:\Datasets\lab\raw\Y\Y_25_ph1.json")
prediction_filepath = Path("../out/25ph1.json")
model_path = Path(r"D:\repos\lab\lab\25ph1.onnx")

In [None]:
errors = json.loads(errors_filepath.read_text())
ground_truth = json.loads(ground_truth_filepath.read_text())
predictions = json.loads(prediction_filepath.read_text())

signal = np.load(signal_filepath)

if model_path:
    providers = [
        (
            'CUDAExecutionProvider',
            {
                'device_id': 0,
                'arena_extend_strategy': 'kNextPowerOfTwo',
                'gpu_mem_limit': 8 * 1024 * 1024 * 1024,
                'cudnn_conv_algo_search': 'EXHAUSTIVE',
                'do_copy_in_default_stream': True,
            },
        ),
        'CPUExecutionProvider',
    ]

    model = OnnxModelWrapper(model_path, providers)

In [None]:
class ErrorsView(object):
    def __init__(self, errors: dict[int, list[int]]) -> None:
        self._errors = errors
        self._current_channel = 0
        self._current_index = 0
    
    @property
    def current_channel(self) -> int:
        return self._current_channel
    
    @current_channel.setter
    def current_channel(self, channel):
        self._current_channel = channel
        self._current_index = 0
    
    @property
    def current_index(self) -> int:
        return self._current_index
    
    @current_index.setter
    def current_index(self,index):
        self._current_index = index
    
    @property
    def current_position(self) -> int:
        return self._errors[str(self.current_channel)][self.current_index]
    
    def get_next_index(self) -> int:
        current_index = min(self.current_index + 1, len(self._errors[str(self.current_channel)]) - 1)
        self.current_index = current_index
        
        return self._errors[str(self.current_channel)][self.current_index]
    
    def get_previous_index(self) -> int:
        current_index = max(self.current_index - 1, 0)
        self.current_index = current_index
        
        return self._errors[str(self.current_channel)][self.current_index]
        

In [None]:
errors_view = ErrorsView(errors["fp"])  # fn or fp are available

In [None]:
size = 40

right_btn = widgets.Button(description="Next", icon="arrow-right")
left_btn = widgets.Button(description="Previous", icon="arrow-left")
channel_selector = widgets.BoundedIntText(
    value=errors_view.current_channel,
    min=0,
    max=max(map(int, errors_view._errors)),
    step=1,
    description='Channel',
    disabled=False
)
hbox = widgets.HBox([channel_selector, left_btn, right_btn])
output = widgets.Output()

display(hbox, output)

def select_channel(change):
    with output:
        output.clear_output(wait=True)
        errors_view.current_channel = change["new"]
        next_ = errors_view.current_position
        
        plot(next_, size)

def plot(next_, size):
    fig, ax = plt.subplots(1, 1, dpi=100, figsize=(16 * 2 / 3, 9 * 2 / 3), tight_layout=True)
    ax.plot(range(next_ - size, next_ + size), signal[errors_view.current_channel][next_ - size : next_ + size], label="Signal")
    ax.axvline(next_, color='red', label="Error")
    
    if model_path:
        start = max(0, errors_view.current_position - 5000)
        stop = min(len(signal[errors_view.current_channel]), start + 10000)

        if stop - start < 10_000:
            start = stop - 10_000

        (predictions, *__), *__ = model.predict(signal[errors_view.current_channel][start:stop].reshape(1, 1, 10000))
        ax.plot(range(next_ - size, next_ + size), predictions[5000 - size : 5000 + size], 'orange', label="Model output", marker='.')
        ax.axhline(0.5, label='Threshold', color='purple')

    for heartbeat_location in ground_truth[errors_view.current_channel]:
        if next_ - size <= heartbeat_location < next_  + size:
            ax.axvline(heartbeat_location, ymin=0, ymax=1, color='green', alpha=0.5)

    ax.set_title(f"Channel = {errors_view.current_channel}")
    ax.legend()
    plt.show()
        
def on_button_clicked(b, direction='r'):
    match(direction):
        case 'r':
            next_ = errors_view.get_next_index()
        case 'l':
            next_ = errors_view.get_previous_index()
        case _:
            raise ValueError

    with output:
        output.clear_output(wait=True)
        
        plot(next_, size)

        

right_btn.on_click(partial(on_button_clicked, direction="r"))
left_btn.on_click(partial(on_button_clicked, direction="l"))
channel_selector.observe(select_channel, names="value")

plot(errors_view.current_position, size)
