In [None]:
from pathlib import Path
import random
import datetime

import sleepkit as sk
import numpy as np
import tensorflow as tf
import plotly.graph_objects as go
from plotly.subplots import make_subplots


In [None]:
params = sk.defines.SKDemoParams(
    num_sleep_stages=4,
    frame_size=240,
    job_dir=Path("../results/demo"),
    ds_path=Path("../datasets/processed/fs001"),
    ds_handler="hdf5",
    ds_params=dict(
        feat_key="features",
        label_key="labels",
        mask_key="mask"
    ),
    model_file="../results/sleep-stage-4/model.tf",
)


In [None]:
bg_color = "rgba(38,42,50,1.0)"
primary_color = "#11acd5"
secondary_color = "#ce6cff"

def get_color_map(num_classes):
    if num_classes == 2:
        return {0: "gray", 1: "#1f77b4"}
    elif num_classes == 3:
        return {0: "gray", 1: "#1f77b4", 2: "#d62728"}
    elif num_classes == 4:
        return {0: "gray", 1: "#1f77b4", 2: "#ff7f0e", 3: "#d62728"}
    elif num_classes == 5:
        return {0: "gray", 1: "#2ca02c", 2: "#1f77b4", 3: "#ff7f0e", 4: "#d62728"}
    raise ValueError("Invalid number of classes")


In [None]:
sleep_classes = sk.defines.get_sleep_stage_classes(params.num_sleep_stages)
class_names = sk.defines.get_sleep_stage_class_names(params.num_sleep_stages)
class_mapping = sk.defines.get_sleep_stage_class_mapping(params.num_sleep_stages)
class_colors = get_color_map(params.num_sleep_stages)
feat_names = sk.features.get_feature_names_001()
plotly_template = 'plotly_dark'


In [None]:
# Load data handler
ds = sk.sleepstage.load_dataset(
    handler=params.ds_handler,
    ds_path=params.ds_path,
    frame_size=params.frame_size,
    params=params.ds_params
)


In [None]:
# Load model handler
model = sk.tflite.load_model(
    params.model_file,
    custom_objects={"MultiF1Score": sk.tflite.MultiF1Score}
)


In [None]:
# Load entire subject's features
subject_id = random.choice(ds.train_subject_ids)
features, _, _ = ds.load_subject_data(subject_id=subject_id, normalize=False)
x, y, m = ds.load_subject_data(subject_id=subject_id, normalize=True)
num_windows = int(features.shape[0] // params.frame_size)
data_len = params.frame_size * num_windows


In [None]:
# Run model (PC or EVB)
xx = x[:data_len, :].reshape((num_windows, params.frame_size) + ds.feature_shape[1:])
mm = m[:data_len].reshape((num_windows, params.frame_size))
y_prob = tf.nn.softmax(model.predict(xx, verbose=0)).numpy()
y_pred = np.argmax(y_prob, axis=-1).flatten()

features = features[:data_len, :]
y_mask = mm.flatten()
y_true = np.vectorize(class_mapping.get)(y[:data_len].flatten())
y_pred = y_pred[y_mask == 1]
y_true = y_true[y_mask == 1]
tod = datetime.datetime(2025, 5, 24, random.randint(12, 23), 00)
ts = [tod + datetime.timedelta(seconds=30*i) for i in range(data_len)]

pred_sleep_durations = sk.metrics.compute_sleep_stage_durations(y_pred)
act_sleep_durations = sk.metrics.compute_sleep_stage_durations(y_true)


In [None]:
fig = make_subplots(rows=1, cols=1, specs=[[{"secondary_y": True}]])

sleep_bounds = np.concatenate(([0], np.diff(y_pred).nonzero()[0]+1))
legend_groups = set()
for i in range(1, len(sleep_bounds)):
    start = sleep_bounds[i-1]
    stop = sleep_bounds[i]
    label = y_pred[start]
    name = class_names[label]
    color = class_colors.get(label, None)

    fig.add_trace(go.Scatter(
        x=[ts[start], ts[stop]],
        y=[label, label],
        mode='lines',
        line_shape='hv',
        name=name,
        legendgroup=name,
        showlegend=name not in legend_groups,
        line_color=color,
        line_width=4,
        fill='tozeroy',
        opacity=0.7,
    ), secondary_y=True)
    # END IF
    legend_groups.add(name)
# END FOR

for f in  [0, 4, 5, 8, 9, 11, 13]:
    name = feat_names[f].upper().replace("_", " ")
    feat_y = np.where(y_mask == 1, features[:data_len, f], np.nan)
    fig.add_trace(go.Scatter(x=ts, y=feat_y, name=name, opacity=0.5, visible='legendonly'), secondary_y=False)
# END FOR

fig.update_yaxes(
    autorange=False,
    range=[max(sleep_classes), min(sleep_classes)],
    ticktext=class_names,
    tickvals=list(range(len(class_names))),
    secondary_y=True
)

fig.update_layout(
    template="plotly_dark",
    height=500,
    plot_bgcolor=bg_color,
    paper_bgcolor=bg_color,
    margin=dict(l=10, r=5, t=40, b=20),
)
# fig.write_html("../docs/assets/sleep-stage-demo-example.html", include_plotlyjs='cdn', full_html=False)
fig.show()


In [None]:
fig = go.Figure()

# Use `hole` to create a donut-like pie chart
fig.add_trace(go.Pie(
    labels=class_names,
    values=[pred_sleep_durations.get(c, 0) for c in sleep_classes],
    textinfo="label+percent",
    textfont_size=15,
    hole=.4,
))
fig.add_annotation(
    text="Sleep<br>Cycles",
    x=0.5, y=0.5,
    showarrow=False,
    font=dict(
        size=22,
        color="rgb(240, 240, 240)"
    )
)
fig.update_layout(
    template=plotly_template,
    plot_bgcolor=bg_color,
    paper_bgcolor=bg_color,
    margin=dict(l=10, r=5, t=40, b=20),
    height=350,
)
# fig.write_html(f"../docs/assets/demo-sleep-cycle-pie.html", include_plotlyjs='cdn', full_html=False)
fig.show()


In [None]:
# fig = make_subplots(rows=1, cols=2, specs=[[{'type':'domain'}, {'type':'domain'}]], subplot_titles=['Actual', 'Predicted'])
# fig.add_trace(go.Pie(labels=class_names, values=[pred_sleep_durations.get(c, 0) for c in sleep_classes], hole=.3, name='Actual'), 1, 1)
# fig.add_trace(go.Pie(labels=class_names, values=[act_sleep_durations.get(c, 0) for c in sleep_classes], hole=.3, name='Predicted'), 1, 2)
# fig.update_traces(hole=.4, hoverinfo="label+percent+name+value")
# fig.update_layout(template='plotly_dark', height=400)
# fig.show()


In [None]:
# fig = make_subplots(rows=len(feat_names), cols=1, shared_xaxes=True, vertical_spacing=0.005)
# for f in range(len(feat_names)):
#     name = feat_names[f].upper().replace("_", " ")
#     feat_y = np.where(y_mask == 1, features[:data_len, f], np.nan)
#     fig.add_trace(go.Scatter(x=ts, y=feat_y, name=name), row=f+1, col=1)
# fig.update_layout(
#     height=1200,
#     showlegend=False,
#     template=plotly_template,
#     # title_text="Stacked Subplots with Shared X-Axes"
# )
# fig.show()


In [None]:
# fig = go.Figure()
# sleep_bounds = np.concatenate(([0], np.diff(y_pred).nonzero()[0]+1))
# legend_groups = set()
# for i in range(1, len(sleep_bounds)):
#     start = sleep_bounds[i-1]
#     stop = sleep_bounds[i]
#     label = y_pred[start]
#     name = class_names[label]
#     color = class_colors.get(label, None)

#     fig.add_trace(go.Scatter(
#         x=[ts[start], ts[stop]],
#         y=[label, label],
#         mode='lines',
#         line_shape='hv',
#         name=name,
#         legendgroup=name,
#         showlegend=name not in legend_groups,
#         line_color=color,
#         line_width=4,
#         fill='tozeroy',
#         opacity=0.7,
#     ))
#     legend_groups.add(name)

# fig.update_yaxes(autorange=False, range=[max(sleep_classes), min(sleep_classes)], ticktext=class_names, tickvals=list(range(len(class_names))))
# fig.update_layout(
#     template=plotly_template,
#     height=400,
# )
# fig.show()


In [None]:
# fig = make_subplots(rows=1, cols=1, specs=[[{"secondary_y": True}]])

# sleep_bounds = np.concatenate(([0], np.diff(y_pred).nonzero()[0]+1))
# legend_groups = set()
# for i in range(1, len(sleep_bounds)):
#     start = sleep_bounds[i-1]
#     stop = sleep_bounds[i]
#     label = y_pred[start]
#     name = class_names[label]
#     color = class_colors.get(label, None)

#     fig.add_trace(go.Scatter(
#         x=[ts[start], ts[stop]],
#         y=[label, label],
#         mode='lines',
#         line_shape='hv',
#         name=name,
#         legendgroup=name,
#         showlegend=name not in legend_groups,
#         line_color=color,
#         line_width=4,
#         fill='tozeroy',
#         opacity=0.7,
#     ), secondary_y=True)
#     # fig.add_shape(type="rect", x0=ts[start], y0=0, x1=ts[stop], y1=label, legendgroup=name, showlegend=False, opacity=0.5, fillcolor=color, line_width=2, line_color=color)
#     # END IF
#     legend_groups.add(name)
# # END FOR

# for f in  [0, 4, 5, 8, 9, 11, 13]:
#     name = feat_names[f].upper().replace("_", " ")
#     feat_y = np.where(y_mask == 1, features[:data_len, f], np.nan)
#     fig.add_trace(go.Scatter(x=ts, y=feat_y, name=name, opacity=0.5, visible='legendonly'), secondary_y=False)
# # END FOR

# fig.update_yaxes(
#     autorange=False,
#     range=[max(sleep_classes), min(sleep_classes)],
#     ticktext=class_names,
#     tickvals=list(range(len(class_names))),
#     secondary_y=True
# )

# fig.update_layout(
#     template="plotly_dark",
#     height=350,
#     plot_bgcolor=bg_color,
#     paper_bgcolor=bg_color,
#     margin=dict(l=10, r=5, t=40, b=20),
# )
# fig.write_html("../docs/assets/sleep-stage-demo-example.html", include_plotlyjs='cdn', full_html=False)
# fig.show()


In [None]:
# ts = np.arange(data_len)
# fig = go.Figure()

# # fig.add_trace(go.Scatter(x=ts, y=y_mask, name="Mask"))

# for k, v in class_colors.items():
#     fig.add_trace(go.Scatter(x=[0], y=[np.nan], name=class_names[k], line_color=v))
# # END FOR

# sleep_bounds = np.concatenate(([0], np.diff(y_true).nonzero()[0]+1))
# for i in range(1, len(sleep_bounds)):
#     start = sleep_bounds[i-1]
#     stop = sleep_bounds[i]
#     color = class_colors.get(y_true[start], None)
#     if color:
#         fig.add_shape(type="rect", x0=start, y0=-1.0, x1=stop, y1=0.25,  opacity=0.5, fillcolor=color, line_width=2, line_color=color)
# # END FOR

# sleep_bounds = np.concatenate(([0], np.diff(y_pred).nonzero()[0]+1))
# for i in range(1, len(sleep_bounds)):
#     start = sleep_bounds[i-1]
#     stop = sleep_bounds[i]
#     color = class_colors.get(y_pred[start], None)
#     if color:
#         fig.add_shape(type="rect", x0=start, y0=-0.25, x1=stop, y1=1.0,  opacity=0.5, fillcolor=color, line_width=2, line_color=color)
# # END FOR
# fig.update_layout(template=plotly_template, height=400)
# fig.show()
