In [None]:
import torch
# ^^^ pyforest auto-imports - don't write above this line
import sys
import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)
from tqdm.auto import tqdm
from copy import deepcopy
from torch.nn.parallel import DistributedDataParallel as DDP, DataParallel as DP

%load_ext autoreload
%autoreload 2

In [None]:
from cfg import TrainCfg, ModelCfg
from trainer import CPSC2019Trainer, _MODEL_MAP
from model import ECG_SEQ_LAB_NET_CPSC2019, ECG_UNET_CPSC2019, ECG_SUBTRACT_UNET_CPSC2019
from dataset import CPSC2019

In [None]:
TrainCfg.db_dir = "/home/wenhao/Jupyter/wenhao/data/CPSC2019/"

In [None]:
ds_train = CPSC2019(TrainCfg, training=True, lazy=False)
ds_val = CPSC2019(TrainCfg, training=False, lazy=False)

## train CNN

In [None]:
train_config = deepcopy(TrainCfg)
train_config.model_name = "seq_lab_cnn"

model_config = deepcopy(ModelCfg[train_config.model_name])

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# device = torch.device('cuda:7' if torch.cuda.is_available() else 'cpu')

In [None]:
model = ECG_SEQ_LAB_NET_CPSC2019(
    n_leads=train_config.n_leads,
    config=model_config,
)

In [None]:
model.module_size_

In [None]:
if torch.cuda.device_count() > 1:
    model = DP(model)

model.to(device=device)

In [None]:
trainer = CPSC2019Trainer(
    model=model,
    model_config=model_config,
    train_config=train_config,
    device=device,
    lazy=True,
)

In [None]:
trainer._setup_dataloaders(ds_train, ds_val)

In [None]:
bmd = trainer.train()

In [None]:
del bmd, trainer, model

## train CRNN

In [None]:
train_config = deepcopy(TrainCfg)
train_config.model_name = "seq_lab_crnn"

model_config = deepcopy(ModelCfg[train_config.model_name])

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
model = ECG_SEQ_LAB_NET_CPSC2019(
    n_leads=train_config.n_leads,
    config=model_config,
)

In [None]:
model.module_size_

In [None]:
if torch.cuda.device_count() > 1:
    model = DP(model)

model.to(device=device)

In [None]:
trainer = CPSC2019Trainer(
    model=model,
    model_config=model_config,
    train_config=train_config,
    device=device,
    lazy=True,
)

In [None]:
trainer._setup_dataloaders(ds_train, ds_val)

In [None]:
bmd = trainer.train()

In [None]:
del bmd, trainer, model

## Train U-Net

In [None]:
train_config = deepcopy(TrainCfg)
train_config.model_name = "unet"

model_config = deepcopy(ModelCfg[train_config.model_name])

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# device = torch.device('cuda:7' if torch.cuda.is_available() else 'cpu')

In [None]:
model = _MODEL_MAP[train_config.model_name](
    n_leads=train_config.n_leads,
    config=model_config,
)

In [None]:
model.module_size_

In [None]:
if torch.cuda.device_count() > 1:
    model = DP(model)

model.to(device=device)

In [None]:
trainer = CPSC2019Trainer(
    model=model,
    model_config=model_config,
    train_config=train_config,
    device=device,
    lazy=True,
)

In [None]:
trainer._setup_dataloaders(ds_train, ds_val)

In [None]:
bmd = trainer.train()

In [None]:
del bmd, trainer, model

# clear GPU

In [None]:
with torch.no_grad():
    torch.cuda.empty_cache()