In [None]:
import torch
import numpy as np
import wandb
import pandas as pd
import ast
import warnings
from torch.utils.data import DataLoader
from omegaconf import OmegaConf
from pathlib import Path
from src.modeling.siglabv2.siglabv2 import SigLabV2
from src.modeling.datasets.siglab_dataset import SigLabDataset
from src.run import lead_sets
from src.utils import count_parameters, confusion_matrix, apply_preprocessors
from src.data.load_ptbdata_new import PRECORDIAL_LEADS, ALL_LEADS
from src.evaluation import lead_level_accuracy, set_level_accuracy

# Ignore warnings
warnings.filterwarnings("ignore", category=UserWarning)
warnings.filterwarnings("ignore", category=FutureWarning)

In [2]:
project_path = "nateml-maastricht-university/bachelors-thesis"
#run_id = "kuq34vvz"
#version = "v24"
run_id = "yh0by5uj"  # replace with wandb run id of interest
version = "v30"  # model checkpoint
device = "cuda" if torch.cuda.is_available() else "cpu"

# Load the torch model from wandb
api = wandb.Api()

run = api.run(f"{project_path}/{run_id}")
config = dict(run.config)

# Get the run name
run_name = run.name

artifact = api.artifact(f"{project_path}/{run_name}:{version}")
artifact_path = artifact.download()

# Convert config to omegaconf
cfg = OmegaConf.create(config)

# Load the model
checkpoint = torch.load(artifact_path + f"/{run_name}.pth", map_location=device)

# Cast the model to the correct type
model = SigLabV2(cfg.model).to(device)
model.load_state_dict(checkpoint)
model.eval()  # Put into evaluation mode

# Count number of parameters
print(f"Loaded model from wandb: {run_name}")
count_parameters(model)

[34m[1mwandb[0m:   1 of 1 files downloaded.  


Loaded model from wandb: 63_siglabv2_inception_gru_2_2
+----------------------------------------+------------+
|                Modules                 | Parameters |
+----------------------------------------+------------+
|             encoder.alpha              |     1      |
| encoder.cnn_encoder.0.branch1.0.weight |     32     |
|  encoder.cnn_encoder.0.branch1.0.bias  |     32     |
| encoder.cnn_encoder.0.branch1.1.weight |     32     |
|  encoder.cnn_encoder.0.branch1.1.bias  |     32     |
| encoder.cnn_encoder.0.branch2.0.weight |     1      |
|  encoder.cnn_encoder.0.branch2.0.bias  |     1      |
| encoder.cnn_encoder.0.branch2.1.weight |     1      |
|  encoder.cnn_encoder.0.branch2.1.bias  |     1      |
| encoder.cnn_encoder.0.branch2.3.weight |    160     |
|  encoder.cnn_encoder.0.branch2.3.bias  |     32     |
| encoder.cnn_encoder.0.branch2.4.weight |     32     |
|  encoder.cnn_encoder.0.branch2.4.bias  |     32     |
| encoder.cnn_encoder.0.branch3.0.weight |     1 

(1365845, 307017)

In [34]:
dataset_path = Path("../../" + cfg.dataset.path)
if OmegaConf.select(cfg, "dataset.only_precordial"):
    dataset_path = dataset_path / "precordial"
else:
    dataset_path = dataset_path / "all"
dataset_path = dataset_path.resolve()

val_data = np.load(dataset_path / "val.npy")
test_data = np.load(dataset_path / "test.npy")

# Apply preprocessors
val_data = apply_preprocessors(val_data,
                               cfg.dataset.sampling_rate,
                               cfg.preprocessor_group.preprocessors)
test_data = apply_preprocessors(test_data,
                                 cfg.dataset.sampling_rate,
                                 cfg.preprocessor_group.preprocessors)
 
# Convert to torch tensor
val_data = torch.from_numpy(val_data).float().to(device)
val_data = val_data.permute(0, 2, 1)
print(val_data.shape)

test_data = torch.from_numpy(test_data).float().to(device)
test_data = test_data.permute(0, 2, 1)

# Create the dataset and dataloader
lead_filter = lead_sets[OmegaConf.select(cfg, "run.leads", default="precordial")]
dataset = SigLabDataset(test_data, filter_leads=lead_filter)
batch_size = 32
dataloader = DataLoader(dataset,
                        batch_size=batch_size,
                        shuffle=False)

# I need to reorder val_data to match the order of the labels in the dataset
if OmegaConf.select(cfg, "dataset.only_precordial") or OmegaConf.select(cfg, "dataset.only_precordial") is None:
    val_data = val_data[:, [PRECORDIAL_LEADS.index(lead) for lead in lead_filter], :]
    test_data = test_data[:, [PRECORDIAL_LEADS.index(lead) for lead in lead_filter], :]
else:
    val_data = val_data[:, [ALL_LEADS.index(lead) for lead in lead_filter], :]
    test_data = test_data[:, [ALL_LEADS.index(lead) for lead in lead_filter], :]

# Load metadata
meta_val = pd.read_csv(dataset_path / "meta_val.csv")
meta_test = pd.read_csv(dataset_path / "meta_test.csv")
meta = meta_test.copy()
meta['scp_codes'] = meta['scp_codes'].apply(lambda x: ast.literal_eval(x))

THRESHOLD = 0.5
def codes_above_threshold(code_dict, thr=THRESHOLD):
    return {code for code, prob in code_dict.items() if prob >= thr}

meta["present_codes"] = meta["scp_codes"].apply(lambda x: codes_above_threshold(x, THRESHOLD))

# Treat diagnostic superclass as lists
meta["diagnostic_superclass"] = meta["diagnostic_superclass"].apply(lambda x: ast.literal_eval(x))
# Now convert to a set
meta["diagnostic_superclass"] = meta["diagnostic_superclass"].apply(lambda x: set(x))

c = cfg.model.num_classes
logits = np.zeros((len(dataset), c, c))
init_logits = np.zeros((len(dataset), c, c))
targets = np.zeros((len(dataset), c))

for idx, (signals, lead_labels) in enumerate(dataloader):
    signals = signals.to(device)
    lead_labels = lead_labels.to(device)

    # Make predictions
    with torch.no_grad():
        these_logits = model(signals)
        logits[(idx * batch_size):(idx * batch_size + batch_size)] = these_logits.cpu().numpy()
        targets[(idx * batch_size):(idx * batch_size + batch_size)] = lead_labels.cpu().numpy()

predictions = logits.argmax(axis=-1)
print(predictions.shape)

torch.Size([2183, 6, 1000])
(2198, 6)


In [35]:
# Calculate confusion matrices for each PTB-XL diagnostic class
# Get indices where 'NORM' is in the 'diagnostic_superclass' column which contains lists
norm_idx = np.where(meta['diagnostic_superclass'].apply(lambda x: 'NORM' in x))[0]
cm_norm = confusion_matrix(predictions[norm_idx], targets[norm_idx])

mi_idx = np.where(meta['diagnostic_superclass'].apply(lambda x: 'MI' in x))[0]
cm_mi = confusion_matrix(predictions[mi_idx], targets[mi_idx])

cd_idx = np.where(meta['diagnostic_superclass'].apply(lambda x: 'CD' in x))[0]
cm_cd = confusion_matrix(predictions[cd_idx], targets[cd_idx])

mi_cd_idx = np.where(meta['diagnostic_superclass'].apply(lambda x: 'MI' in x or 'CD' in x))[0]
cm_mi_cd = confusion_matrix(predictions[mi_cd_idx], targets[mi_cd_idx])

sttc_idx = np.where(meta['diagnostic_superclass'].apply(lambda x: 'STTC' in x))[0]
cm_sttc = confusion_matrix(predictions[sttc_idx], targets[sttc_idx])

hyp_idx = np.where(meta['diagnostic_superclass'].apply(lambda x: 'HYP' in x))[0]
cm_hyp = confusion_matrix(predictions[hyp_idx], targets[hyp_idx])

In [37]:
lead_acc = lead_level_accuracy(predictions=predictions, targets=targets)
set_acc = set_level_accuracy(predictions=predictions, targets=targets)
print(f"Lead-level accuracy: {lead_acc:.4f}")
print(f"Set-level accuracy: {set_acc:.4f}")

norm_acc = cm_norm.diagonal().sum() / cm_norm.sum()
print(f"Normal class accuracy: {norm_acc:.4f}")

mi_acc = cm_mi.diagonal().sum() / cm_mi.sum()
print(f"MI class accuracy: {mi_acc:.4f}")

mi_cd_acc = cm_mi_cd.diagonal().sum() / cm_mi_cd.sum()
print(f"MI+CD class accuracy: {mi_cd_acc:.4f}")

Lead-level accuracy: 0.9810
Set-level accuracy: 0.9472
Normal class accuracy: 0.9922
MI class accuracy: 0.9718
MI+CD class accuracy: 0.9723


In [38]:
# Wilson score interval for lead-level accuracy
def wilson_score_interval(successes, trials, confidence=0.95):
    # Convert confidence level to z-score
    z = abs(np.percentile(np.random.normal(size=1000000), (1 - confidence) * 100 / 2))
    p = successes / trials
    denominator = 1 + z**2 / trials
    centre_adjusted_probability = p + z**2 / (2 * trials)
    adjusted_standard_deviation = np.sqrt((p * (1 - p) + z**2 / (4 * trials)) / trials)
    lower_bound = (centre_adjusted_probability - z * adjusted_standard_deviation) / denominator
    upper_bound = (centre_adjusted_probability + z * adjusted_standard_deviation) / denominator
    return lower_bound, upper_bound, z

lead_successes = np.sum(predictions.flatten() == targets.flatten(), axis=0)
lead_trials = len(predictions.flatten())

lower_bound, upper_bound, z = wilson_score_interval(lead_successes, lead_trials)
print(f"Wilson score interval for lead-level accuracy: [{lower_bound:.4f}, {upper_bound:.4f}], z-score: {z:.4f}")


set_successes = np.sum(np.all(predictions == targets, axis=1))
set_trials = len(predictions)

set_lower_bound, set_upper_bound, z = wilson_score_interval(set_successes, set_trials)
print(f"Wilson score interval for set-level accuracy: [{set_lower_bound:.4f}, {set_upper_bound:.4f}], z-score: {z:.4f}")

normal_successes = np.sum(predictions[norm_idx].flatten() == targets[norm_idx].flatten())
normal_trials = len(predictions[norm_idx].flatten())

normal_lower_bound, normal_upper_bound, z = wilson_score_interval(normal_successes, normal_trials)
print(f"Wilson score interval for normal class: [{normal_lower_bound:.4f}, {normal_upper_bound:.4f}], z-score: {z:.4f}")

mi_successes = np.sum(predictions[mi_idx].flatten() == targets[mi_idx].flatten())
mi_trials = len(predictions[mi_idx].flatten())

mi_lower_bound, mi_upper_bound, z = wilson_score_interval(mi_successes, mi_trials)
print(f"Wilson score interval for MI class: [{mi_lower_bound:.4f}, {mi_upper_bound:.4f}], z-score: {z:.4f}")

mi_cd_successes = np.sum(predictions[mi_cd_idx].flatten() == targets[mi_cd_idx].flatten())
mi_cd_trials = len(predictions[mi_cd_idx].flatten())
mi_cd_lower_bound, mi_cd_upper_bound, z = wilson_score_interval(mi_cd_successes, mi_cd_trials)
print(f"Wilson score interval for MI+CD class: [{mi_cd_lower_bound:.4f}, {mi_cd_upper_bound:.4f}], z-score: {z:.4f}")

Wilson score interval for lead-level accuracy: [0.9785, 0.9832], z-score: 1.9594
Wilson score interval for set-level accuracy: [0.9370, 0.9558], z-score: 1.9650
Wilson score interval for normal class: [0.9896, 0.9942], z-score: 1.9630
Wilson score interval for MI class: [0.9656, 0.9769], z-score: 1.9595
Wilson score interval for MI+CD class: [0.9674, 0.9764], z-score: 1.9620
