## Setting Up The Environment

In [1]:
%load_ext autoreload
%autoreload 2
import logging
import os
import sys
import yaml
from pathlib import Path
from typing import Any, Optional

import thop
import torch
import ipywidgets as widgets
import matplotlib.pyplot as plt
import numpy as np

logging.basicConfig(level=logging.INFO)

try:
    from google.colab import drive
except ImportError:
    logging.info("Local machine detected")
    sys.path.append(os.path.realpath(".."))
else:
    logging.info("Colab detected")
    drive.mount("/content/drive")
    sys.path.append("/content/drive/MyDrive/ecg-reconstruction/src")

from ecg.trainer import Trainer
from ecg.util.path import resolve_path
from ecg.util.visualize import visualize_model, visualize_model_all
from collections import OrderedDict

INFO:root:Local machine detected


## Sae current useful checkpoint to a folder

In [24]:
import shutil
root_folder = "../checkpoints"
target_folder = "../sharing"
for folder in os.listdir(root_folder):
    if folder not in os.listdir(target_folder):
        os.mkdir(os.path.join(target_folder, folder))
    dirs = [f.path for f in os.scandir(os.path.join(root_folder,folder)) if f.is_dir()]
    for dir in dirs:
        if "best" in os.listdir(dir):
            best_file = Path(dir) / (Path(dir) / "best").read_text().strip()
            if dir.split('\\')[-1] not in os.listdir(os.path.join(target_folder, folder)):
                final_target_dir = os.path.join(target_folder, folder, dir.split('\\')[-1])
                os.mkdir(final_target_dir) 
                shutil.copy(best_file, final_target_dir)
                shutil.copy(Path(dir) / "best", final_target_dir)
                shutil.copy(Path(dir) / "trainer_config.yaml", final_target_dir)


## Testing The Model

The following 2 blocks test the models in `experiment_list`. These experiments generates the trained models for visualization.

In [9]:
# This is a list of trained models. Each item in the list has 2 elements: model folder
# name and version. If version is None, the latest version will be used. Otherwise,
# please specify the version.
# experiment_list: list[tuple[str, Optional[str]]] = [
#     ("UFormer", None),
#     ("LSTM", None),
# ]


# for experiment_name, version in experiment_list:
#     checkpoint_dir = resolve_path("src/checkpoints") / experiment_name
#     if version is None:
#         version = (checkpoint_dir / "latest").read_text().strip()
#     checkpoint_dir /= version
# dir_list = [
#     "../checkpoints/StackedCNN/20230716-1120-code",    
#     "../checkpoints/LSTM/20230716-1148-code",
#     "../checkpoints/Unet/20230716-1240-code",
#     "../checkpoints/Fastformer/20230716-2316-code",
#     "../checkpoints/UFormer/20230716-1921-code",
#     "../checkpoints/UFastformer/20230717-0020-code",
# ]
# dir_list = [
#     "../checkpoints/StackedCNN/20230716-0255-ptb",    
#     "../checkpoints/LSTM/20230716-0258-ptb",
#     "../checkpoints/Unet/20230716-0302-ptb",
#     "../checkpoints/Fastformer/20230716-0419-ptb",
#     "../checkpoints/UFormer/20230716-0347-ptb",
#     "../checkpoints/UFastformer/20230716-0426-ptb",
# ]
# dir_list = [ "../checkpoints/LSTM/comparision_I_II_x-code15%/" + x for x in os.listdir("../checkpoints/LSTM/comparision_I_II_x-code15%")]
# dir_list = [ "../checkpoints/LSTM/comparison_I_II_x_ptb-xl/" + x for x in os.listdir("../checkpoints/LSTM/comparison_I_II_x_ptb-xl")]

# dir_list = [
#     # "../checkpoints/LSTM/20230716-0258-ptb",
#     # "../checkpoints/Unet/20230716-0302-ptb", 
#     # "../checkpoints/FastformerPlus/20230717-2038",
#     # "../checkpoints/FastformerPlus/20230717-2058",
#     # "../checkpoints/CNNLSTM/20230718-0132"
# ]

# dir_list = [
#     "../checkpoints/CNNLSTM/20230718-0338-code",
#     "../checkpoints/Fastformer/20230718-0828-code-01",
#     "../checkpoints/FastFormerPlus/20230718-0210-code-01",
#     "../checkpoints/UFormer/20230718-0435-code-01",
# ]

dir_list = [
    "../checkpoints/Linear/code/20230718-0105",
    "../checkpoints/StackedCNN/20230718-1651-code",
    "../checkpoints/LSTM/20230718-1651-code",
    "../checkpoints/CNNLSTM/20230719-1821-code",
    "../checkpoints/Unet/20230716-1240-code",
    "../checkpoints/Fastformer/20230718-0828-code-01",
    "../checkpoints/FastformerPlus/20230718-2123-code-01",
    "../checkpoints/UFastformer/20230719-0836-code-01",
]

# dir_list = [
#     # "../checkpoints/Linear/ptb/20230717-2126",
#     # "../checkpoints/StackedCNN/20230719-0004-ptb",
#     # "../checkpoints/LSTM/20230719-0008-ptb",
#     # "../checkpoints/CNNLSTM/20230719-1801-ptb",
#     # "../checkpoints/Unet/20230716-0302-ptb",
#     # "../checkpoints/Fastformer/20230719-0424-ptb-01",
#     # "../checkpoints/FastformerPlus/20230719-0017-ptb-01",
#     # "../checkpoints/FastformerZero/20230720-2054",
#     # "../checkpoints/FastformerStuff/20230721-2303-ptb",
#     # "../checkpoints/UFastformer/20230719-0454-ptb-01",
# ]
                                                                                     


model_dict: dict[str, dict[str, Any]] = {}
for checkpoint_dir in dir_list:
    checkpoint_dir = Path(checkpoint_dir)

    with open(checkpoint_dir / "trainer_config.yaml", encoding="utf-8") as config_file:
        config = yaml.load(config_file, Loader=yaml.Loader)
    # Replace validation.hdf5 with test.hdf5.
    hdf5_filename = Path(config["dataset"]["eval"]["hdf5_filename"])
    hdf5_filename = hdf5_filename.parent / "test.hdf5"
    config["dataset"]["eval"]["hdf5_filename"] = hdf5_filename
    config["dataset"]["common"]["include_filtered_signal"] = True
    config['accumulate_grad_batches'] = 4
    config['dataloader']['common']['batch_size'] = 128

    # CODE-15% only: filter certain diseases
    # disease_label = ["1dAVb", "AF", "LBBB", "RBBB", "SB", "ST"][0]
    # config["dataset"]["common"]["predicate"] = f"lambda f: f['{disease_label}'][:]"
    # logging.info("Out leads: %s", config["out_leads"])

    trainer = Trainer(config)
    trainer.load_checkpoint(
        checkpoint_dir / (checkpoint_dir / "best").read_text().strip()
    )
    trainer.test()
    test_loss = trainer.metrics.average_loss.get_average()
    logging.info("Loss: %f", test_loss)

    # Lead-wise Metrics
    test_rmse = trainer.metrics.rmse.get_average()
    lead_wise_rmse = list(map(trainer.metrics.rmse._postprocess, trainer.metrics.rmse._metric_batches[0].mean(dim=0)))
    logging.info("RMSE: %f", test_rmse)
    logging.info("Lead-wise RMSE: %s", lead_wise_rmse)

    test_pearson_r = trainer.metrics.pearson_r.get_average()
    lead_wise_pearson_r = list(map(trainer.metrics.pearson_r._postprocess, trainer.metrics.pearson_r._metric_batches[0].mean(dim=0)))
    logging.info("PearsonR: %f", test_pearson_r)
    logging.info("Lead-wise PearsonR: %s", lead_wise_pearson_r)

    total_params = sum(param.numel() for param in trainer.reconstructor.parameters())
    logging.info("Number of parameters: %d", total_params)
    device = next(iter(trainer.reconstructor.parameters())).device
    dummy_input = torch.from_numpy(
        trainer.eval_dataset[0]["input"][None, ...]
    ).to(device)
    with torch.no_grad():
        macs, params = thop.profile(trainer.reconstructor, (dummy_input,))
    macs_g = macs / 1e9
    params_m = params / 1e6
    logging.info("MACs (G): %f", macs_g)
    logging.info("Params (M): %f", params_m)
    # experiment = f"{experiment_name}/{version}"
    experiment = config['reconstructor']['type'].__name__
    model_dict[experiment] = {
        "result": OrderedDict({
            "test_loss": test_loss,
            "test_rmse": test_rmse,
            # "Lead-wise_rmse": lead_wise_rmse,
            **{f"V{i+1}_rmse":lead_wise_rmse[i] for i in range(len(config['out_leads']))},
            "test_pearson_r": test_pearson_r,
            # "Lead-wise_pearson_r": lead_wise_pearson_r,
            **{f"V{i+1}_pearson_r":lead_wise_pearson_r[i] for i in range(len(config['out_leads']))},

        }),
        "model": trainer.reconstructor,
        "complexity": (macs_g, params_m),
        "datasets": {
            "train": trainer.train_dataset,
            "test": trainer.eval_dataset,
        },
    }

Test 1318/1318 [0:00:09<0:00:00, 0.00665s/it, batch_loss=0.07207, average_loss=0.05697]
INFO:root:Loss: 0.056968
INFO:root:RMSE: 0.238680
INFO:root:Lead-wise RMSE: [0.19330306148273005, 0.2524898194866796, 0.25426975052845363, 0.25952683853267666, 0.2274109721346797]
INFO:root:PearsonR: 0.845633
INFO:root:Lead-wise PearsonR: [0.8222781419754028, 0.8228965997695923, 0.8640535473823547, 0.8586218357086182, 0.860312283039093]
INFO:root:Number of parameters: 20
INFO:root:MACs (G): 0.000044
INFO:root:Params (M): 0.000020


[INFO] Register count_convNd() for <class 'torch.nn.modules.conv.Conv1d'>.


Test 1318/1318 [0:00:08<0:00:00, 0.00618s/it, batch_loss=0.05879, average_loss=0.0351] 
INFO:root:Loss: 0.035102
INFO:root:RMSE: 0.187356
INFO:root:Lead-wise RMSE: [0.1512948509387418, 0.20748248498887314, 0.20281947523621421, 0.19046275511985739, 0.17933319239259285]
INFO:root:PearsonR: 0.906245
INFO:root:Lead-wise PearsonR: [0.8911484479904175, 0.875296950340271, 0.9156298637390137, 0.9291456341743469, 0.9200048446655273]
INFO:root:Number of parameters: 79685
INFO:root:MACs (G): 0.236269
INFO:root:Params (M): 0.079685


[INFO] Register count_convNd() for <class 'torch.nn.modules.conv.Conv1d'>.
[INFO] Register count_normalization() for <class 'torch.nn.modules.batchnorm.BatchNorm1d'>.
[INFO] Register count_relu() for <class 'torch.nn.modules.activation.LeakyReLU'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.container.Sequential'>.


Test 1318/1318 [0:00:19<0:00:00, 0.0145s/it, batch_loss=0.05781, average_loss=0.03358]
INFO:root:Loss: 0.033576
INFO:root:RMSE: 0.183238
INFO:root:Lead-wise RMSE: [0.14787205487322994, 0.20337113726346737, 0.19826700966563673, 0.185220592119541, 0.17617674945592232]
INFO:root:PearsonR: 0.911205
INFO:root:Lead-wise PearsonR: [0.8969014883041382, 0.881779134273529, 0.9203380346298218, 0.9331735968589783, 0.9238311052322388]
INFO:root:Number of parameters: 155693
INFO:root:MACs (G): 0.464112
INFO:root:Params (M): 0.155693


[INFO] Register count_lstm() for <class 'torch.nn.modules.rnn.LSTM'>.
[INFO] Register count_relu() for <class 'torch.nn.modules.activation.LeakyReLU'>.
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.


Test 1318/1318 [0:00:13<0:00:00, 0.00949s/it, batch_loss=0.05938, average_loss=0.03387]
INFO:root:Loss: 0.033867
INFO:root:RMSE: 0.184029
INFO:root:Lead-wise RMSE: [0.14841682515969437, 0.20404875222670318, 0.19989592332690506, 0.1858836881522114, 0.17651912211212223]
INFO:root:PearsonR: 0.910023
INFO:root:Lead-wise PearsonR: [0.8969941735267639, 0.8792061805725098, 0.9181492924690247, 0.9322951436042786, 0.9234697818756104]
INFO:root:Number of parameters: 727221
INFO:root:MACs (G): 2.138440
INFO:root:Params (M): 0.727221


[INFO] Register count_convNd() for <class 'torch.nn.modules.conv.Conv1d'>.
[INFO] Register count_normalization() for <class 'torch.nn.modules.batchnorm.BatchNorm1d'>.
[INFO] Register count_relu() for <class 'torch.nn.modules.activation.LeakyReLU'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.container.Sequential'>.
[INFO] Register count_lstm() for <class 'torch.nn.modules.rnn.LSTM'>.
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.


Test 1318/1318 [0:00:58<0:00:00, 0.0438s/it, batch_loss=0.05622, average_loss=0.03584]
INFO:root:Loss: 0.035845
INFO:root:RMSE: 0.189327
INFO:root:Lead-wise RMSE: [0.1526412870439801, 0.20862953864841002, 0.20560744865785402, 0.1931711398101988, 0.18113158279274647]
INFO:root:PearsonR: 0.904062
INFO:root:Lead-wise PearsonR: [0.8881211876869202, 0.8748304843902588, 0.9124782085418701, 0.9266301989555359, 0.918251097202301]
INFO:root:Number of parameters: 56277892
INFO:root:MACs (G): 24.254806
INFO:root:Params (M): 56.277892


[INFO] Register count_convNd() for <class 'torch.nn.modules.conv.Conv1d'>.
[INFO] Register count_normalization() for <class 'torch.nn.modules.batchnorm.BatchNorm1d'>.
[INFO] Register count_relu() for <class 'torch.nn.modules.activation.LeakyReLU'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.container.Sequential'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.pooling.MaxPool1d'>.
[INFO] Register count_upsample() for <class 'torch.nn.modules.upsampling.Upsample'>.


Test 1318/1318 [0:00:11<0:00:00, 0.00831s/it, batch_loss=0.0669, average_loss=0.03891] 
INFO:root:Loss: 0.038912
INFO:root:RMSE: 0.197260
INFO:root:Lead-wise RMSE: [0.1601260668527707, 0.21710258271476865, 0.2137121744874866, 0.20260048048168608, 0.18725568032784587]
INFO:root:PearsonR: 0.895019
INFO:root:Lead-wise PearsonR: [0.8751391172409058, 0.8650624752044678, 0.9060326814651489, 0.9177928566932678, 0.9110673069953918]
INFO:root:Number of parameters: 901557
INFO:root:MACs (G): 0.337011
INFO:root:Params (M): 0.115125


[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
[INFO] Register count_softmax() for <class 'torch.nn.modules.activation.Softmax'>.
[INFO] Register count_normalization() for <class 'torch.nn.modules.normalization.LayerNorm'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.dropout.Dropout'>.
[INFO] Register count_convNd() for <class 'torch.nn.modules.conv.Conv1d'>.


Test 1318/1318 [0:00:19<0:00:00, 0.0142s/it, batch_loss=0.06048, average_loss=0.0341] 
INFO:root:Loss: 0.034096
INFO:root:RMSE: 0.184650
INFO:root:Lead-wise RMSE: [0.14915917665330478, 0.20550601461755588, 0.19969254606733564, 0.1864569980069159, 0.17707105668839557]
INFO:root:PearsonR: 0.910093
INFO:root:Lead-wise PearsonR: [0.8948764801025391, 0.8805539608001709, 0.9199216961860657, 0.9322670698165894, 0.9228460192680359]
INFO:root:Number of parameters: 1246072
INFO:root:MACs (G): 1.351063
INFO:root:Params (M): 0.459640


[INFO] Register count_convNd() for <class 'torch.nn.modules.conv.Conv1d'>.
[INFO] Register count_normalization() for <class 'torch.nn.modules.batchnorm.BatchNorm1d'>.
[INFO] Register count_relu() for <class 'torch.nn.modules.activation.LeakyReLU'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.container.Sequential'>.
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
[INFO] Register count_softmax() for <class 'torch.nn.modules.activation.Softmax'>.
[INFO] Register count_normalization() for <class 'torch.nn.modules.normalization.LayerNorm'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.dropout.Dropout'>.


Test 1318/1318 [0:00:30<0:00:00, 0.0225s/it, batch_loss=0.05902, average_loss=0.03706]
INFO:root:Loss: 0.037057
INFO:root:RMSE: 0.192501
INFO:root:Lead-wise RMSE: [0.15513940340341661, 0.21083901013363693, 0.20928369459946236, 0.1973010966069169, 0.18448539071987977]
INFO:root:PearsonR: 0.900701
INFO:root:Lead-wise PearsonR: [0.8850336670875549, 0.8712440133094788, 0.9100236892700195, 0.9230912923812866, 0.9141111969947815]
INFO:root:Number of parameters: 8311890
INFO:root:MACs (G): 6.284586
INFO:root:Params (M): 7.787602


[INFO] Register count_convNd() for <class 'torch.nn.modules.conv.Conv1d'>.
[INFO] Register count_normalization() for <class 'torch.nn.modules.batchnorm.BatchNorm1d'>.
[INFO] Register count_relu() for <class 'torch.nn.modules.activation.LeakyReLU'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.container.Sequential'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.pooling.MaxPool1d'>.
[INFO] Register count_upsample() for <class 'torch.nn.modules.upsampling.Upsample'>.
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
[INFO] Register count_softmax() for <class 'torch.nn.modules.activation.Softmax'>.
[INFO] Register count_normalization() for <class 'torch.nn.modules.normalization.LayerNorm'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.dropout.Dropout'>.


In [10]:
import pandas as pd
for experiment in model_dict:
    logging.info(experiment)
    logging.info(
        ", ".join(f"{k}={v:.4}" for k, v in model_dict[experiment]["result"].items())
    )
    macs_g, params_m = model_dict[experiment]["complexity"]
    logging.info("MACs=%.05fG, Params=%.05fM", macs_g, params_m)


INFO:root:Linear
INFO:root:test_loss=0.05697, test_rmse=0.2387, V1_rmse=0.1933, V2_rmse=0.2525, V3_rmse=0.2543, V4_rmse=0.2595, V5_rmse=0.2274, test_pearson_r=0.8456, V1_pearson_r=0.8223, V2_pearson_r=0.8229, V3_pearson_r=0.8641, V4_pearson_r=0.8586, V5_pearson_r=0.8603
INFO:root:MACs=0.00004G, Params=0.00002M
INFO:root:StackedCNN
INFO:root:test_loss=0.0351, test_rmse=0.1874, V1_rmse=0.1513, V2_rmse=0.2075, V3_rmse=0.2028, V4_rmse=0.1905, V5_rmse=0.1793, test_pearson_r=0.9062, V1_pearson_r=0.8911, V2_pearson_r=0.8753, V3_pearson_r=0.9156, V4_pearson_r=0.9291, V5_pearson_r=0.92
INFO:root:MACs=0.23627G, Params=0.07969M
INFO:root:LSTM
INFO:root:test_loss=0.03358, test_rmse=0.1832, V1_rmse=0.1479, V2_rmse=0.2034, V3_rmse=0.1983, V4_rmse=0.1852, V5_rmse=0.1762, test_pearson_r=0.9112, V1_pearson_r=0.8969, V2_pearson_r=0.8818, V3_pearson_r=0.9203, V4_pearson_r=0.9332, V5_pearson_r=0.9238
INFO:root:MACs=0.46411G, Params=0.15569M
INFO:root:CNNLSTM
INFO:root:test_loss=0.03387, test_rmse=0.184, V

In [12]:
import pandas as pd
result_table = pd.DataFrame.from_dict({k: v['result'] for k,v in model_dict.items()}, orient="index")
result_table.to_csv("result_table.csv")

In [4]:
model_dict['LSTM']['datasets']['test'][0]

{'input': array([[-0.07424219, -0.07128301, -0.06712826, ..., -0.00672395,
         -0.00511282, -0.00659104],
        [-0.02162265, -0.00969927,  0.00129256, ...,  0.0077694 ,
          0.00274633, -0.003356  ],
        [ 0.02833819,  0.02760309,  0.02568165, ...,  0.03326277,
          0.0288905 ,  0.0243857 ]], dtype=float32),
 'target': array([[ 0.05996339,  0.06573509,  0.06969021, ...,  0.02132249,
          0.01781764,  0.01550785],
        [ 0.02214532,  0.02707442,  0.03044287, ...,  0.02512129,
          0.02249582,  0.02102113],
        [-0.03553151, -0.02592649, -0.01871811, ...,  0.00870049,
          0.00791088,  0.00853751],
        [-0.00707445, -0.00249877,  0.0001077 , ...,  0.01635305,
          0.01142398,  0.00660076],
        [-0.00681645, -0.00157745,  0.00206726, ...,  0.01519492,
          0.01107781,  0.00631183]], dtype=float32),
 'filtered_signal': array([[-0.07424219, -0.07128301, -0.06712826, ..., -0.00672395,
         -0.00511282, -0.00659104],
        [-

: 

# Class Dependent Metrics


In [None]:
# from utils.general import compare_parameters
# from ecg.metric import AverageLoss, _LeadWiseMetricBase

from tqdm.auto import tqdm
experiments = list(model_dict.keys())
metric_results = {}
for experiment in experiments:
    print(experiment)
    dataset = model_dict[experiment]['datasets']['test']
    model = model_dict[experiment]['model']
    # dataset[1000]
    class_dependent_metrics = {}
    metric_config = {'lead_names': tuple(dataset.lead_names[i]\
        for i in dataset.out_leads)}
    for idx, items in enumerate(tqdm(dataset)):
        class_id = items['label']
        target = torch.Tensor(items['target'][None]).cuda()
        if class_id not in class_dependent_metrics:
            class_dependent_metrics[class_id] = SimpleMetricAggregate(**metric_config)
        output = model(torch.Tensor(items['input'][None]).cuda())
        loss = model.loss_fn(output, target).item()
        class_dependent_metrics[class_id](loss, output, target)
    for ck in class_dependent_metrics:
        print(f'class {ck}', end='==>')
        cur_test_results = class_dependent_metrics[ck].result
        for k in cur_test_results:
            if k != 'loss':
                print(k, cur_test_results[k][0], end = ' ')
        print()
    metric_results[experiment] = class_dependent_metrics


## Visualization

Visualize a single model.

In [11]:
@widgets.interact(experiment=list(model_dict.keys()))
def run_visualization_one(experiment):
    dataset = model_dict[experiment]["datasets"]["test"]
    model = model_dict[experiment]["model"]

    @widgets.interact(dataset_index=(0, len(dataset) - 1), channel=np.arange(12))
    def visualize_impl(dataset_index: int, channel: int) -> None:
        plt.figure(figsize=(8, 2), dpi=100)
        visualize_model(
            plt.gca(),
            model,
            dataset,
            dataset_index,
            channel,
        )
        plt.show()
        # 0.5 - 60 -> 3 - 40

interactive(children=(Dropdown(description='experiment', options=('CNNLSTM',), value='CNNLSTM'), Output()), _d…

Visualize all models in a single interactive widget.

In [9]:
visualized_dataset = next(iter(model_dict.values()))["datasets"]["test"]


@widgets.interact(dataset_index=(0, len(visualized_dataset) - 1), channel=np.arange(12))
def visualize_model_v1(dataset_index: int, channel: int) -> None:
    fig = plt.figure(figsize=(8, len(model_dict) * 2.5), dpi=100)
    visualize_model_all(
        list(model_dict.keys()),
        plt.gca(),
        model_dict,
        visualized_dataset,
        dataset_index,
        channel,
    )
    plt.subplots_adjust(
        left=0.1, bottom=0.2, right=0.9, top=0.8, wspace=0.4, hspace=0.4
    )
    plt.show()

# PTB-XL
# Unet: 1850
# FastformerPlus: 2

# Code
# FastformerPlus: 30965-6, 9527-9 (even reconstructs noise), 13587-9, 19283-10, 

interactive(children=(IntSlider(value=21085, description='dataset_index', max=42171), Dropdown(description='ch…