## Calculate confusion matrices and performance metrics using pretrained weights, domain adaptation without ensemble classification.

In [1]:
import import_ipynb
from ecg_utilities import *

import torch.nn.functional as Func

from pytorch_sklearn import NeuralNetwork
from pytorch_sklearn.callbacks import WeightCheckpoint, Verbose, LossPlot, EarlyStopping, Callback, CallbackInfo
from pytorch_sklearn.utils.func_utils import to_safe_tensor

importing Jupyter notebook from ecg_utilities.ipynb


In [2]:
patient_ids = pd.read_csv(osj("..", "files", "patient_ids.csv"), header=None).to_numpy().reshape(-1)
valid_patients = pd.read_csv(osj("..", "files", "valid_patients.csv"), header=None).to_numpy().reshape(-1)

In [3]:
DATASET_PATH = osj("..", "dataset_training", "dataset_domain_adapted")
TRIO_PATH = osj("..", "dataset_training", "dataset_beat_trios_domain_adapted")

In [4]:
LOAD_PATH = osj("..", "pretrained", "nets")
SAVE_PATH = osj("..", "savefolder")

In [5]:
max_epochs = [-1]
batch_sizes = [1024]

## What is collected?
- ### Per repeat:
    - #### Confusion matrices per patient (34 in total).
    - #### Cumulative confusion matrix (1 in total).

In [6]:
all_patient_cms = []
all_cms = []
repeats = 10

for repeat in range(repeats):
    patient_cms = {}
    cm = torch.zeros(2, 2)
    
    for i, patient_id in enumerate(valid_patients):
        dataset = load_N_channel_dataset(patient_id, DATASET_PATH, TRIO_PATH)
        train_X, train_y, train_ids, val_X, val_y, val_ids, test_X, test_y, test_ids = dataset.values()

        # Train the neural network.
        model = get_base_model(in_channels=train_X.shape[1])
        # model = model.to("cuda")
        crit = nn.CrossEntropyLoss()
        optim = torch.optim.AdamW(params=model.parameters())
        
        net = NeuralNetwork.load_class(osj(LOAD_PATH, f"net_{repeat+1}_{patient_id}"), model, optim, crit)
        weight_checkpoint_val_loss = net.cbmanager.callbacks[1]  # <- this needs to change in case weight checkpoint is not the second callback.
        
        net.load_weights(weight_checkpoint_val_loss)
        
        pred_y = net.predict(test_X, batch_size=1024, use_cuda=False, fits_gpu=False, decision_func=lambda pred_y: pred_y.argmax(dim=1)).cpu()
        
        cur_cm = get_confusion_matrix(pred_y, test_y, pos_is_zero=False)
        patient_cms[patient_id] = cur_cm
        cm += cur_cm

        print_progress(i + 1, len(valid_patients), opt=[f"{patient_id}"])
        
    all_patient_cms.append(patient_cms)
    all_cms.append(cm)



In [7]:
config = dict(
    learning_rate=0.001,
    max_epochs=max_epochs[0],
    batch_size=batch_sizes[0],
    optimizer=optim.__class__.__name__,
    loss=crit.__class__.__name__,
    early_stopping="true",
    checkpoint_on=weight_checkpoint_val_loss.tracked,
    dataset="default+trio",
    info="Results replicated for GitHub, just DA."
)

In [8]:
all_cms = np.stack(all_cms).astype(int)

In [10]:
get_performance_metrics(all_cms.sum(axis=0))

{'acc': 0.977506173222966,
 'rec': 0.9068441542132362,
 'spe': 0.9874778478769405,
 'pre': 0.910870607230253,
 'npv': 0.9868622394763314,
 'f1': 0.9088529211871019}

In [12]:
if False:
    with open(osj(SAVE_PATH, "cms.pkl"), "wb") as f:
        pickle.dump(all_cms, f)
        
    with open(osj(SAVE_PATH, "config.pkl"), "wb") as f:
        pickle.dump(config, f)
        
    with open(osj(SAVE_PATH, "patient_cms.pkl"), "wb") as f:
        pickle.dump(all_patient_cms, f)