In [None]:
import random

from matplotlib import pyplot as plt
import numpy as np
import torch

from heartbeat_detector.dataset.dataset import HeartbeatDataloaders
from heartbeat_detector.utils import seed_everything

In [None]:
plt.rcParams['figure.figsize'] = [20, 7]

In [None]:
seed_everything(420)
*__, test_dataloader = HeartbeatDataloaders(r'data\processed\2s_sin_with_channels\dataset.csv', 40, 1).get_train_validation_test_dataloaders()

In [None]:
DEVICE = 'cuda'

In [None]:
# model_path = r"out\charming-fish-930\checkpoints\unet1d_epoch_050.pth"
model_path = r"out\trusting-sow-659\checkpoints\unet1d_epoch_015.pth"
model = torch.load(model_path).to(DEVICE)
model.eval();

In [None]:
pred_batches = []
labels_bathes = []
signals_batches = []

with torch.no_grad():
    for __, signal_batch, label_batch in test_dataloader:
        signal_batch = signal_batch.to(DEVICE)
        label_batch = label_batch.to(DEVICE)

        preds_batch = model(signal_batch)

        signals_batches.append(signal_batch.cpu().numpy().squeeze())
        pred_batches.append(preds_batch.cpu().numpy().squeeze())
        labels_bathes.append(label_batch.cpu().numpy().squeeze())
        break

In [None]:
index = random.randint(0, len(pred_batches))

In [None]:
signal = signals_batches[0][index]
pred = pred_batches[0][index]
label = labels_bathes[0][index]

In [None]:
# Plot vanilla preds
alpha = 0.5
plt.plot(range(len(signal)), signal, 'b', label='Raw Signal', alpha=alpha)
plt.plot(range(len(label)), label, 'g', label='Ground Truth', alpha=alpha)
plt.plot(range(len(pred)), pred, 'r', label='Predictions', alpha=alpha)
plt.ylim([-1, 1])
plt.legend()
plt.show()

In [None]:
INDEX = 3400

signal_copy = np.copy(signal)
signal_copy[INDEX] = 1_000_000_000_000
signal_tensor = torch.Tensor(np.array([[signal_copy]])).to(DEVICE)

with torch.no_grad():
    new_pred = model(signal_tensor).cpu().numpy().flatten()

In [None]:
start = 0
stop = 10000

plt.plot(range(len(signal[start:stop])), signal[start:stop], 'b', label='Raw Signal', alpha=alpha)
plt.plot(range(len(label[start:stop])), label[start:stop], 'g', label='Ground Truth', alpha=alpha)
plt.plot(range(len(new_pred[start:stop])), new_pred[start:stop], 'r', label='Predictions', alpha=alpha)
plt.ylim([-1, 1])
plt.legend()
plt.show()

In [None]:
diff = new_pred - pred

eps = 1e-4
small_to_zero = (np.abs(diff) > eps) * diff
(diff_indexes, *__) = np.nonzero(small_to_zero)

In [None]:
receptive_field_width = diff_indexes[-1] - diff_indexes[0]

In [None]:
plt.plot(range(len(diff)), diff)
plt.xlim([diff_indexes[0], diff_indexes[-1]])
plt.xlabel('Signal index')
plt.ylabel('Predictions difference')
plt.title(f'Receptive field (width={receptive_field_width})');