In [None]:
import pandas as pd
import matplotlib.pyplot as plt
# ^^^ pyforest auto-imports - don't write above this line
import sys
import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)
from tqdm.auto import tqdm
from copy import deepcopy
from torch.nn.parallel import DistributedDataParallel as DDP, DataParallel as DP

%load_ext autoreload
%autoreload 2

In [None]:
from cfg import TrainCfg, ModelCfg
from trainer import CPSC2019Trainer, _MODEL_MAP
from model import ECG_SEQ_LAB_NET_CPSC2019, ECG_UNET_CPSC2019, ECG_SUBTRACT_UNET_CPSC2019
from dataset import CPSC2019

In [None]:
TrainCfg.db_dir = "/home/wenhao/Jupyter/wenhao/data/CPSC2019/"

In [None]:
ds_train = CPSC2019(TrainCfg, training=True, lazy=False)
ds_val = CPSC2019(TrainCfg, training=False, lazy=False)

## train CNN

In [None]:
train_config = deepcopy(TrainCfg)
train_config.model_name = "seq_lab_cnn"

model_config = deepcopy(ModelCfg[train_config.model_name])

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# device = torch.device('cuda:7' if torch.cuda.is_available() else 'cpu')

In [None]:
model = ECG_SEQ_LAB_NET_CPSC2019(
    n_leads=train_config.n_leads,
    config=model_config,
)

In [None]:
model.module_size_

In [None]:
# if torch.cuda.device_count() > 1:
#     model = DP(model)

model.to(device=device)

In [None]:
trainer = CPSC2019Trainer(
    model=model,
    model_config=model_config,
    train_config=train_config,
    device=device,
    lazy=True,
)

In [None]:
trainer._setup_dataloaders(ds_train, ds_val)

In [None]:
bmd = trainer.train()

In [None]:
del bmd, trainer, model

## train CRNN

In [None]:
train_config = deepcopy(TrainCfg)
train_config.model_name = "seq_lab_crnn"

model_config = deepcopy(ModelCfg[train_config.model_name])

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
model_config

In [None]:
model = ECG_SEQ_LAB_NET_CPSC2019(
    n_leads=train_config.n_leads,
    config=model_config,
)

In [None]:
model.module_size_

In [None]:
# if torch.cuda.device_count() > 1:
#     model = DP(model)

model.to(device=device)

In [None]:
trainer = CPSC2019Trainer(
    model=model,
    model_config=model_config,
    train_config=train_config,
    device=device,
    lazy=True,
)

In [None]:
trainer._setup_dataloaders(ds_train, ds_val)

In [None]:
bmd = trainer.train()

In [None]:
del bmd, trainer, model

## Train U-Net

In [None]:
train_config = deepcopy(TrainCfg)
train_config.model_name = "unet"

model_config = deepcopy(ModelCfg[train_config.model_name])

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# device = torch.device('cuda:7' if torch.cuda.is_available() else 'cpu')

In [None]:
model = _MODEL_MAP[train_config.model_name](
    n_leads=train_config.n_leads,
    config=model_config,
)

In [None]:
model.module_size_

In [None]:
# if torch.cuda.device_count() > 1:
#     model = DP(model)

model.to(device=device)

In [None]:
trainer = CPSC2019Trainer(
    model=model,
    model_config=model_config,
    train_config=train_config,
    device=device,
    lazy=True,
)

In [None]:
trainer._setup_dataloaders(ds_train, ds_val)

In [None]:
bmd = trainer.train()

In [None]:
del bmd, trainer, model

# clear GPU

In [None]:
with torch.no_grad():
    torch.cuda.empty_cache()

# gather results

In [None]:
import seaborn as sns
from matplotlib.pyplot import cm
sns.set()
colors = plt.rcParams['axes.prop_cycle'].by_key()['color']
markers = ["+", "v", "x", "*", "p", "d", "s", "$\heartsuit$"]
marker_size = 9
plt.rcParams['xtick.labelsize']=28
plt.rcParams['ytick.labelsize']=28
plt.rcParams['axes.labelsize']=40
plt.rcParams['legend.fontsize']=24

In [None]:
from torch_ecg.utils.misc import MovingAverage

In [None]:
# ma = MovingAverage()

ma = lambda x: x

In [None]:
df_cnn = pd.read_csv("./results/TorchECG_12-22_16-52_ECG_SEQ_LAB_NET_CPSC2019_adamw_amsgrad_LR_0.001_BS_32_multi_scopic.csv")
df_crnn = pd.read_csv("./results/TorchECG_12-22_17-13_ECG_SEQ_LAB_NET_CPSC2019_adamw_amsgrad_LR_0.001_BS_32_multi_scopic.csv")
df_unet = pd.read_csv("./results/TorchECG_12-22_17-48_ECG_UNET_CPSC2019_adamw_amsgrad_LR_0.001_BS_32_none.csv")

In [None]:
df_cnn_train = df_cnn[df_cnn.part=="train"].dropna(subset=["qrs_score"])
df_crnn_train = df_crnn[df_crnn.part=="train"].dropna(subset=["qrs_score"])
df_unet_train = df_unet[df_unet.part=="train"].dropna(subset=["qrs_score"])
df_cnn_val = df_cnn[df_cnn.part=="val"].dropna(subset=["qrs_score"])
df_crnn_val = df_crnn[df_crnn.part=="val"].dropna(subset=["qrs_score"])
df_unet_val = df_unet[df_unet.part=="val"].dropna(subset=["qrs_score"])

In [None]:
fig, ax = plt.subplots(figsize=(16,12))

ax.plot(
    df_crnn_train.epoch.values, ma(df_crnn_train.qrs_score.values),
    marker=markers[0], markersize=marker_size, linewidth=2, color=colors[0], label="crnn-train",
)
ax.plot(
    df_cnn_train.epoch.values, ma(df_cnn_train.qrs_score.values),
    marker=markers[1], markersize=marker_size, linewidth=2, color=colors[1], label="cnn-train",
)
ax.plot(
    df_unet_train.epoch.values, ma(df_unet_train.qrs_score.values),
    marker=markers[2], markersize=marker_size, linewidth=2, color=colors[2], label="unet-train",
)
ax.plot(
    df_crnn_train.epoch.values, ma(df_crnn_val.qrs_score.values),
    marker=markers[0], markersize=marker_size, linewidth=2, color=colors[0], ls="--", label="crnn-val",
)
ax.plot(
    df_cnn_train.epoch.values, ma(df_cnn_val.qrs_score.values),
    marker=markers[1], markersize=marker_size, linewidth=2, color=colors[1], ls="--", label="cnn-val",
)
ax.plot(
    df_unet_train.epoch.values, ma(df_unet_val.qrs_score.values),
    marker=markers[2], markersize=marker_size, linewidth=2, color=colors[2], ls="--", label="unet-val",
)
ax.set_ylim(0.6,1.05)
ax.legend(loc="best", ncol=2)
ax.set_xlabel("Epochs (n.u.)", fontsize=36)
ax.set_ylabel("QRS score (n.u.)", fontsize=36)

plt.savefig("./results/cpsc2019_nn_compare.svg", dpi=1200, bbox_inches="tight", transparent=False)
plt.savefig("./results/cpsc2019_nn_compare.pdf", dpi=1200, bbox_inches="tight", transparent=False)