In [None]:
import torch
from easydict import EasyDict as ED

# ^^^ pyforest auto-imports - don't write above this line
%load_ext autoreload
%autoreload 2

In [None]:
from copy import deepcopy

from cfg import ModelCfg, ModelCfg_ns, TrainCfg, TrainCfg_ns
from dataset import CINC2020
from model import ECG_CRNN_CINC2020
from torch.nn.parallel import DataParallel as DP
from torch.nn.parallel import DistributedDataParallel as DDP
from tqdm.auto import tqdm
from trainer import CINC2020Trainer

In [None]:
ECG_CRNN_CINC2020.__DEBUG__ = False
CINC2020Trainer.__DEBUG__ = False
CINC2020.__DEBUG__ = False

In [None]:
TrainCfg_ns.db_dir = "/home/wenhao/Jupyter/wenhao/data/CinC2021/"
TrainCfg.db_dir = "/home/wenhao/Jupyter/wenhao/data/CinC2021/"

In [None]:
ds_train = CINC2020(TrainCfg_ns, training=True, lazy=False)
ds_val = CINC2020(TrainCfg_ns, training=False, lazy=False)

## resnet_nature_comm_bottle_neck_se, 1-linears, AsymmetricLoss, lr=1e-4 to 2e-3, one cycle

In [None]:
train_config = deepcopy(TrainCfg_ns)
train_config.cnn_name = "resnet_nature_comm_bottle_neck_se"
train_config.rnn_name = "none"
train_config.attn_name = "none"

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

train_config.n_leads = len(train_config.leads)

tranches = train_config.tranches_for_training
if tranches:
    classes = train_config.tranche_classes[tranches]
else:
    classes = train_config.classes

model_config = deepcopy(ModelCfg_ns)

model_config.cnn.name = train_config.cnn_name
model_config.rnn.name = train_config.rnn_name
model_config.attn.name = train_config.attn_name
model_config.clf = ED()
model_config.clf.out_channels = [
    # not including the last linear layer, whose out channels equals n_classes
]
model_config.clf.bias = True
model_config.clf.dropouts = 0.0
model_config.clf.activation = "mish"  # for a single layer `SeqLin`, activation is ignored

In [None]:
model = ECG_CRNN_CINC2020(
    classes=train_config.classes,
    n_leads=train_config.n_leads,
    config=model_config,
)

In [None]:
model.module_size_

In [None]:
if torch.cuda.device_count() > 1:
    model = DP(model)
    # model = DDP(model)
model.to(device=device)

In [None]:
trainer = CINC2020Trainer(
    model=model,
    model_config=model_config,
    train_config=train_config,
    device=device,
    lazy=True,
)

In [None]:
len(classes)

In [None]:
trainer._setup_dataloaders(ds_train, ds_val)

In [None]:
trainer.train()