In [None]:
import os; os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' 
import random
import numpy as np
import tensorflow as tf
import tensorflow_addons as tfa # Need to import since imported models use addons
import neurokit2 as nk
from heartkit.defines import HeartTask, HeartSegment, HeartBeat
from heartkit.datasets import IcentiaDataset
from heartkit.datasets.preprocess import preprocess_signal
from heartkit.hrv import compute_hrv
from neuralspot.tflite.model import get_strategy, load_model

import plotly.figure_factory as ff
import plotly.graph_objects as go
from plotly.subplots import make_subplots

physical_devices = tf.config.list_physical_devices('GPU')
try:
  tf.config.experimental.set_memory_growth(physical_devices[0], True)
except:
  pass


In [None]:
ds_path = "../datasets"
plotly_template = "plotly_dark"  # plotly
frame_size = 624
arrhythmia_frame_size = 1000
arrhythmia_threshold = 0.75
qrs_threshold = 0.7
sampling_rate = 250
data_len = 10 * sampling_rate
seg_len = 624
seg_olp = 25
arr_len = 1000
beat_len = 200

results = dict(heart_rate=0, num_beats=0, num_pac=0, num_pvc=0, rhythm=False, arrhythmia=False)


### Load models 

In [None]:
seg_model = load_model("../results/segmentation/model.tf")
arr_model = load_model("../results/arrhythmia/model.tf")
beat_model = load_model("../results/beat/model.tf")

### Load sample data

In [None]:
with tf.device('/CPU:0'):
    ds_icdb = IcentiaDataset(
        ds_path=ds_path,
        task=HeartTask.arrhythmia,
        frame_size=frame_size,
        target_rate=250,
    )

In [None]:
pt_gen = ds_icdb.uniform_patient_generator(patient_ids=ds_icdb.get_test_patient_ids(), repeat=False, shuffle=True)
data = None
true_arr_labels = None
for pt, segments in pt_gen:
    seg_key = random.choice(list(segments.keys()))
    frame_start = random.randint(5*sampling_rate, segments[seg_key]["data"].size - 2*data_len)
    frame_end = frame_start + data_len
    data = segments[seg_key]["data"][frame_start:frame_end]
    true_blabels = segments[seg_key]["blabels"][:]
    break
data = preprocess_signal(data=data, sample_rate=sampling_rate)

### Plot ECG data

In [None]:
fig = go.Figure()
t = np.linspace(0, data.shape[0]/250, data.shape[0], endpoint=False)
fig.add_trace(go.Scatter(x=t, y=data.squeeze()))
fig.update_layout(template=plotly_template, height=480)
fig.show()

### Apply arrhythmia model

In [None]:
arr_labels = np.zeros((data_len,))
for i in range(0, data_len - arr_len + 1, arr_len):    
    test_x = np.expand_dims(data[i : i + arr_len], axis=(0,1))
    y_prob = tf.nn.softmax(arr_model.predict(test_x)).numpy()
    y_pred = 1 if y_prob[0][1] > arrhythmia_threshold else 0
    arr_labels[i : i + arr_len] = y_pred
    print(f"{i} : {i + arr_len} = {y_pred} ({y_prob[0][1]:0.1%})")
# END FOR

arrhythmia_detected = np.any(arr_labels)
results["arrhythmia"] = arrhythmia_detected
if arrhythmia_detected:
    print("Arrhythmia onset detected")
else:
     print("No arrhythmia detected")

### Apply segmentation model

In [None]:
seg_mask = np.zeros((data_len,))
qrs_mask = np.zeros((data_len,))
for i in range(0, data_len - seg_len + 1, seg_len - 2 * seg_olp):
    test_x = np.expand_dims(data[i : i + seg_len], axis=(0, 1))
    y_prob = tf.nn.softmax(seg_model.predict(test_x)).numpy()
    y_pred = np.argmax(y_prob, axis=2)
    seg_mask[i + seg_olp : i + seg_len - seg_olp] = y_pred[0, seg_olp:-seg_olp]
    qrs_mask[i + seg_olp : i + seg_len - seg_olp] = np.where(y_prob[0, seg_olp:-seg_olp, 2] > qrs_threshold, 1, 0)
    print(f"{i}:{i + seg_len}")
# END FOR
if (data_len-seg_olp)-i:
    test_x = np.expand_dims(data[-seg_len:], axis=(0, 1))
    y_prob = tf.nn.softmax(seg_model.predict(test_x)).numpy()
    y_pred = np.argmax(y_prob, axis=2)
    seg_mask[-seg_len:-seg_olp] = y_pred[0, -seg_len:-seg_olp]
    qrs_mask[-seg_len:-seg_olp] =  np.where(y_prob[0, -seg_len:-seg_olp, 2] > qrs_threshold, 1, 0)


### Apply HRV

In [None]:
hr, rr_lens, rpeaks = compute_hrv(data, qrs_mask, sampling_rate)
avg_rr = int(sampling_rate/(hr/60))
results["heart_rate"] = hr
results["num_beats"] = len(rpeaks)
results["rhythm"] = "bradycardia" if hr < 60 else "normal" if hr <= 100 else "tachycardia" 

### Apply beat model

In [None]:
blabels = np.zeros_like(rpeaks)
for i in range(1, len(rpeaks) - 1):
    frame_start = rpeaks[i] - int(0.5 * beat_len)
    frame_end = frame_start + beat_len
    if frame_start - avg_rr < 0 or frame_end + avg_rr >= data.shape[0]:
        continue
    test_x = np.hstack((
        data[frame_start - avg_rr: frame_end - avg_rr],
        data[frame_start : frame_end],
        data[frame_start + avg_rr: frame_end + avg_rr],
    ))
    test_x = np.expand_dims(test_x, axis=(0, 1))
    y_prob = tf.nn.softmax(beat_model.predict(test_x)).numpy()
    y_pred = np.argmax(y_prob, axis=1)
    blabels[i] = y_pred[0]
results["num_pac"] = len(np.where(blabels == HeartBeat.pac)[0])
results["num_pvc"] = len(np.where(blabels == HeartBeat.pvc)[0])

In [None]:
results

In [None]:
fig = make_subplots(
    rows=2, cols=2, column_widths=[3, 1], horizontal_spacing=0.05, vertical_spacing=0.05,
    specs=[[{"rowspan": 2}, {"type": "table"}], [None, {}]],
    subplot_titles=[None, None, "R-R Distribution"]
)

# 1. Plot ECG data with colored segments 
t = np.linspace(0, data.shape[0]/sampling_rate, data.shape[0], endpoint=False)
pwave = np.where(seg_mask == HeartSegment.pwave, data.squeeze(), np.NAN)
qrs = np.where(seg_mask == HeartSegment.qrs, data.squeeze(), np.NAN)
twave = np.where(seg_mask == HeartSegment.twave, data.squeeze(), np.NAN)
norm = np.where(seg_mask == HeartSegment.normal, data.squeeze(), np.NAN)
fig.add_trace(go.Scatter(x=t, y=data.squeeze(), name='ECG'), row=1, col=1)
fig.add_trace(go.Scatter(x=t, y=pwave, name='P wave'), row=1, col=1)
fig.add_trace(go.Scatter(x=t, y=qrs, name='QRS'), row=1, col=1)
fig.add_trace(go.Scatter(x=t, y=twave, name='T Wave'), row=1, col=1)
fig.update_xaxes(title_text="Time (sec)", row=1, col=1)

for i in range(len(rpeaks)):
    if blabels[i] != HeartBeat.normal:
        label = "PAC" if blabels[i] == HeartBeat.pac else "PVC"
        fig.add_vline(x=rpeaks[i]/sampling_rate, line_color="red", annotation_text=label, annotation_font_color="red", annotation_textangle = 90, row=1, col=1)

# 2. Table
header = ['Metric', 'Value']
cells = [[
        "Heart Rate", 
        "Heart Rhythm",
        "Total Beats",        
        "Normal Beats", 
        "PAC Beats", 
        "PVC Beats", 
        "Arrhythmia"],
    [
        f'{results["heart_rate"]:0.0f} BPM', 
        results["rhythm"],
        results["num_beats"], 
        results["num_beats"]-(results["num_pac"] + results["num_pvc"]), 
        results["num_pac"], results["num_pvc"], 
        "Onset detected" if results["arrhythmia"] else "Not detected"],
]
fig.add_trace(go.Table(
    header=dict(values=header, height=40, font_size=16, align='left'),
    cells=dict(values=cells, height=30, font_size=14, align='left')
), row=1, col=2) 

# 3. Plot Poincare
dist_fig = ff.create_distplot([1000*rr_lens], ['R Peaks'], bin_size=10, show_rug=False, colors=['#835AF1'])
for trace in dist_fig.select_traces():
    fig.add_trace(trace, row=2, col=2)
fig.update_xaxes(title_text="Time (ms)", row=2, col=2)

fig.update_layout(
    template='plotly_dark', 
    height=720, 
    title_text="HeartKit Summary", 
    title_font_size=20,
    legend_orientation="h",
    margin=dict(l=40, r=40, t=60, b=30),
)
fig.write_html("../results/report.html")
fig.show()

In [None]:
fig = go.Figure()
fig.update_layout(template=plotly_template, height=480)
t = np.arange(beat_len)
for i in range(len(rpeaks)):
    frame_start = rpeaks[i] - int(0.5 * beat_len)
    frame_end = frame_start + beat_len
    fig.add_trace(go.Scatter(x=t, y=data[frame_start:frame_end].squeeze(), name=f"Beat {i+1}"))
fig.show()

In [None]:
# fig, ax = plt.subplots(figsize=(12, 5), layout="constrained")
# for i in range(len(rpeaks)):
#     frame_start = rpeaks[i] - int(0.5 * beat_len)
#     frame_end = frame_start + beat_len
#     ax.plot(data[frame_start:frame_end])

In [None]:
# color_map = {HeartBeat.normal: 'gray', HeartBeat.pac: 'purple', HeartBeat.pvc: 'red'}

# fig, ax = plt.subplots(figsize=(12, 5), layout="constrained")
# # ax.plot(rpeaks, data[rpeaks], '*', color='black')
# for i in range(len(rpeaks)):
#     c = color_map.get(blabels[i], 'black')
#     ax.axvline(x=rpeaks[i], color=c)
# plot_segmentations(data, preds=seg_mask, fig=fig, ax=ax)
# ax.set_xlim(40, 4500)
# ax.autoscale(enable=True, axis="y")
# ax.set_title(f"HR={hr:0.0f} bpm | RR={1000*avg_rr/sampling_rate:0.0f} ms")

In [None]:
#rpeaks, info = nk.ecg_peaks(data.squeeze(), sampling_rate=sampling_rate)
nk.hrv_time(rpeaks, sampling_rate=sampling_rate, show=True)
nk.hrv_nonlinear(rpeaks, sampling_rate=sampling_rate, show=True)
# ecg_rates = nk.ecg_rate(rpeaks, sampling_rate=sampling_rate, desired_length=len(data))
# edr = nk.ecg_rsp(ecg_rates, sampling_rate=sampling_rate)
# nk.signal_plot(edr)

In [None]:
seg_win = seg_len+2*seg_olp
data_win = data_len/(seg_win)
num_steps = np.ceil(data_len/data_win)
seg_olp2 = data_len/num_steps
print(seg_win, data_win, num_steps, seg_olp2)

In [None]:
num_steps

In [None]:
data_len

In [None]:
for i in range(0, data_len/4):
    sx = 
    ex = sx + seg_len + seg_olp