In [82]:
import os
from pathlib import Path

import plotly.express as px
import plotly.graph_objects as go
import plotly.io as pio
from plotly.subplots import make_subplots
import numpy as np
import pandas as pd
import sleepkit as sk


In [132]:
plotly_template = "plotly_dark"
bg_color = "rgba(38,42,50,1.0)"
primary_color = "#11acd5" # "rgb(101, 110, 242)"
secondary_color = "#ce6cff"

pio.templates.default = plotly_template


In [3]:

dst_path = Path("../docs/assets")
src_path = Path("../results")
stage2_src_path = src_path / "sleep-stage-2"
stage3_src_path = src_path / "sleep-stage-3"
stage4_src_path = src_path / "sleep-stage-4"
stage5_src_path = src_path / "sleep-stage-5"


In [4]:

df_s2_metrics = pd.read_csv(stage2_src_path / "metrics.csv")
df_s2_results = pd.read_csv(stage2_src_path / "results.csv")
df_s3_metrics = pd.read_csv(stage3_src_path / "metrics.csv")
df_s3_results = pd.read_csv(stage3_src_path / "results.csv")
df_s4_metrics = pd.read_csv(stage4_src_path / "metrics.csv")
df_s4_results = pd.read_csv(stage4_src_path / "results.csv")
df_s5_metrics = pd.read_csv(stage5_src_path / "metrics.csv")
df_s5_results = pd.read_csv(stage5_src_path / "results.csv")


In [5]:
stages = [2, 3, 4, 5]
metrics = [df_s2_metrics, df_s3_metrics, df_s4_metrics, df_s5_metrics]
results = [df_s2_results, df_s3_results, df_s4_results, df_s5_results]


In [142]:
def plot_tst(df: pd.DataFrame, html_path: Path | None = None, json_path: Path | None = None, num_stages: int = 5):
    fig = px.scatter(
        df,
        x="act_tst", y="pred_tst",
        trendline="ols",
        hover_data=["subject", "acc"],
        color_discrete_sequence=[primary_color],
        labels={"act_tst": "Actual TST (min)", "pred_tst": "Predicted TST (min)"}
    )
    fig.update_layout(
        plot_bgcolor=bg_color,
        paper_bgcolor=bg_color,
        margin=dict(l=10, r=5, t=40, b=20),
        height=300,
        title=f"{num_stages} Stage Sleep: Actual vs. Predicted TST",
    )
    if html_path is not None:
        fig.write_html(html_path, include_plotlyjs='cdn', full_html=False)
    if json_path is not None:
        fig.write_json(json_path)
    return fig


def plot_eff(df: pd.DataFrame, html_path: Path | None = None, json_path: Path | None = None, num_stages: int = 5):
    fig = px.scatter(
        df,
        x="act_eff", y="pred_eff",
        trendline="ols",
        hover_data=["subject", "acc"],
        color_discrete_sequence=[primary_color],
        labels={"act_eff": "Actual Efficiency (%)", "pred_eff": "Predicted Efficiency (%)"}
    )
    fig.update_layout(
        template="plotly_dark",
        plot_bgcolor=bg_color,
        paper_bgcolor=bg_color,
        margin=dict(l=10, r=5, t=40, b=20),
        height=300,
        title=f"{num_stages} Stage Sleep: Actual vs. Predicted Efficiency",
    )
    if html_path is not None:
        fig.write_html(html_path, include_plotlyjs='cdn', full_html=False)
    if json_path is not None:
        fig.write_json(json_path)
    return fig


In [143]:
for i, stage in enumerate(stages):
    plot_tst(metrics[i], dst_path / f"stage-{stage}-tst.html", dst_path / f"stage-{stage}-tst.json", stage)
    plot_eff(metrics[i], dst_path / f"stage-{stage}-eff.html", dst_path / f"stage-{stage}-eff.json", stage)


In [136]:
ratios = [0, 1, 2, 4, 8, 16, 64]
flops = [4.58, 4.68, 4.67, 4.67, 4.66, 4.66]
params = [11184, 22932, 17145, 14237, 12783, 12134]
loss = [0.42717, 0.38867, 0.38385, 0.35740, 0.38308, 0.40652]

fig = make_subplots(specs=[[{"secondary_y": False}]])
fig.add_trace(go.Scatter(
    x=ratios,
    y=loss,
    mode='lines+markers',
    line_width=8,
    marker_size=8,
    line_color=primary_color,
    name="Loss"
), secondary_y=False)

fig.update_xaxes(title_text="SE Ratio")
fig.update_yaxes(title_text="Model Loss")
fig.update_layout(
    plot_bgcolor=bg_color,
    paper_bgcolor=bg_color,
    title="Performance vs Squeeze Excitation Ratio",
    margin=dict(l=10, r=5, t=40, b=20),
    height=300,
)
fig.write_html(dst_path / f"ablation-se-ratio.html", include_plotlyjs='cdn', full_html=False)
fig.show()


In [137]:
time = [32/2, 64/2, 96/2, 128/2, 192/2, 240/2, 256/2, 320/2]
flops = [0.63, 1.25, 1.87, 2.49, 3.73, 4.67, 4.98, 6.24]
loss = [0.49171, 0.46407, 0.41749, 0.40542, 0.3965, .3616, 0.3963, 0.4291]

fig = go.Figure()
fig = make_subplots(specs=[[{"secondary_y": True}]])

fig.add_trace(go.Scatter(
    x=time,
    y=loss,
    name="Loss",
    line_width=8,
    line_color=primary_color,
    marker_size=8,
), secondary_y=False)

fig.add_trace(go.Scatter(
    x=time,
    y=flops,
    name="FLOPS",
    line_width=8,
    line_color=secondary_color,
    marker_size=8,
), secondary_y=True)

fig.update_xaxes(title_text="Temporal Context (min)")
fig.update_yaxes(title_text="Model Loss", secondary_y=False, color=primary_color)
fig.update_yaxes(title_text="FLOPS (M)", secondary_y=True, color=secondary_color)
fig.update_layout(
    plot_bgcolor=bg_color,
    paper_bgcolor=bg_color,
    margin=dict(l=10, r=5, t=40, b=20),
    height=300,
    title="Performance vs Temporal Context",
)
fig.write_html(dst_path / f"ablation-temporal.html", include_plotlyjs='cdn', full_html=False)
fig.show()


In [138]:
names = ["Dilation", "No Dilation"]
loss = [0.3616, 0.46915]

# 1*1*5+1*1*5+2*1*5+2*1*5 = 30
# 8*5+4*5+2*2*5+2*1*5 = 90

fig = go.Figure()
fig.add_trace(go.Bar(
    x=loss,
    y=names,
    marker_color=[primary_color, secondary_color],
    orientation='h'
))

fig.update_xaxes(title_text="Model Loss")
fig.update_layout(
    plot_bgcolor=bg_color,
    paper_bgcolor=bg_color,
    margin=dict(l=10, r=5, t=40, b=20),
    height=300,
    title="Performance vs Dilation",
)
fig.write_html(dst_path / f"ablation-dilation.html", include_plotlyjs='cdn', full_html=False)
fig.show()


In [139]:
width = [0.5, 0.75, 1, 1.25, 1.5, 2.0]
flops = [1.53, 2.89, 4.67, 6.86, 9.48, 15.97]
params = [4865, 8906, 14237, 20686, 28473, 47573]
loss = [0.42014, 0.41343, 0.36162, 0.35719, 0.35887, 0.36043]

fig = make_subplots(specs=[[{"secondary_y": True}]])

fig.add_trace(go.Scatter(
    x=width,
    y=loss,
    line_width=8,
    line_color=primary_color,
    marker_size=8,
), secondary_y=False)

fig.add_trace(go.Scatter(
    x=width,
    y=flops,
    name="FLOPS",
    line_width=8,
    line_color=secondary_color,
    marker_size=8,
), secondary_y=True)

fig.update_xaxes(title_text="Model Width", tickvals=width)

fig.update_yaxes(title_text="Model Loss", secondary_y=False, color=primary_color)
fig.update_yaxes(title_text="FLOPS (M)", secondary_y=True, color=secondar_color)

fig.update_layout(
    plot_bgcolor=bg_color,
    paper_bgcolor=bg_color,
    margin=dict(l=10, r=5, t=40, b=20),
    height=300,
    title="Performance vs Model Width",
)
fig.write_html(dst_path / f"ablation-width.html", include_plotlyjs='cdn', full_html=False)
fig.show()


In [140]:
# Over-parameterized D=1,2,3,4
# 0.36162, 0.36479, 0.36549


In [144]:
# Kernel size K=3,5,7,9
size = [3, 5, 7, 9]
flops = [4.24, 4.67, 5.09, 5.52]
params = [13349, 14237, 15125, 16013]
loss = [0.38024, 0.36162, 0.36201, 0.36527]

fig = make_subplots(specs=[[{"secondary_y": True}]])


fig.add_trace(go.Scatter(
    x=size,
    y=loss,
    line_width=8,
    line_color=primary_color,
    marker_size=8,
))

fig.add_trace(go.Scatter(
    x=size,
    y=flops,
    name="FLOPS",
    line_width=8,
    line_color=secondary_color,
    marker_size=8,
), secondary_y=True)

fig.update_xaxes(title_text="Kernel Size", tickvals=size)

fig.update_yaxes(title_text="Model Loss", secondary_y=False, color=primary_color)
fig.update_yaxes(title_text="FLOPS (M)", secondary_y=True, color=secondary_color)

fig.update_layout(
    plot_bgcolor=bg_color,
    paper_bgcolor=bg_color,
    margin=dict(l=10, r=5, t=40, b=20),
    height=300,
    title="Performance vs Kernel Size",
)
fig.write_html(dst_path / f"ablation-kernelsize.html", include_plotlyjs='cdn', full_html=False)
fig.show()
