In [None]:
import torch
# ^^^ pyforest auto-imports - don't write above this line
%load_ext autoreload
%autoreload 2
from scipy.io import loadmat, savemat

In [None]:
import wfdb
import pandas as pd

In [None]:
from utils.scoring_metrics import (
    RefInfo, load_ans,
    score, ue_calculate, ur_calculate,
    compute_challenge_metric, gen_endpoint_score_mask,
)
from utils.scoring_metrics_test import _load_af_episodes

# from database_reader.cpsc_databases import CPSC2021
# from data_reader import CINC2021Reader
from data_reader import CPSC2021Reader
from dataset import CPSC2021
from cfg import TrainCfg
from utils.aux_metrics import compute_main_task_metric, compute_rr_metric
from trainer import train, evaluate, _set_task

from utils.misc import list_sum

In [None]:
db_dir = "/home/wenhao/Jupyter/wenhao/data/CPSC2021/"

# data generator

In [None]:
from dataset import CPSC2021
from cfg import TrainCfg

In [None]:
ds_train = CPSC2021(TrainCfg, task="rr_lstm", training=True)
ds_val = CPSC2021(TrainCfg, task="rr_lstm", training=False)

In [None]:
len(ds_train)

In [None]:
err_list = []
for idx, seg in enumerate(ds_train.segments):
    sig, lb = ds_train[idx]
    if sig.shape != (2,6000) or lb.shape != (750, 1):
        print("\n"+f"segment {seg} has sig.shape = {sig.shape}, lb.shape = {lb.shape}"+"\n")
        err_list.append(seg)
    print(f"{idx+1}/{len(ds_train)}", end="\r")

In [None]:
for idx, seg in enumerate(ds_val.segments):
    sig, lb = ds_val[idx]
    if sig.shape != (2,6000) or lb.shape != (750, 1):
        print("\n"+f"segment {seg} has sig.shape = {sig.shape}, lb.shape = {lb.shape}"+"\n")
        err_list.append(seg)
    print(f"{idx+1}/{len(ds_val)}", end="\r")

In [None]:
len(err_list)

In [None]:
loadmat(ds_train._get_seg_data_path(err_list[-1]))["ecg"].shape

In [None]:
for idx, seg in enumerate(err_list):
    path = ds_train._get_seg_data_path(seg)
    os.remove(path)
    path = ds_train._get_seg_ann_path(seg)
    os.remove(path)
    print(f"{idx+1}/{len(err_list)}", end="\r")

# Plan

## R peak detection

## rr-lstm

## U-net

## sequence labelling

## R peak detection

In [None]:
from model import (
    ECG_SEQ_LAB_NET_CPSC2021,
    ECG_UNET_CPSC2021,
    ECG_SUBTRACT_UNET_CPSC2021,
    RR_LSTM_CPSC2021,
    _qrs_detection_post_process,
)
from trainer import train
from utils.misc import init_logger, dict_to_str

In [None]:
from cfg import ModelCfg, TrainCfg
from copy import deepcopy
from torch.nn.parallel import DistributedDataParallel as DDP, DataParallel as DP

In [None]:
from cfg import ModelCfg
task = "qrs_detection"  # or "main"
model_cfg = deepcopy(ModelCfg[task])
model_cfg.model_name = "seq_lab"
model = ECG_SEQ_LAB_NET_CPSC2021(model_cfg)

In [None]:
model = DP(model)
model.to(torch.device("cuda"))

In [None]:
train_config = deepcopy(TrainCfg)
# train_config.task = "qrs_detection"
_set_task("qrs_detection", train_config)
device = torch.device("cuda")

In [None]:
train_config.main.reduction

In [None]:
logger = init_logger(log_dir=train_config.log_dir, verbose=2)
logger.info(f"\n{'*'*20}   Start Training   {'*'*20}\n")
logger.info(f"Using device {device}")
logger.info(f"Using torch of version {torch.__version__}")
logger.info(f"with configuration\n{dict_to_str(train_config)}")

In [None]:
train(
    model=model,
    model_config=model_cfg,
    config=train_config,
    device=device,
    logger=logger,
    debug=train_config.debug,
)

## rr-lstm 

In [None]:
from model import (
    ECG_SEQ_LAB_NET_CPSC2021,
    ECG_UNET_CPSC2021,
    ECG_SUBTRACT_UNET_CPSC2021,
    RR_LSTM_CPSC2021,
    _qrs_detection_post_process,
)
from trainer import train, evaluate
from utils.misc import init_logger, dict_to_str

In [None]:
from cfg import ModelCfg, TrainCfg
from copy import deepcopy
from torch.nn.parallel import DistributedDataParallel as DDP, DataParallel as DP

In [None]:
task = "rr_lstm"  # or "main"
model_cfg = deepcopy(ModelCfg[task])
model = RR_LSTM_CPSC2021(model_cfg)

In [None]:
model_cfg

In [None]:
train_config = deepcopy(TrainCfg)
_set_task("rr_lstm", train_config)
device = torch.device("cuda")

In [None]:
# model = DP(model)
model.to(device)

In [None]:
logger = init_logger(log_dir=train_config.log_dir, verbose=2)
logger.info(f"\n{'*'*20}   Start Training   {'*'*20}\n")
logger.info(f"Using device {device}")
logger.info(f"Using torch of version {torch.__version__}")
logger.info(f"with configuration\n{dict_to_str(train_config)}")

In [None]:
best_model = train(
    model=model,
    model_config=model_cfg,
    config=train_config,
    device=device,
    logger=logger,
    debug=True,
)

## main_task

In [None]:
from model import (
    ECG_SEQ_LAB_NET_CPSC2021,
    ECG_UNET_CPSC2021,
    ECG_SUBTRACT_UNET_CPSC2021,
    RR_LSTM_CPSC2021,
    _qrs_detection_post_process,
    _main_task_post_process
)
from trainer import train
from utils.misc import init_logger, dict_to_str

In [None]:
from cfg import ModelCfg, TrainCfg
from copy import deepcopy
from torch.nn.parallel import DistributedDataParallel as DDP, DataParallel as DP

In [None]:
from cfg import ModelCfg
task = "main"  # or "main"
model_cfg = deepcopy(ModelCfg[task])
# model_cfg.model_name = "seq_lab"
# model = ECG_SEQ_LAB_NET_CPSC2021(model_cfg)
model_cfg.model_name = "unet"
model = ECG_UNET_CPSC2021(model_cfg)

In [None]:
model_cfg

In [None]:
model = DP(model)
model.to(torch.device("cuda"))

In [None]:
train_config = deepcopy(TrainCfg)
# train_config.task = "qrs_detection"
_set_task("main", train_config)
device = torch.device("cuda")

In [None]:
train_config.main.model_name = "unet"
train_config.main.reduction = 1
train_config.main.cnn_name = None
train_config.main.rnn_name = None
train_config.main.attn_name = None

In [None]:
logger = init_logger(log_dir=train_config.log_dir, verbose=2)
logger.info(f"\n{'*'*20}   Start Training   {'*'*20}\n")
logger.info(f"Using device {device}")
logger.info(f"Using torch of version {torch.__version__}")
logger.info(f"with configuration\n{dict_to_str(train_config)}")

In [None]:
best_model = train(
    model=model,
    model_config=model_cfg,
    config=train_config,
    device=device,
    logger=logger,
    debug=True,
)

## Misc 

In [None]:
from entry_2021 import *
from test_entry import run_test

In [None]:
sample_path = "./working_dir/sample_data/data_98_1"

In [None]:
out = challenge_entry(sample_path)

In [None]:
out

In [None]:
type(out['predict_endpoints'][0][0])

In [None]:
ds_val.reader.load_data("data_98_1").shape

In [None]:
ds_val.reader.load_af_episodes("data_98_1")