In [None]:
import pandas as pd
import torch

# ^^^ pyforest auto-imports - don't write above this line
import os

os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "2,3,4,5,6,7"

In [None]:
torch.cuda.device_count()

In [None]:
import sys

# ^^^ pyforest auto-imports - don't write above this line
sys.path.insert(0, "/home/wenhao/Jupyter/wenhao/workspace/torch_ecg/")
sys.path.insert(0, "/home/wenhao/Jupyter/wenhao/workspace/bib_lookup/")

import os
import pickle
import time
from copy import deepcopy
from pathlib import Path
from typing import Dict, Union, Tuple, Sequence

import numpy as np
import torch
from sklearn.base import BaseEstimator
from sklearn.ensemble import RandomForestClassifier, RandomForestRegressor
from sklearn.impute import SimpleImputer
from torch.nn.parallel import (  # noqa: F401
    DistributedDataParallel as DDP,
    DataParallel as DP,
)  # noqa: F401
from torch_ecg.cfg import CFG
from torch_ecg.utils.misc import str2bool
from torch_ecg._preprocessors import PreprocManager
from tqdm.auto import tqdm

from cfg import TrainCfg, ModelCfg, MLCfg
from dataset import CinC2023Dataset
from models import CRNN_CINC2023, ML_Classifier_CINC2023
from trainer import CINC2023Trainer, _set_task
from helper_code import find_data_folders
from utils.features import get_features, get_labels
from utils.misc import (
    load_challenge_metadata,
    load_challenge_eeg_data,
    find_eeg_recording_files,
)
from utils.sqi import compute_sqi  # noqa: F401

%load_ext autoreload
%autoreload 2

## DL model

In [None]:
TEST_FLAG = False

TASK = "classification"  # "classification", "regression"

# choices of the models
TrainCfg[TASK].model_name = "crnn"

# "tresnetS"  # "resnet_nature_comm", "tresnetF", etc.
TrainCfg[TASK].cnn_name = "resnet_nature_comm_bottle_neck_se"

# TrainCfg[TASK].rnn_name = "none"  # "none", "lstm"
# TrainCfg[TASK].attn_name = "se"  # "none", "se", "gc", "nl"

ENHANCED_ML_MODEL = True

_ModelFilename = "final_model_main.pth.tar"
_ModelFilename_ml = "final_model_ml.pkl"
_ModelFilename_ml_min_guarantee = "final_model_ml_min_guarantee.pkl"

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if ModelCfg.torch_dtype == torch.float64:
    torch.set_default_tensor_type(torch.DoubleTensor)
    DTYPE = np.float64
else:
    DTYPE = np.float32

CinC2023Dataset.__DEBUG__ = False
CRNN_CINC2023.__DEBUG__ = False
CINC2023Trainer.__DEBUG__ = False

EEG_BIPOLAR_CHANNELS = [
    [pair.split("-")[0] for pair in TrainCfg.eeg_bipolar_channels],
    [pair.split("-")[1] for pair in TrainCfg.eeg_bipolar_channels],
]

In [None]:
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if ModelCfg.torch_dtype == torch.float64:
    torch.set_default_tensor_type(torch.DoubleTensor)
    DTYPE = np.float64
else:
    DTYPE = np.float32

In [None]:
TASK = "classification"

train_config = deepcopy(TrainCfg)
# train_config.db_dir = data_folder
# train_config.model_dir = model_folder
# train_config.final_model_filename = _ModelFilename
train_config.debug = True

train_config.db_dir = "/data1/Jupyter-Data/CinC2023-new/"

train_config.n_epochs = 65
train_config.batch_size = 36  # 16G (Tesla T4)
train_config.reload_data_every = 5
# train_config.log_step = 20
# # train_config.max_lr = 1.5e-3
train_config.early_stopping.patience = int(train_config.n_epochs * 0.55)

# augmentations configurations
# train_config.classification.label_smooth = False
# train_config.classification.random_masking = False
# train_config.classification.stretch_compress = False  # stretch or compress in time axis
# train_config.classification.mixup = CFG(
#     prob=0.6,
#     alpha=0.3,
# )

# train_config[TASK].pop("normalize")

train_config[TASK].model_name = "crnn"  # "crnn"

train_config[
    TASK
].cnn_name = "resnet_nature_comm_bottle_neck_se"  # "resnet_nature_comm_bottle_neck_se"
# train_config[TASK].rnn_name = "none"  # "none", "lstm"
# train_config[TASK].attn_name = "se"  # "none", "se", "gc", "nl"

_set_task(TASK, train_config)

model_config = deepcopy(ModelCfg[TASK])

# adjust model choices if needed
model_config.model_name = train_config[TASK].model_name
# print(model_name)
if "cnn" in model_config[model_config.model_name]:
    model_config[model_config.model_name].cnn.name = train_config[TASK].cnn_name
if "rnn" in model_config[model_config.model_name]:
    model_config[model_config.model_name].rnn.name = train_config[TASK].rnn_name
if "attn" in model_config[model_config.model_name]:
    model_config[model_config.model_name].attn.name = train_config[TASK].attn_name

In [None]:
model = CRNN_CINC2023(config=model_config)
if torch.cuda.device_count() > 1:
    model = DP(model)
    # model = DDP(model)
model = model.to(device=DEVICE)

In [None]:
if isinstance(model, DP):
    print(model.module.module_size, model.module.module_size_)
else:
    print(model.module_size, model.module_size_)

In [None]:
model

In [None]:
ds_train = CinC2023Dataset(train_config, TASK, training=True, lazy=True)
ds_test = CinC2023Dataset(train_config, TASK, training=False, lazy=True)

In [None]:
ds_test._load_all_data()

In [None]:
ds_train._load_all_data()

In [None]:
# ds_test.cache["waveforms"].shape

In [None]:
# ds_train.cache["cpc"]

In [None]:
len(ds_test.reader)

In [None]:
trainer = CINC2023Trainer(
    model=model,
    model_config=model_config,
    train_config=train_config,
    device=DEVICE,
    lazy=True,
)

In [None]:
trainer._setup_dataloaders(ds_train, ds_test)

In [None]:
best_state_dict = trainer.train()

## ML model

In [None]:
verbose = 2
data_folder = "/home/wenh06/Jupyter/wenhao/data/CinC2023/training/"

In [None]:
patient_ids = find_data_folders(data_folder)
num_patients = len(patient_ids)

if num_patients == 0:
    raise FileNotFoundError("No data was provided.")
else:
    if verbose >= 1:
        print(f"Found {num_patients} patients.")

In [None]:
if verbose >= 1:
    print("Extracting features and labels from the Challenge data...")

features = list()
outcomes = list()
cpcs = list()

for i in tqdm(
    range(num_patients),
    desc="Extracting features and labels",
    total=num_patients,
    dynamic_ncols=True,
    mininterval=1.0,
    disable=verbose < 2,
):
    # Load data.
    patient_id = patient_ids[i]
    patient_metadata = load_challenge_metadata(data_folder, patient_id)

    # Extract features.
    current_features = get_features(patient_metadata)
    features.append(current_features)

    # Extract labels.
    current_labels = get_labels(patient_metadata)
    outcomes.append(current_labels["outcome"])
    cpcs.append(current_labels["cpc"])

features = np.vstack(features)
outcomes = np.vstack(outcomes)
cpcs = np.vstack(cpcs)

In [None]:
# Define parameters for random forest classifier and regressor.
n_estimators = 42  # Number of trees in the forest.
max_leaf_nodes = 456  # Maximum number of leaf nodes in each tree.
random_state = 789  # Random state; set for reproducibility.

# Impute any missing features; use the mean value by default.
imputer = SimpleImputer().fit(features)

# Train the models.
features = imputer.transform(features)
outcome_model = RandomForestClassifier(
    n_estimators=n_estimators,
    max_leaf_nodes=max_leaf_nodes,
    random_state=random_state,
).fit(features, outcomes.ravel())
cpc_model = RandomForestRegressor(
    n_estimators=n_estimators,
    max_leaf_nodes=max_leaf_nodes,
    random_state=random_state,
).fit(features, cpcs.ravel())

d = {"imputer": imputer, "outcome_model": outcome_model, "cpc_model": cpc_model}

In [None]:
# with open("./tmp/final_model_ml.pkl", "wb") as f:
#     pickle.dump(d, f)