In [None]:
import sys
# ^^^ pyforest auto-imports - don't write above this line
try:
    import bib_lookup
except ModuleNotFoundError:
    sys.path.insert(0, "/home/wenhao/Jupyter/wenhao/workspace/bib_lookup/")
try:
    from torch_ecg.utils.misc import MovingAverage, list_sum
except ModuleNotFoundError:
    sys.path.insert(0, "/home/wenhao/Jupyter/wenhao/workspace/torch_ecg/")
    from torch_ecg.utils.misc import MovingAverage, list_sum

%load_ext autoreload
%autoreload 2

In [None]:
from tqdm.auto import tqdm
from copy import deepcopy

from torch.nn.parallel import DistributedDataParallel as DDP, DataParallel as DP

from dataset import LUDB
from cfg import TrainCfg, ModelCfg
from trainer import LUDBTrainer
from model import ECG_UNET_LUDB
from metrics import compute_metrics

In [None]:
TrainCfg.db_dir = "/home/wenhao/Jupyter/wenhao/data/LUDB/"

In [None]:
train_cfg_fl = deepcopy(TrainCfg)
train_cfg_fl.use_single_lead = False
train_cfg_fl.loss = "FocalLoss"

train_cfg_ce = deepcopy(TrainCfg)
train_cfg_ce.use_single_lead = False
train_cfg_ce.loss = "CrossEntropyLoss"

In [None]:
ds_train_fl = LUDB(train_cfg_fl, training=True, lazy=False)
ds_train_ce = LUDB(train_cfg_ce, training=True, lazy=False)

In [None]:
# ds_train_fl._load_all_data()
# ds_train_ce._load_all_data()

In [None]:
ds_val_fl = LUDB(train_cfg_fl, training=False, lazy=False)
ds_val_ce = LUDB(train_cfg_ce, training=False, lazy=False)

In [None]:
# ds_val_fl._load_all_data()
# ds_val_ce._load_all_data()

## dry run: no augmentation, no preprocessing

In [None]:
# train_config = deepcopy(TrainCfg)
model_config = deepcopy(ModelCfg)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
model = ECG_UNET_LUDB(model_config.n_leads, model_config)

In [None]:
model.module_size_

In [None]:
if torch.cuda.device_count() > 1:
    model = DP(model)
    # model = DDP(model)
model.to(device=device)

In [None]:
trainer = LUDBTrainer(
    model=model,
    model_config=model_config,
    train_config=train_cfg_fl,
    device=device,
    lazy=True,
)

In [None]:
trainer._setup_dataloaders(ds_train_fl, ds_val_fl)

In [None]:
bmd = trainer.train()

## eval and plot

In [None]:
model, _ = ECG_UNET_LUDB.from_checkpoint("/home/wenhao/.cache/torch_ecg/saved_models/BestModel_ECG_UNET_LUDB_epoch100_03-25_23-42_metric_0.99.pth.tar")

In [None]:
model

In [None]:
model_output = model.module.inference(ds_val_fl.signals[0])

In [None]:
model_output.mask

In [None]:
fig, ax = plt.subplots(figsize=(12,6))
ax.plot(ds_val_fl.signals[0][0],color="black")
ax2 = ax.twinx()
ax2.plot(model_output.mask[0], color="red")
plt.show()

## gather results

In [None]:
import seaborn as sns
from matplotlib.pyplot import cm
import matplotlib.patches as patches

sns.set()

plt.rcParams['xtick.labelsize']=28
plt.rcParams['ytick.labelsize']=28
plt.rcParams['axes.labelsize']=40
plt.rcParams['legend.fontsize']=24

colors = plt.rcParams['axes.prop_cycle'].by_key()['color']

markers = ["p", "v", "s", "d", "x", "*", "+", "$\heartsuit$"]
marker_size = 12

In [None]:
df_res = pd.read_csv("/home/wenhao/Jupyter/wenhao/workspace/torch_ecg/benchmarks/train_unet_ludb/results/TorchECG_04-06_22-30_ECG_UNET_LUDB_adamw_amsgrad_LR_0.001_BS_32.csv")

In [None]:
df_lr = df_res[df_res.part=="train"][["step", "epoch", "lr"]].dropna(subset=["lr"])
df_lr["step"] = df_lr["step"] / (1060/210)

In [None]:
fig, ax = plt.subplots(figsize=(20,12))

line_width = 4

spacing = 4

df_val = df_res[df_res.part=="val"]
df_train = df_res[df_res.part=="train"].dropna(subset=["loss"])
lns1 = ax.plot(
    df_val.epoch[::spacing], df_val.f1_score[::spacing],
    marker=markers[0], linewidth=line_width, color=colors[0], markersize=marker_size, label="val-f1-score",
)
ax.set_xlabel("Epochs (n.u.)", fontsize=36)
ax.set_ylabel("f1 score (n.u.)", fontsize=36)
ax.set_ylim(-0.1,1.1)
ax2 = ax.twinx()
lns2 = ax2.plot(
    df_train.epoch[::spacing], df_train.loss[::spacing],
    marker=markers[1], linewidth=line_width, color=colors[1], markersize=marker_size, label="train-loss",
)
ax2.set_ylabel("Loss (n.u.)", fontsize=36)
ax2.set_ylim(-0.03,0.33)
ax2.set_yticks(np.arange(0,0.35,0.06))

lr_line = ax.plot(
    df_lr.step.values[::spacing], (df_lr.lr.values/df_lr.lr.max()/0.96)[::spacing],
    linestyle=":", linewidth=6, color=colors[2],
)

ax.text(110,1.03, "max lr = 0.02", fontsize=30)
ax.text(0,-0.05, f"start lr = {df_lr.lr.values[0]:.5f}", fontsize=30)

ax2.legend(lr_line, ["learning rate",], loc="upper left", fontsize=30);

lns = lns1+lns2
labs = [l.get_label() for l in lns]
ax.legend(lns, labs, loc="lower right", fontsize=30, bbox_to_anchor=(0.95, 0.21))

rect = patches.Rectangle((25, 0.15), 35, 0.8, linewidth=3, linestyle="dotted", edgecolor='b', facecolor='r', alpha=0.3)

# Add the patch to the Axes
ax.add_patch(rect);

plt.savefig("./results/ludb-unet-score-loss.pdf", dpi=1200, bbox_inches="tight", transparent=False);
plt.savefig("./results/ludb-unet-score-loss.svg", dpi=1200, bbox_inches="tight", transparent=False);

In [None]:
fig, ax = plt.subplots(figsize=(20,12))

line_width = 4

spacing = 2

lns1 = ax.plot(
    df_val.epoch[::spacing], df_val.standard_deviation[::spacing],
    marker=markers[0], linewidth=line_width, color=colors[0], markersize=marker_size, label="val-std",
)
ax.set_xlabel("Epochs (n.u.)", fontsize=36)
ax.set_ylabel("Standard Deviation (ms)", fontsize=36)
ax.set_ylim(10,60)
ax.set_yticks(np.arange(15,65,10))
ax2 = ax.twinx()
lns2 = ax2.plot(
    df_val.epoch[::spacing], df_val.mean_error[::spacing],
    marker=markers[1], linewidth=line_width, color=colors[1], markersize=marker_size, label="val-mean-error",
)
ax2.set_ylabel("Mean Error (ms)", fontsize=36)
ax2.set_ylim(-12.5,12.5)
ax2.set_yticks(np.arange(-10,15,5))

lr_line = ax.plot(
    df_lr.step.values[::spacing], (df_lr.lr.values/df_lr.lr.max()/0.025 + 15)[::spacing],
    linestyle=":", linewidth=6, color=colors[2],
)

ax.text(90,56, "max lr = 0.02", fontsize=30)
ax.text(0,12, f"start lr = {df_lr.lr.values[0]:.5f}", fontsize=30)

ax2.legend(lr_line, ["learning rate",], loc="upper left", fontsize=30);

lns = lns1+lns2
labs = [l.get_label() for l in lns]
ax.legend(lns, labs, loc="upper right", fontsize=30)

rect = patches.Rectangle((25, 14), 35, 36, linewidth=3, linestyle="dotted", edgecolor='b', facecolor='r', alpha=0.3)

# Add the patch to the Axes
ax.add_patch(rect);

plt.savefig("./results/ludb-unet-me-std.pdf", dpi=1200, bbox_inches="tight", transparent=False);
plt.savefig("./results/ludb-unet-me-std.svg", dpi=1200, bbox_inches="tight", transparent=False);