## Setup The Environment

In [1]:
%load_ext autoreload
%autoreload 2
import logging
import os
import sys
import yaml
import torch
import thop
from fvcore.nn import FlopCountAnalysis
from io import StringIO

logging.basicConfig(level=logging.INFO)


try:
    from google.colab import drive
except ImportError:
    logging.info("Local machine detected")
    sys.path.append(os.path.realpath(".."))
else:
    logging.info("Colab detected")
    drive.mount("/content/drive")
    sys.path.append("/content/drive/MyDrive/ecg-reconstruction/src")

from ecg.trainer import Trainer, TrainerConfig
from ecg.reconstructor.transformer.transformer import UFormer, NaiveTransformerEncoder
from ecg.reconstructor.lstm.lstm import LSTM, CNNLSTM
from ecg.reconstructor.cnn.cnn import StackedCNN
from ecg.reconstructor.transformer.fastformer import Fastformer, UFastformer, FastformerPlus, FastformerStuff
from ecg.reconstructor.unet.unet import UNet
from ecg.util.path import resolve_path
from ecg.util.tree import deep_merge

# dataset_name = "ptb-xl"  # "code15%"
dataset_name = "code15%"

training_list = [
    # StackedCNN,
    # LSTM,
    # UNet,
    # # NaiveTransformerEncoder,
    # FastformerPlus,
    # CNNLSTM,
    # UFormer,
    Fastformer,
    # UFastformer,
    FastformerStuff
]

base_config: TrainerConfig = {
    "in_leads": [0, 1, 8],
    "out_leads": [6, 7, 9, 10, 11],
    "max_epochs": 32,
    "dataset": {
        "train": {"hdf5_filename": f"{dataset_name}/train.hdf5"},
        "eval": {"hdf5_filename": f"{dataset_name}/validation.hdf5"},
    },
}

INFO:root:Local machine detected


## Train a Model With a New Configuration

In [2]:
%%capture captured_output
device = "cpu"
for MODEL_TYPE in training_list:

    with open(
        resolve_path("src/best_configs") / MODEL_TYPE.__name__ / "tuned_config.yaml",
        "r", encoding="utf-8",
    ) as fp:
        best_config = yaml.load(fp, Loader=yaml.Loader)

    config = deep_merge(best_config, base_config)
    config['dataloader']['common']['num_workers'] = 6
    config['reconstructor']['type'] = MODEL_TYPE

    config_stream = StringIO()
    yaml.dump(config, config_stream, yaml.Dumper, indent=4)
    logging.info("Config:\n%s", config_stream.getvalue())
    
    trainer = Trainer(config)
    trainer.reconstructor = trainer.reconstructor.to(device)
    total_params = sum(param.numel() for param in trainer.reconstructor.parameters())
    device = next(iter(trainer.reconstructor.parameters())).device
    dummy_input = torch.from_numpy(trainer.eval_dataset[0]["input"][None, ...]).to(device)
    with torch.no_grad():
        macs, params = thop.profile(trainer.reconstructor, (dummy_input,))
        # flops = FlopCountAnalysis(trainer.reconstructor, dummy_input).total()
    macs_g = macs / 1e9
    params_m = params / 1e6
    logging.info("MACs (G): %f", macs_g)
    logging.info("Params (M): %f", params_m)
    logging.info("Number f parameters: %d", total_params)

    # logging.info("FLOPs (G): %f", flops / 1e9)
    # trainer.fit()


INFO:root:Config:
accumulate_grad_batches: 8
dataloader:
    common:
        batch_size: 128
        num_workers: 6
dataset:
    common:
        feature_scaling: false
        filter_args:
            N: 3
            Wn: !!python/tuple
            - 0.5
            - 60
            btype: bandpass
        filter_type: butter
        include_filtered_signal: false
        include_labels: {}
        include_original_signal: false
        mean_normalization: true
        predicate: null
        signal_dtype: float32
    eval:
        hdf5_filename: code15%/validation.hdf5
    train:
        hdf5_filename: code15%/train.hdf5
in_leads:
- 0
- 1
- 8
lr_scheduler:
    args:
        factor: 0.5986269041609817
        patience: 2
    type: ReduceLROnPlateau
max_epochs: 32
optimizer:
    args:
        betas:
        - 0.9091494051250636
        - 0.964802538848444
        lr: 0.0023091577420568253
        weight_decay: 0.0002987140973394854
    type: AdamW
out_leads:
- 6
- 7
- 9
- 10
- 11
recons

In [None]:
import tensorboard
%load_ext tensorboard

# Resume Training

In [None]:
%%capture captured_output

MODEL_TYPE = Fastformer

with open("../checkpoints/FastformerPlus/20230718-2035-code-02/trainer_config.yaml", "r", encoding="utf-8") as fp:
    config = yaml.load(fp, Loader=yaml.Loader)

trainer = Trainer(config)
total_params = sum(param.numel() for param in trainer.reconstructor.parameters())
logging.info("Number of parameters: %d", total_params)
device = next(iter(trainer.reconstructor.parameters())).device
dummy_input = torch.from_numpy(trainer.eval_dataset[0]["input"][None, ...]).to(device)
with torch.no_grad():
    macs, params = thop.profile(trainer.reconstructor, (dummy_input,))
macs_g = macs / 1e9
params_m = params / 1e6
logging.info("MACs (G): %f", macs_g)
logging.info("Params (M): %f", params_m)
trainer.resume(
    checkpoint_dir="../checkpoints/FastformerPlus/20230718-2035-code-02",
)


: 

In [None]:
import re
x = "xjjx=1.23.pthxxxdsaf"
# re.findall(r"=([0-9.]+)\.pth", x)[0]current_epoch
re.findall(r"=([0-9.]+)\.pth", x)[0]

# Test the model
This is a simple test. For more complicated analysis, please refer to [`testing notebook`](./src/notebooks/demo_testing_and_visualize.ipynb)

In [None]:
MODEL_TYPE = UNet

checkpoint_dir = resolve_path("src/checkpoints") / MODEL_TYPE.__name__
checkpoint_dir /= (checkpoint_dir / "latest").read_text().strip()

with open(checkpoint_dir / "trainer_config.yaml", encoding="utf-8") as config_file:
    config = yaml.load(config_file, Loader=yaml.Loader)

config["dataset"]["eval"]["hdf5_filename"] = f"{dataset_name}/test.hdf5"

trainer = Trainer(config)
trainer.load_checkpoint(checkpoint_dir / (checkpoint_dir / "best").read_text().strip())
trainer.test()
logging.info("Loss: %f", trainer.metrics.average_loss.get_average())
logging.info("RMSE: %f", trainer.metrics.rmse.get_average())
logging.info("PearsonR: %f", trainer.metrics.pearson_r.get_average())