In [1]:
import torch
import numpy as np
# ^^^ pyforest auto-imports - don't write above this line
sys.path.insert(0, str(Path("../../").resolve()))

%load_ext autoreload
%autoreload 2

In [2]:
from data_reader import CINC2022Reader, CINC2016Reader, EPHNOGRAMReader
from dataset import CinC2022Dataset
from models import (
    CRNN_CINC2022,
    SEQ_LAB_NET_CINC2022,
    UNET_CINC2022,
    Wav2Vec2_CINC2022,
    HFWav2Vec2_CINC2022,
)
from cfg import TrainCfg, ModelCfg
from trainer import CINC2022Trainer, _MODEL_MAP, _set_task, collate_fn
from utils.plot import plot_spectrogram

from tqdm.auto import tqdm
import torchaudio
from copy import deepcopy

from torch.nn.parallel import DistributedDataParallel as DDP, DataParallel as DP
from torch.utils.data import DataLoader

CRNN_CINC2022.__DEBUG__ = False
Wav2Vec2_CINC2022.__DEBUG__ = False
HFWav2Vec2_CINC2022.__DEBUG__ = False
CinC2022Dataset.__DEBUG__ = False

%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [4]:
db_dir = "/data1/Jupyter-Data/CinC2022/"  # replace with the data directory

In [5]:
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if ModelCfg.torch_dtype == torch.float64:
    torch.set_default_tensor_type(torch.DoubleTensor)
    DTYPE = np.float64
else:
    DTYPE = np.float32

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

In [6]:
# task = "classification"
task = "multi_task"

train_config = deepcopy(TrainCfg)
# train_config.db_dir = data_folder
# train_config.model_dir = model_folder
# train_config.final_model_filename = _ModelFilename
train_config.debug = True

train_config.db_dir = db_dir

# train_config.n_epochs = 100
# train_config.batch_size = 24  # 16G (Tesla T4)
# train_config.log_step = 20
# # train_config.max_lr = 1.5e-3
# train_config.early_stopping.patience = 20

train_config[task].model_name = "crnn"  # "wav2vec2_hf"

train_config[task].cnn_name = "tresnetF"  # "resnet_nature_comm_bottle_neck_se"
# train_config[task].rnn_name = "none"  # "none", "lstm"
# train_config[task].attn_name = "se"  # "none", "se", "gc", "nl"

_set_task(task, train_config)

model_config = deepcopy(ModelCfg[task])

# adjust model choices if needed
model_config.model_name = train_config[task].model_name
# print(model_name)
if "cnn" in model_config[model_config.model_name]:
    model_config[model_config.model_name].cnn.name = train_config[task].cnn_name
if "rnn" in model_config[model_config.model_name]:
    model_config[model_config.model_name].rnn.name = train_config[task].rnn_name
if "attn" in model_config[model_config.model_name]:
    model_config[model_config.model_name].attn.name = train_config[task].attn_name

# model_config.wav2vec2.cnn.name = "resnet_nature_comm_bottle_neck_se"
# model_config.wav2vec2.encoder.name = "wav2vec2_nano"

In [8]:
model_cls = _MODEL_MAP[model_config.model_name]
model_cls.__DEBUG__ = False

In [9]:
model = model_cls(config=model_config)
if torch.cuda.device_count() > 1:
    model = DP(model)
    # model = DDP(model)
model.to(device=DEVICE);

<IPython.core.display.Javascript object>

In [10]:
model.module.module_size, model.module.module_size_

(5130149, '19.6M')

In [11]:
model

DataParallel(
  (module): CRNN_CINC2022(
    (cnn): ResNet(
      (input_stem): ResNetStem(
        (s2d): SpaceToDepth(
          (out_conv): Conv_Bn_Activation(
            (conv1d): Conv1d(4, 32, kernel_size=(1,), stride=(1,))
            (batch_norm): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          )
        )
      )
      (ResNetBasicBlock_0_0): ResNetBasicBlock(
        (main_stream): Sequential(
          (cba_0): Conv_Bn_Activation(
            (conv1d): SeparableConv(
              (depthwise_conv): Conv1d(32, 32, kernel_size=(35,), stride=(1,), padding=(17,), groups=32, bias=False)
              (pointwise_conv): Conv1d(32, 32, kernel_size=(1,), stride=(1,), bias=False)
            )
            (batch_norm): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (activation_LeakyReLU): LeakyReLU(negative_slope=0.01, inplace=True)
          )
          (cba_1): Conv_Bn_Activation(
            (con

In [12]:
ds_train = CinC2022Dataset(train_config, task, training=True, lazy=True)
ds_test = CinC2022Dataset(train_config, task, training=False, lazy=True)

Reading the statistics from local file...
Reading the statistics of the records from local file...
Reading the statistics from local file...
Reading the statistics of the records from local file...


In [13]:
ds_train._load_all_data()

Loading data:   0%|          | 0/2521 [00:00<?, ?records/s]

In [14]:
ds_test._load_all_data()

Loading data:   0%|          | 0/641 [00:00<?, ?records/s]

In [15]:
trainer = CINC2022Trainer(
    model=model,
    model_config=model_config,
    train_config=train_config,
    device=DEVICE,
    lazy=True,
)

log file path: /home/wenhao/Jupyter/wenhao/workspace/torch_ecg/benchmarks/train_mtl_cinc2022/log/TorchECG_10-10_16-21_task-multi_task_CRNN_CINC2022_adamw_amsgrad_LR_0.0005_BS_24.txt
level of c_handler is set INFO, level of f_handler is set DEBUG
TorchECG - INFO - training configurations are as follows:
{
    "debug": True,
    "final_model_name": None,
    "log_step": 20,
    "flooding_level": 0.0,
    "early_stopping": {
        "min_delta": 0.001,
        "patience": 30
    },
    "log_dir": /home/wenhao/Jupyter/wenhao/workspace/torch_ecg/benchmarks/train_mtl_cinc2022/log,
    "checkpoints": /home/wenhao/Jupyter/wenhao/workspace/torch_ecg/benchmarks/train_mtl_cinc2022/checkpoints,
    "model_dir": /home/wenhao/Jupyter/wenhao/workspace/torch_ecg/benchmarks/train_mtl_cinc2022/saved_models,
    "prefix": "TorchECG",
    "DTYPE": DTYPE(STR='float32', NP=dtype('float32'), TORCH=torch.float32, INT=32),
    "str_dtype": "float32",
    "np_dtype": <class 'numpy.float32'>,
    "dtype": torch.

In [16]:
trainer._setup_dataloaders(ds_train, ds_test)

In [17]:
best_state_dict = trainer.train()

TorchECG - INFO - 
Starting training:
------------------
Epochs:          60
Batch size:      24
Learning rate:   0.0005
Training size:   2545
Validation size: 643
Device:          cuda
Optimizer:       adamw_amsgrad
Dataset classes: ['Present', 'Unknown', 'Absent']
-----------------------------------------



Epoch 0/60:   0%|          | 0/2545 [00:00<?, ?signals/s]

TorchECG - INFO - Train epoch_0:
--------------------------------------------------------------------------------------------------------------
TorchECG - INFO - Train Metrics:
--------------------------------------------------
Epoch 0 / Step 20: train/loss : 0.6994
Epoch 0 / Step 20: train/lr :   0.0001
--------------------------------------------------
TorchECG - INFO - Train Metrics:
--------------------------------------------------
Epoch 0 / Step 40: train/loss : 0.6535
Epoch 0 / Step 40: train/lr :   0.0001
--------------------------------------------------
TorchECG - INFO - Train Metrics:
--------------------------------------------------
Epoch 0 / Step 60: train/loss : 0.4824
Epoch 0 / Step 60: train/lr :   0.0001
--------------------------------------------------
TorchECG - INFO - Train Metrics:
--------------------------------------------------
Epoch 0 / Step 80: train/loss : 0.5565
Epoch 0 / Step 80: train/lr :   0.0001
--------------------------------------------------
Torc

TorchECG - INFO - 
----------------------------------------------
outcome scalar prediction:    [0.552, 0.448]
outcome binary prediction:    [1, 0]
outcome labels:               [1, 0]
outcome predicted classes:    ['Abnormal']
outcome label classes:        ['Abnormal']
----------------------------------------------

Computing AUROC and AUPRC...
Computing F-measure...
Computing accuracy...
Computing weighted accuracy...
Computing challenge cost...
Computing AUROC and AUPRC...
Computing F-measure...
Computing accuracy...
Computing weighted accuracy...
Computing challenge cost...
TorchECG - INFO - Val Metrics:
--------------------------------------------------
Epoch 0 / Step 107: val/murmur_auroc :              0.6306
Epoch 0 / Step 107: val/murmur_auprc :              0.5568
Epoch 0 / Step 107: val/murmur_f_measure :          0.4555
Epoch 0 / Step 107: val/murmur_accuracy :           0.8367
Epoch 0 / Step 107: val/murmur_weighted_accuracy :  0.5061
Epoch 0 / Step 107: val/murmur_cost : 

Epoch 1/60:   0%|          | 0/2545 [00:00<?, ?signals/s]

TorchECG - INFO - Train epoch_1:
--------------------------------------------------------------------------------------------------------------
TorchECG - INFO - Train Metrics:
--------------------------------------------------
Epoch 1 / Step 120: train/loss : 0.5223
Epoch 1 / Step 120: train/lr :   0.0001
--------------------------------------------------
TorchECG - INFO - Train Metrics:
--------------------------------------------------
Epoch 1 / Step 140: train/loss : 0.4560
Epoch 1 / Step 140: train/lr :   0.0001
--------------------------------------------------
TorchECG - INFO - Train Metrics:
--------------------------------------------------
Epoch 1 / Step 160: train/loss : 0.4841
Epoch 1 / Step 160: train/lr :   0.0001
--------------------------------------------------
TorchECG - INFO - Train Metrics:
--------------------------------------------------
Epoch 1 / Step 180: train/loss : 0.5246
Epoch 1 / Step 180: train/lr :   0.0001
-----------------------------------------------

KeyboardInterrupt: 

## Inspect trained models

In [None]:
from models import Wav2Vec2_CINC2022, CRNN_CINC2022

%load_ext autoreload
%autoreload 2

In [None]:
ckpt = CRNN_CINC2022.from_checkpoint(
    "./saved_models/BestModel_task-multi_task_CRNN_CINC2022_epoch41_08-11_02-38_metric_-16272.44.pth.tar"
    # replace with a saved model
)

In [None]:
ckpt[0].config

In [None]:
best_model = ckpt[0]

In [None]:
best_model = best_model.to("cpu")

In [None]:
dl = DataLoader(
    dataset=ds_train,
    batch_size=4,
    shuffle=True,
    num_workers=4,
    pin_memory=True,
    drop_last=False,
    collate_fn=collate_fn,
)

In [None]:
for batch in dl:
    labels = batch
    waveforms = labels.pop("waveforms")
    break

In [None]:
best_model(waveforms, labels)