In [1]:
import torch
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import wandb
import pandas as pd
import os
import ast
import warnings
from torch.utils.data import DataLoader
from omegaconf import OmegaConf
from pathlib import Path
from matplotlib.colors import LinearSegmentedColormap
from bachelors_thesis.modeling.siglabv2.siglabv2 import SigLabV2
from bachelors_thesis.modeling.datasets.sigloc_dataset import SigLocDataset
from bachelors_thesis.run import lead_sets
from bachelors_thesis.utils import count_parameters, confusion_matrix, apply_preprocessors, PRECORDIAL_LEAD_NAMES
from bachelors_thesis.data.load_ptbdata_new import PRECORDIAL_LEADS, LIMB_LEADS, AUGMENTED_LEADS, ALL_LEADS

# Ignore warnings
warnings.filterwarnings("ignore", category=UserWarning)
warnings.filterwarnings("ignore", category=FutureWarning)

[32m2025-05-16 18:56:15.218[0m | [1mINFO    [0m | [36mbachelors_thesis.config[0m:[36m<module>[0m:[36m11[0m - [1mPROJ_ROOT path is: C:\Users\natem\bachelors_thesis[0m


In [2]:
project_path = "nateml-maastricht-university/bachelors-thesis"
#run_id = "83mabq4k"
#version = "v22"
run_id = "yh0by5uj"
version = "v30"
device = "cuda" if torch.cuda.is_available() else "cpu"

# Load the torch model from wandb
api = wandb.Api()

run = api.run(f"{project_path}/{run_id}")
config = dict(run.config)

# Get the run name
run_name = run.name

artifact = api.artifact(f"{project_path}/{run_name}:{version}")
artifact_path = artifact.download()

# Convert config to omegaconf
cfg = OmegaConf.create(config)

# Load the model
checkpoint = torch.load(artifact_path + f"/{run_name}.pth", map_location=device)

# Cast the model to the correct type
model = SigLabV2(cfg.model).to(device)
model.load_state_dict(checkpoint)
model.eval()  # Put into evaluation mode

# Count number of parameters
print(f"Loaded model from wandb: {run_name}")
count_parameters(model)

[34m[1mwandb[0m:   1 of 1 files downloaded.  


Loaded model from wandb: 63_siglabv2_inception_gru_2_2
+----------------------------------------+------------+
|                Modules                 | Parameters |
+----------------------------------------+------------+
|             encoder.alpha              |     1      |
| encoder.cnn_encoder.0.branch1.0.weight |     32     |
|  encoder.cnn_encoder.0.branch1.0.bias  |     32     |
| encoder.cnn_encoder.0.branch1.1.weight |     32     |
|  encoder.cnn_encoder.0.branch1.1.bias  |     32     |
| encoder.cnn_encoder.0.branch2.0.weight |     1      |
|  encoder.cnn_encoder.0.branch2.0.bias  |     1      |
| encoder.cnn_encoder.0.branch2.1.weight |     1      |
|  encoder.cnn_encoder.0.branch2.1.bias  |     1      |
| encoder.cnn_encoder.0.branch2.3.weight |    160     |
|  encoder.cnn_encoder.0.branch2.3.bias  |     32     |
| encoder.cnn_encoder.0.branch2.4.weight |     32     |
|  encoder.cnn_encoder.0.branch2.4.bias  |     32     |
| encoder.cnn_encoder.0.branch3.0.weight |     1 

(1365845, 307017)

In [3]:
dataset_path = Path("../../" + cfg.dataset.path)
if OmegaConf.select(cfg, "dataset.only_precordial"):
    dataset_path = dataset_path / "precordial"
else:
    dataset_path = dataset_path / "all"
dataset_path = dataset_path.resolve()

val_data = np.load(dataset_path / "val.npy")
test_data = np.load(dataset_path / "test.npy")

# Apply preprocessors
val_data = apply_preprocessors(val_data,
                               cfg.dataset.sampling_rate,
                               cfg.preprocessor_group.preprocessors)
 
# Convert to torch tensor
val_data = torch.from_numpy(val_data).float().to(device)
val_data = val_data.permute(0, 2, 1)
print(val_data.shape)

test_data = torch.from_numpy(test_data).float().to(device)
test_data = test_data.permute(0, 2, 1)

# Create the dataset and dataloader
lead_filter = lead_sets[OmegaConf.select(cfg, "run.leads", default="precordial")]
dataset = SigLocDataset(val_data, filter_leads=lead_filter)
batch_size = 32
dataloader = DataLoader(dataset,
                        batch_size=batch_size,
                        shuffle=False)

# I need to reorder val_data to match the order of the labels in the dataset
if OmegaConf.select(cfg, "dataset.only_precordial") or OmegaConf.select(cfg, "dataset.only_precordial") is None:
    val_data = val_data[:, [PRECORDIAL_LEADS.index(lead) for lead in lead_filter], :]
else:
    val_data = val_data[:, [ALL_LEADS.index(lead) for lead in lead_filter], :]

# Load metadata
meta_val = pd.read_csv(dataset_path / "meta_val.csv")
meta_test = pd.read_csv(dataset_path / "meta_test.csv")
meta = meta_val.copy()
meta['scp_codes'] = meta['scp_codes'].apply(lambda x: ast.literal_eval(x))

THRESHOLD = 0.5
def codes_above_threshold(code_dict, thr=THRESHOLD):
    return {code for code, prob in code_dict.items() if prob >= thr}

meta["present_codes"] = meta["scp_codes"].apply(lambda x: codes_above_threshold(x, THRESHOLD))

# Treat diagnostic superclass as lists
meta["diagnostic_superclass"] = meta["diagnostic_superclass"].apply(lambda x: ast.literal_eval(x))
# Now convert to a set
meta["diagnostic_superclass"] = meta["diagnostic_superclass"].apply(lambda x: set(x))

c = cfg.model.num_classes
logits = np.zeros((len(dataset), c, c))
init_logits = np.zeros((len(dataset), c, c))
targets = np.zeros((len(dataset), c))

for idx, (signals, lead_labels) in enumerate(dataloader):
    signals = signals.to(device)
    lead_labels = lead_labels.to(device)

    # Make predictions
    with torch.no_grad():
        these_logits = model(signals)
        logits[(idx * batch_size):(idx * batch_size + batch_size)] = these_logits.cpu().numpy()
        targets[(idx * batch_size):(idx * batch_size + batch_size)] = lead_labels.cpu().numpy()

predictions = logits.argmax(axis=1)
print(predictions.shape)

torch.Size([2183, 6, 1000])
(2183, 6)


In [10]:
predictions[0]

array([0, 1, 2, 3, 4, 5])

In [24]:
def detect_swap(predictions, lead1, lead2):
    """
    Detects if the model is predicting V1 and V2 in the wrong order.
    This is done by checking if the predicted class for V1 is V2 and
    the predicted class for V2 is V1.
    """
    i1 = PRECORDIAL_LEADS.index(lead1)
    i2 = PRECORDIAL_LEADS.index(lead2)
    swap = (predictions[:, i1] == i1) & (predictions[:, i2] == i2)
    return swap

# Calculate the swap accuracy
swaps = detect_swap(predictions, "V1", "V6")
# Calculate the swap accuracy
swap_accuracy = np.mean(swaps)
print(f"Swap accuracy: {swap_accuracy:.4f}")

Swap accuracy: 0.9789


In [21]:
np.where(~swaps)

(array([  23,   34,   35,   36,   52,   78,   83,   92,  122,  185,  198,
         215,  239,  304,  387,  463,  468,  471,  483,  486,  524,  538,
         578,  580,  581,  632,  683,  686,  689,  717,  727,  749,  837,
         848,  859,  872,  881,  883,  905,  936,  948,  955,  956,  961,
        1004, 1056, 1058, 1074, 1077, 1198, 1206, 1215, 1218, 1220, 1240,
        1250, 1259, 1290, 1293, 1296, 1320, 1379, 1395, 1407, 1468, 1486,
        1510, 1546, 1550, 1595, 1614, 1632, 1647, 1753, 1754, 1789, 1862,
        1911, 2004, 2041, 2177]),)