In [6]:
import numpy as np
import plotly.graph_objects as go
from pathlib import Path

In [7]:
dir_orig_root = Path("NEK")
dir_csv_root = Path("out_csv")
dir_csv_root.mkdir(exist_ok=True)
dir_img_root = Path("out_img")
dir_img_root.mkdir(exist_ok=True)

In [8]:
part_idx = "0"
dir_orig = dir_orig_root / part_idx

In [9]:
train_csv = dir_csv_root / f"train_{part_idx}.csv"
test_csv = dir_csv_root / f"test_{part_idx}.csv"

In [10]:
train_label = np.load(dir_orig / "train_label.npy").astype(int)
train_value = np.load(dir_orig / "train.npy").astype(np.float32)
test_label = np.load(dir_orig / "test_label.npy").astype(int)
test_value = np.load(dir_orig / "test.npy").astype(np.float32)

In [11]:
with train_csv.open("w") as f:
    f.write("value,label\n")
    for value, label in zip(train_value, train_label):
        f.write(f"{value},{label}\n")

In [12]:
with test_csv.open("w") as f:
    f.write("value,label\n")
    for value, label in zip(test_value, test_label):
        f.write(f"{value},{label}\n")

In [13]:
import itertools
import numpy as np


def consecutive_label_and_idx(iterable):
    """合并连续的块，用以加速plotly的`fig.add_vrect`

    return: ((index, group_length), key)"""
    idx_last = 0
    for k, g in itertools.groupby(iterable):
        g_len = len(list(g))
        yield ((idx_last, g_len), k)
        idx_last += g_len


def test_consecutive_label_and_idx():
    idx_lens = []
    for idx_len, label in consecutive_label_and_idx(
        "".join(np.array([1, 1, 0, 0, 0, 1, 1]).astype(str))
    ):
        if label == "1":
            idx_lens.append(idx_len)
    print(idx_lens)
    assert idx_lens == [(0, 2), (5, 2)]


test_consecutive_label_and_idx()

[(0, 2), (5, 2)]


In [16]:
# train与test合在一起作图，好对比
from plotly.subplots import make_subplots

fig = make_subplots(rows=2, cols=1)

x = list(range(len(train_value)))
fig.add_trace(go.Scatter(x=x, y=train_value, mode="lines", name='train'), row=1, col=1)
# add label as shape
for idx_len, label in consecutive_label_and_idx("".join(train_label.astype(str))):
    idx, length = idx_len
    if label == "1":
        fig.add_vrect(
            x0=idx,
            x1=idx + length,
            line_width=0,
            fillcolor="red",
            opacity=0.2,
            row=1,
            col=1,
        )

x = list(range(len(test_value)))
fig.add_trace(go.Scatter(x=x, y=test_value, mode="lines", name='test'), row=2, col=1)
# add label as shape
for idx_len, label in consecutive_label_and_idx("".join(test_label.astype(str))):
    idx, length = idx_len
    if label == "1":
        fig.add_vrect(
            x0=idx,
            x1=idx + length,
            line_width=0,
            fillcolor="red",
            opacity=0.2,
            row=2,
            col=1,
        )

# 若要使用update_layout的rangeslider，应该row数要翻倍，其也算一个trace
# fig.update_layout(
#     hovermode='x',
#     xaxis=dict(
#         rangeslider=dict(
#             visible=True
#         ),
#     )
# )
fig.write_image(dir_img_root / f"allinone_{part_idx}.pdf", width=1280, height=400)
fig.show()