In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import torch
# ^^^ pyforest auto-imports - don't write above this line
%load_ext autoreload
%autoreload 2

In [None]:
from tqdm.auto import tqdm
from copy import deepcopy

from torch.nn.parallel import DistributedDataParallel as DDP, DataParallel as DP

from dataset import LUDB
from cfg import TrainCfg, ModelCfg
from trainer import LUDBTrainer
from model import ECG_UNET_LUDB
from metrics import compute_metrics

from torch_ecg.utils import mask_to_intervals

In [None]:
TrainCfg.db_dir = "/home/wenhao/Jupyter/wenhao/data/LUDB/"

In [None]:
train_cfg_fl = deepcopy(TrainCfg)
train_cfg_fl.use_single_lead = False
train_cfg_fl.loss = "FocalLoss"

train_cfg_ce = deepcopy(TrainCfg)
train_cfg_ce.use_single_lead = False
train_cfg_ce.loss = "CrossEntropyLoss"

In [None]:
ds_train_fl = LUDB(train_cfg_fl, training=True, lazy=False)
# ds_train_ce = LUDB(train_cfg_ce, training=True, lazy=False)

In [None]:
# ds_train_fl._load_all_data()
# ds_train_ce._load_all_data()

In [None]:
ds_val_fl = LUDB(train_cfg_fl, training=False, lazy=False)
# ds_val_ce = LUDB(train_cfg_ce, training=False, lazy=False)

In [None]:
# ds_val_fl._load_all_data()
# ds_val_ce._load_all_data()

In [None]:
# train_cfg_fl.keep_checkpoint_max = 0
# train_cfg_fl.monitor = None
# train_cfg_fl.n_epochs = 10

## dry run: no augmentation, no preprocessing

In [None]:
# train_config = deepcopy(TrainCfg)
model_config = deepcopy(ModelCfg)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
model = ECG_UNET_LUDB(model_config.n_leads, model_config)

In [None]:
model.module_size_

In [None]:
if torch.cuda.device_count() > 1:
    model = DP(model)
    # model = DDP(model)
model.to(device=device)

In [None]:
trainer = LUDBTrainer(
    model=model,
    model_config=model_config,
    train_config=train_cfg_fl,
    device=device,
    lazy=True,
)

In [None]:
trainer._setup_dataloaders(ds_train_fl, ds_val_fl)

In [None]:
bmd = trainer.train()

## eval and plot

In [None]:
model, _ = ECG_UNET_LUDB.from_checkpoint("/home/wenhao/.cache/torch_ecg/saved_models/BestModel_ECG_UNET_LUDB_epoch136_09-08_10-36_metric_0.97.pth.tar")

In [None]:
model

In [None]:
model_output = model.inference(ds_val_fl.signals[0])

In [None]:
model_output.mask

In [None]:
fig, ax = plt.subplots(figsize=(20,6))
ax.plot(ds_val_fl.signals[0][0],color="black")
ax2 = ax.twinx()
ax2.plot(model_output.mask[0], color="red")
plt.show()

In [None]:
values = ds_val_fl.signals[0][2]
mask_labels = np.where(ds_val_fl.labels[0][0]==1)[1]
mask_preds = model_output.mask[0]

In [None]:
mapping = {
    1: "P wave",
    2: "QRS",
    3: "T wave",
}
# colors = ['#4477AA', '#EE6677', '#228833', '#CCBB44', '#66CCEE', '#AA3377', '#BBBBBB']  # bright
colors = ['#004488', '#DDAA33', '#BB5566']  # high-contrast
pallete = {
    "P wave": colors[0],
    "QRS": colors[2],
    "T wave": colors[1],
}

In [None]:
intervals_preds = mask_to_intervals(mask_preds, vals=[1,2,3])
intervals_preds = {mapping[k]:v for k,v in intervals_preds.items()}
intervals_labels = mask_to_intervals(mask_labels, vals=[1,2,3])
intervals_labels = {mapping[k]:v for k,v in intervals_labels.items()}

In [None]:
# plt.rcParams['xtick.labelsize']=36
# plt.rcParams['ytick.labelsize']=36
# plt.rcParams['axes.labelsize']=50
# plt.rcParams['legend.fontsize']=40
plt.rcParams['xtick.labelsize']=24
plt.rcParams['ytick.labelsize']=24
plt.rcParams['axes.labelsize']=32
plt.rcParams['legend.fontsize']=24

In [None]:
from matplotlib.patches import Patch

fig, ax = plt.subplots(figsize=(20,8))
fs = ds_val_fl.reader.fs
ax.plot(np.arange(len(values)) / fs, values, color="black", lw=1.2)
split_y = 0.35
ax.set_xlim(-150 / fs, 5150 / fs)
ax.set_ylim(-0.5, 0.9)

# ax.xaxis.set_major_locator(plt.MultipleLocator(0.2))
# ax.yaxis.set_major_locator(plt.MultipleLocator(0.5))
# ax.grid(
#     which="major", linestyle="-", linewidth="0.4", color="red"
# )
# ax.xaxis.set_minor_locator(plt.MultipleLocator(0.04))
# ax.yaxis.set_minor_locator(plt.MultipleLocator(0.1))
# ax.grid(
#     which="minor", linestyle=":", linewidth="0.2", color="gray"
# )
# ax.set_xticks(np.arange(0,11,1))

for wave, l_itvs in intervals_preds.items():
    for itv in l_itvs:
        ax.axvspan(itv[0] / fs, itv[1] / fs, ymin=split_y+0.02, color=pallete[wave], alpha=0.4)
for wave, l_itvs in intervals_labels.items():
    for itv in l_itvs:
        ax.axvspan(itv[0] / fs, itv[1] / fs, ymax=split_y-0.02, color=pallete[wave], alpha=0.6)
ax.axhline(0, color="red", linewidth=2, linestyle="dotted")
ax.text(-110 / fs, 0.8, "Lead III", fontsize=28)
ax.text(5300 / fs, -0.48, "Label Mask", fontsize=28, rotation=90)
ax.text(5300 / fs, 0.16, "Predicted Mask", fontsize=28, rotation=90)
ax.set_xlabel("Time (s)")
ax.set_ylabel("Voltage (mV)")
legend_elements = [
    Patch(facecolor=v, label=k, alpha=0.5) for k,v in pallete.items()
]
ax.legend(
    handles=legend_elements,
    loc="lower center",
    bbox_to_anchor=(0.5, 0.99),
    ncol=len(pallete),
    fancybox=True,
);
ax.set_xticks(np.arange(0,10.5,0.5));
ax.grid(
    which="major", linestyle=":", linewidth="0.6", color="gray"
);

# plt.savefig("./images/ludb-unet-val-example-small.pdf", dpi=1200, bbox_inches="tight", transparent=False);
# plt.savefig("./images/ludb-unet-val-example-small.svg", dpi=1200, bbox_inches="tight", transparent=False);

In [None]:
from matplotlib.patches import Patch

fig, ax = plt.subplots(figsize=(120,12))
fs = ds_val_fl.reader.fs
ax.plot(np.arange(len(values)) / fs, values, color="black", lw=1.2)
split_y = 0.35
ax.set_xlim(-150 / fs, 5150 / fs)
ax.set_ylim(-0.6, 1.6)

ax.xaxis.set_major_locator(plt.MultipleLocator(0.2))
ax.yaxis.set_major_locator(plt.MultipleLocator(0.5))
ax.grid(
    which="major", linestyle="-", linewidth="0.4", color="red"
)
ax.xaxis.set_minor_locator(plt.MultipleLocator(0.04))
ax.yaxis.set_minor_locator(plt.MultipleLocator(0.1))
ax.grid(
    which="minor", linestyle=":", linewidth="0.2", color="gray"
)
# ax.set_xticks(np.arange(0,11,1))

for wave, l_itvs in intervals_preds.items():
    for itv in l_itvs:
        ax.axvspan(itv[0] / fs, itv[1] / fs, ymin=split_y+0.02, color=pallete[wave], alpha=0.4)
for wave, l_itvs in intervals_labels.items():
    for itv in l_itvs:
        ax.axvspan(itv[0] / fs, itv[1] / fs, ymax=split_y-0.02, color=pallete[wave], alpha=0.6)
# ax.axhline(0, color="red", linewidth=2, linestyle="dotted")
ax.text(-110 / fs, 1.2, "Lead III", fontsize=28)
ax.text(5200 / fs, -0.55, "Label Mask", fontsize=28, rotation=90)
ax.text(5200 / fs, 0.35, "Predicted Mask", fontsize=28, rotation=90)
ax.set_xlabel("Time (s)")
ax.set_ylabel("Voltage (mV)")
legend_elements = [
    Patch(facecolor=v, label=k, alpha=0.5) for k,v in pallete.items()
]
ax.legend(
    handles=legend_elements,
    loc="lower left",
#     bbox_to_anchor=(0.5, 0.99),
#     ncol=len(pallete),
    fancybox=True,
);
# ax.set_xticks(np.arange(0,10.5,0.5));
# ax.grid(
#     which="major", linestyle=":", linewidth="0.6", color="gray"
# );

# plt.savefig("./images/ludb-unet-val-example-large.pdf", dpi=1200, bbox_inches="tight", transparent=False);
# plt.savefig("./images/ludb-unet-val-example-large.svg", dpi=1200, bbox_inches="tight", transparent=False);

## gather results

In [None]:
import seaborn as sns
from matplotlib.pyplot import cm
sns.set()
colors = plt.rcParams['axes.prop_cycle'].by_key()['color']
markers = ["*", "v", "p", "d", "s", "$\heartsuit$", "+", "x", ]
marker_size = 9
plt.rcParams['xtick.labelsize']=24
plt.rcParams['ytick.labelsize']=24
plt.rcParams['axes.labelsize']=32
plt.rcParams['legend.fontsize']=20

marker_size = 9

In [None]:
df_res = pd.read_csv("/home/wenhao/Jupyter/wenhao/workspace/torch_ecg/benchmarks/train_unet_ludb/log/TorchECG_04-06_22-30_ECG_UNET_LUDB_adamw_amsgrad_LR_0.001_BS_32.csv")

In [None]:
df_res

In [None]:
fig, ax = plt.subplots(figsize=(16,12))

line_width = 2.5

df_val = df_res[df_res.part=="val"]
df_train = df_res[df_res.part=="train"].dropna(subset=["loss"])
lns1 = ax.plot(
    df_val.epoch, df_val.f1_score,
    marker=markers[0], linewidth=line_width, color=colors[0], markersize=marker_size, label="val-f1-score",
)
ax.set_xlabel("Epochs (n.u.)", fontsize=36)
ax.set_ylabel("f1 score (n.u.)", fontsize=36)
ax.set_ylim(-0.1,1)
ax2 = ax.twinx()
lns2 = ax2.plot(
    df_train.epoch, df_train.loss,
    marker=markers[1], linewidth=line_width, color=colors[1], markersize=marker_size, label="train-loss",
)
ax2.set_ylabel("Loss (n.u.)", fontsize=36)
ax2.set_ylim(-0.03,0.3)
ax2.set_yticks(np.arange(0,0.35,0.06))

lns = lns1+lns2
labs = [l.get_label() for l in lns]
ax.legend(lns, labs, loc="lower right", fontsize=26)

# plt.savefig("./results/ludb-unet-score-loss.pdf", dpi=1200, bbox_inches="tight", transparent=False);
# plt.savefig("./results/ludb-unet-score-loss.svg", dpi=1200, bbox_inches="tight", transparent=False);