## Setup The Environment

In [4]:
%load_ext autoreload
%autoreload 2
import logging
import os
import sys
from pathlib import Path

import yaml
from tqdm.auto import tqdm

logging.basicConfig(level=logging.INFO)

try:
    from google.colab import drive
except ImportError:
    logging.info("Local machine detected")
    sys.path.append(os.path.realpath(".."))
else:
    logging.info("Colab detected")
    drive.mount("/content/drive")
    sys.path.append("/content/drive/MyDrive/ecg-reconstruction/src")

from ecg.trainer import Trainer, TrainerConfig
from ecg.reconstructor.transformer.transformer import UFormer
from ecg.reconstructor.lstm.lstm import LSTM
from ecg.reconstructor.linear.linear import Linear
from ecg.util.tree import deep_merge

INFO:root:Local machine detected


The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


## Train a Model With a New Configuration

In [9]:
MODEL_TYPE = Linear
# dataset = "ptb-xl"
dataset = "code15%"

for i in tqdm(range(6, 12)):
    config: TrainerConfig = {
        "in_leads": [0, 1, i],
        "out_leads": [oidx for oidx in range(6, 12) if oidx != i],
        "max_epochs": 32,
        "accumulate_grad_batches": 8,
        "dataset": {
            "common": {
                "predicate": None,
                "signal_dtype": "float32",
                "filter_type": "butter",
                "filter_args": {"N": 3, "Wn": (0.5, 60), "btype": "bandpass"},
                # "filter_args": {"N": 3, "Wn": (0.05, 150), "btype": "bandpass"},
                # "mean_normalization": True,
                "mean_normalization": False,
                "feature_scaling": False,
                "include_original_signal": False,
                "include_filtered_signal": False, # This will be set to True in visulization
                "include_labels": {},
            },
            "train": {"hdf5_filename": f"{dataset}/train.hdf5"},
            "eval": {"hdf5_filename": f"{dataset}/validation.hdf5"},
        },
        "dataloader": {
            "common": {"num_workers": 6},
        },
        "reconstructor": {"type": MODEL_TYPE},
    }
    # with open(os.path.join(f"../best_configs/{MODEL_TYPE.__name__}/tuned_config.yaml"), 'r') as fp:
    #     best_config = yaml.safe_load(fp)

    config = deep_merge(config, MODEL_TYPE.default_config())
    config['reconstructor']["args"]['in_leads'] = config['in_leads']
    config['reconstructor']["args"]['out_leads'] = config['out_leads']
    config['dataloader']['common']["batch_size"] = 256
    config["accumulate_grad_batches"] = 1
    trainer = Trainer(config)
    trainer.fit()

  0%|          | 0/6 [00:00<?, ?it/s]

INFO:ecg.trainer:Epoch 1
Train 769/769 [0:00:42<0:00:00, 0.0545s/it, batch_loss=0.07923, average_loss=0.09844]
INFO:ecg.trainer:Loss=0.09844, RMSE=0.3137, Pearson-R=0.7595


KeyboardInterrupt: 

# Test the model
This is a simple test. For more complicated analysis, please refer to [`testing notebook`](./src/notebooks/demo_testing_and_visualize.ipynb)

In [8]:
checkpoint_dir = Path("../checkpoints/Linear")

# checkpoint_dir /= (checkpoint_dir / "latest").read_text().strip()
subfolders = [ Path(f.path) for f in os.scandir(checkpoint_dir) if f.is_dir() ]
for checkpoint in tqdm(subfolders):
    with open( checkpoint / "trainer_config.yaml", encoding="utf-8") as config_file:
        config = yaml.load(config_file, Loader=yaml.Loader)
    
    config['dataset']['eval']['hdf5_filename'] = 'code15%/test.hdf5'

    trainer = Trainer(config)
    trainer.load_checkpoint(checkpoint / (checkpoint / "best").read_text().strip())
    trainer.test()
    logging.info("Loss: %f", trainer.metrics.average_loss.get_average())
    logging.info("RMSE: %f", trainer.metrics.rmse.get_average())
    logging.info("PearsonR: %f", trainer.metrics.pearson_r.get_average())

  0%|          | 0/6 [00:00<?, ?it/s]

Test 165/165 [0:00:07<0:00:00, 0.0433s/it, batch_loss=0.1539, average_loss=0.09478] 
INFO:root:Loss: 0.094779
INFO:root:RMSE: 0.307862
INFO:root:PearsonR: 0.776552
Test 165/165 [0:00:07<0:00:00, 0.0427s/it, batch_loss=0.1051, average_loss=0.07003] 
INFO:root:Loss: 0.070032
INFO:root:RMSE: 0.264636
INFO:root:PearsonR: 0.823635
Test 165/165 [0:00:07<0:00:00, 0.04s/it, batch_loss=0.08826, average_loss=0.0628]   
INFO:root:Loss: 0.062797
INFO:root:RMSE: 0.250593
INFO:root:PearsonR: 0.835283
Test 165/165 [0:00:06<0:00:00, 0.036s/it, batch_loss=0.1068, average_loss=0.07961]  
INFO:root:Loss: 0.079608
INFO:root:RMSE: 0.282149
INFO:root:PearsonR: 0.758198
Test 165/165 [0:00:06<0:00:00, 0.0384s/it, batch_loss=0.1601, average_loss=0.1041]  
INFO:root:Loss: 0.104052
INFO:root:RMSE: 0.322570
INFO:root:PearsonR: 0.600485
Test 165/165 [0:00:06<0:00:00, 0.0364s/it, batch_loss=0.204, average_loss=0.1316]  
INFO:root:Loss: 0.131649
INFO:root:RMSE: 0.362834
INFO:root:PearsonR: 0.474759
