In [None]:
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
# ^^^ pyforest auto-imports - don't write above this line
%load_ext autoreload
%autoreload 2

In [None]:
from tqdm.auto import tqdm
from copy import deepcopy

from torch.nn.parallel import DistributedDataParallel as DDP, DataParallel as DP

from dataset import LUDB
from cfg import TrainCfg, ModelCfg
from trainer import LUDBTrainer
from model import ECG_UNET_LUDB
from metrics import compute_metrics

In [None]:
TrainCfg.db_dir = "/home/wenhao/Jupyter/wenhao/data/LUDB/"

In [None]:
train_cfg_fl = deepcopy(TrainCfg)
train_cfg_fl.use_single_lead = False
train_cfg_fl.loss = "FocalLoss"

train_cfg_ce = deepcopy(TrainCfg)
train_cfg_ce.use_single_lead = False
train_cfg_ce.loss = "CrossEntropyLoss"

In [None]:
ds_train_fl = LUDB(train_cfg_fl, training=True, lazy=False)
ds_train_ce = LUDB(train_cfg_ce, training=True, lazy=False)

In [None]:
# ds_train_fl._load_all_data()
# ds_train_ce._load_all_data()

In [None]:
ds_val_fl = LUDB(train_cfg_fl, training=False, lazy=False)
ds_val_ce = LUDB(train_cfg_ce, training=False, lazy=False)

In [None]:
# ds_val_fl._load_all_data()
# ds_val_ce._load_all_data()

## dry run: no augmentation, no preprocessing

In [None]:
# train_config = deepcopy(TrainCfg)
model_config = deepcopy(ModelCfg)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
model = ECG_UNET_LUDB(model_config.n_leads, model_config)

In [None]:
model.module_size_

In [None]:
if torch.cuda.device_count() > 1:
    model = DP(model)
    # model = DDP(model)
model.to(device=device)

In [None]:
trainer = LUDBTrainer(
    model=model,
    model_config=model_config,
    train_config=train_cfg_fl,
    device=device,
    lazy=True,
)

In [None]:
trainer._setup_dataloaders(ds_train_fl, ds_val_fl)

In [None]:
bmd = trainer.train()

## eval and plot

In [None]:
model, _ = ECG_UNET_LUDB.from_checkpoint("./checkpoints/BestModel_ECG_UNET_LUDB_epoch111_12-16_22-23_metric_0.97.pth.tar")

In [None]:
model

In [None]:
_, mask = model.inference(ds_val_fl.signals[0])

In [None]:
mask

In [None]:
fig, ax = plt.subplots(figsize=(12,6))
ax.plot(ds_val_fl.signals[0][0],color="black")
ax2 = ax.twinx()
ax2.plot(mask[0], color="red")
plt.show()

## gather stats

In [None]:
import seaborn as sns
from matplotlib.pyplot import cm
sns.set()
colors = plt.rcParams['axes.prop_cycle'].by_key()['color']
markers = ["*", "v", "p", "d", "s", "$\heartsuit$", "+", "x", ]
marker_size = 9
plt.rcParams['xtick.labelsize']=24
plt.rcParams['ytick.labelsize']=24
plt.rcParams['axes.labelsize']=32
plt.rcParams['legend.fontsize']=20

marker_size = 9

In [None]:
df_res = pd.read_csv("./results/TorchECG_12-16_21-52_ECG_UNET_LUDB_adamw_amsgrad_LR_0.0001_BS_32.csv")

In [None]:
df_res

In [None]:
fig, ax = plt.subplots(figsize=(16,12))

line_width = 2.5

df_val = df_res[df_res.part=="val"]
df_train = df_res[df_res.part=="train"].dropna(subset=["loss"])
lns1 = ax.plot(
    df_val.epoch, df_val.f1_score,
    marker=markers[0], linewidth=line_width, color=colors[0], markersize=marker_size, label="f1 score",
)
ax.set_xlabel("Epochs (n.u.)")
ax.set_ylabel("f1 score (n.u.)")
ax.set_ylim(-0.1,1)
ax2 = ax.twinx()
lns2 = ax2.plot(
    df_train.epoch, df_train.loss,
    marker=markers[1], linewidth=line_width, color=colors[1], markersize=marker_size, label="Loss",
)
ax2.set_ylabel("Loss (n.u.)")
ax2.set_ylim(-0.03,0.3)
ax2.set_yticks(np.arange(0,0.35,0.06))

lns = lns1+lns2
labs = [l.get_label() for l in lns]
ax.legend(lns, labs, loc="lower right", fontsize=28)

plt.savefig("./results/ludb-unet-score-loss.pdf", dpi=1200, bbox_inches="tight", transparent=False)
plt.savefig("./results/ludb-unet-score-loss.svg", dpi=1200, bbox_inches="tight", transparent=False)