In [None]:
import torch
from torch.utils.data import Dataset, DataLoader
import numpy as np

# --- Load ECG leads and labels ---
lead_file_paths = {
    "LEAD_I": "../Data_processing/data_aspire_PAP/LEAD_I.pt",
    "LEAD_II": "../Data_processing/data_aspire_PAP/LEAD_II.pt",
    "LEAD_III": "../Data_processing/data_aspire_PAP/LEAD_III.pt",
    "LEAD_aVR": "../Data_processing/data_aspire_PAP/LEAD_aVR.pt",
    "LEAD_aVL": "../Data_processing/data_aspire_PAP/LEAD_aVL.pt",
    "LEAD_aVF": "../Data_processing/data_aspire_PAP/LEAD_aVF.pt",
    "LEAD_V1": "../Data_processing/data_aspire_PAP/LEAD_V1.pt",
    "LEAD_V2": "../Data_processing/data_aspire_PAP/LEAD_V2.pt",
    "LEAD_V3": "../Data_processing/data_aspire_PAP/LEAD_V3.pt",
    "LEAD_V4": "../Data_processing/data_aspire_PAP/LEAD_V4.pt",
    "LEAD_V5": "../Data_processing/data_aspire_PAP/LEAD_V5.pt",
    "LEAD_V6": "../Data_processing/data_aspire_PAP/LEAD_V6.pt"
}
labels_file_path = "../Data_processing/data_aspire_PAP/labels.pt"

ecg_lead_tensors = {lead: torch.load(path) for lead, path in lead_file_paths.items()}
labels = torch.load(labels_file_path)

sample_count = len(next(iter(ecg_lead_tensors.values())))
assert len(labels) == sample_count
for tensor in ecg_lead_tensors.values():
    assert len(tensor) == sample_count

# --- Dataset Class ---
class ECGMultiLeadDatasetWithLabels(Dataset):
    def __init__(self, ecg_leads, labels):
        self.ecg_leads = ecg_leads
        self.labels = labels

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        lead_data = {lead: self.ecg_leads[lead][idx].unsqueeze(0) for lead in self.ecg_leads}
        label = self.labels[idx]
        return lead_data, label

# --- Create Balanced & Imbalanced Datasets ---

labels_np = labels.numpy()
class0_indices = np.where(labels_np == 0)[0]
class1_indices = np.where(labels_np == 1)[0]

# Balanced dataset
min_class_count = min(len(class0_indices), len(class1_indices))
balanced_class0 = np.random.choice(class0_indices, min_class_count, replace=False)
balanced_class1 = np.random.choice(class1_indices, min_class_count, replace=False)
balanced_indices = np.concatenate([balanced_class0, balanced_class1])
np.random.shuffle(balanced_indices)

balanced_leads = {lead: tensor[balanced_indices] for lead, tensor in ecg_lead_tensors.items()}
balanced_labels = labels[balanced_indices]
balanced_dataset = ECGMultiLeadDatasetWithLabels(balanced_leads, balanced_labels)

# Imbalanced dataset: all class 0 + half of class 1
imbal_class1_count = len(class0_indices) // 2
imbal_class1 = np.random.choice(class1_indices, imbal_class1_count, replace=False)
imbal_indices = np.concatenate([class0_indices, imbal_class1])
np.random.shuffle(imbal_indices)

imbalanced_leads = {lead: tensor[imbal_indices] for lead, tensor in ecg_lead_tensors.items()}
imbalanced_labels = labels[imbal_indices]
imbalanced_dataset = ECGMultiLeadDatasetWithLabels(imbalanced_leads, imbalanced_labels)

# --- Print Stats ---
def count_labels(labels_tensor):
    unique, counts = np.unique(labels_tensor.numpy(), return_counts=True)
    return dict(zip(unique, counts))

print("Balanced dataset size:", len(balanced_dataset))
print("Balanced label distribution:", count_labels(balanced_labels))
print("Imbalanced dataset size:", len(imbalanced_dataset))
print("Imbalanced label distribution:", count_labels(imbalanced_labels))


In [86]:
import torch.nn as nn
class ECGLeadClassifier(nn.Module):
    def __init__(self, pretrained_mopoe, num_classes, use_12_leads=True):
        super(ECGLeadClassifier, self).__init__()

        # Define lead names based on the mode
        self.use_12_leads = use_12_leads
        self.lead_names = (
            ["LEAD_I", "LEAD_II", "LEAD_III", "LEAD_aVR", "LEAD_aVF", "LEAD_aVL",
             "LEAD_V1", "LEAD_V2", "LEAD_V3", "LEAD_V4", "LEAD_V5", "LEAD_V6"]
            if use_12_leads
            else ["LEAD_I", "LEAD_II", "LEAD_III", "LEAD_aVR", "LEAD_aVF", "LEAD_aVL"]
        )

        # Select encoders based on the leads to use
        encoder_indices = range(12) if use_12_leads else range(6)
        self.lead_encoders = nn.ModuleList([pretrained_mopoe.encoders[i] for i in encoder_indices])

        self.feature_dim = pretrained_mopoe.latent_dim

        # Freeze encoder weights
        for encoder in self.lead_encoders:
            for param in encoder.parameters():
                param.requires_grad = False

        # Define classifier network
        self.classifier = nn.Sequential(
            nn.Linear(len(self.lead_names) * self.feature_dim, 128),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(128, num_classes)
        )

    def forward(self, lead_data):
        lead_features = []
        for lead_name, encoder in zip(self.lead_names, self.lead_encoders):
            mu, _ = encoder(lead_data[lead_name])
            lead_features.append(mu)

        combined_features = torch.cat(lead_features, dim=1)
        logits = self.classifier(combined_features)
        return logits

Balance dataset

In [None]:
import torch
import random
import numpy as np
from torch.utils.data import DataLoader, Subset
from sklearn.model_selection import StratifiedKFold
from sklearn.metrics import accuracy_score, roc_auc_score, f1_score, matthews_corrcoef
import torch.optim as optim
import torch.nn as nn

# **Set device configuration**
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")


# **Training Function**
def train_classifier(model, train_loader, criterion, optimizer, epochs=10, use_12_leads=True):
    model.train()
    
    # Select lead names based on `use_12_leads`
    lead_names = (
        ["LEAD_I", "LEAD_II", "LEAD_III", "LEAD_aVR", "LEAD_aVF", "LEAD_aVL"]
        if not use_12_leads else
        ["LEAD_I", "LEAD_II", "LEAD_III", "LEAD_aVR", "LEAD_aVF", "LEAD_aVL",
         "LEAD_V1", "LEAD_V2", "LEAD_V3", "LEAD_V4", "LEAD_V5", "LEAD_V6"]
    )
    
    for epoch in range(epochs):
        total_loss = 0
        for batch in train_loader:
            lead_data, labels = batch
            labels = labels.to(device)

            # Filter lead_data to use only selected leads
            lead_data = {lead: lead_data[lead].to(device) for lead in lead_names}

            optimizer.zero_grad()
            outputs = model(lead_data)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            total_loss += loss.item() * labels.size(0)

        avg_loss = total_loss / len(train_loader.dataset)
        print(f'Epoch {epoch + 1}/{epochs}, Loss: {avg_loss:.4f}')


# **Evaluation Function**
def evaluate_model(model, data_loader, use_12_leads=True):
    model.eval()
    all_labels = []
    all_probs = []
    all_preds = []

    # Select lead names based on `use_12_leads`
    lead_names = (
        ["LEAD_I", "LEAD_II", "LEAD_III", "LEAD_aVR", "LEAD_aVF", "LEAD_aVL"]
        if not use_12_leads else
        ["LEAD_I", "LEAD_II", "LEAD_III", "LEAD_aVR", "LEAD_aVF", "LEAD_aVL",
         "LEAD_V1", "LEAD_V2", "LEAD_V3", "LEAD_V4", "LEAD_V5", "LEAD_V6"]
    )

    with torch.no_grad():
        for batch in data_loader:
            lead_data, labels = batch
            labels = labels.to(device)

            # Filter lead_data to use only selected leads
            lead_data = {lead: lead_data[lead].to(device) for lead in lead_names}

            logits = model(lead_data)
            probs = torch.softmax(logits, dim=1)
            preds = torch.argmax(probs, dim=1)

            all_labels.extend(labels.cpu().numpy())
            all_probs.extend(probs[:, 1].cpu().numpy())
            all_preds.extend(preds.cpu().numpy())

    accuracy = accuracy_score(all_labels, all_preds)
    auc_score = roc_auc_score(all_labels, all_probs)
    f1 = f1_score(all_labels, all_preds)
    mcc = matthews_corrcoef(all_labels, all_preds)

    return accuracy, auc_score, f1, mcc


# **Cross-Validation Training**
skf = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)
fold_results = []

for fold, (train_ids, test_ids) in enumerate(skf.split(np.zeros(len(balanced_labels)), balanced_labels)):
    print(f'FOLD {fold}')
    print('--------------------------------')

    train_subset = Subset(balanced_dataset, train_ids)
    test_subset = Subset(balanced_dataset, test_ids)

    train_loader = DataLoader(train_subset, batch_size=32, shuffle=True)
    test_loader = DataLoader(test_subset, batch_size=32, shuffle=False)

    # Load pretrained MoPoE and initialize classifier

    # Define parameters
    params = {'latent_dim': 256, 'input_dim_per_lead': 5000, 'num_leads': 12}

    # Instantiate the prior distribution
    prior_dist = prior_expert(params['latent_dim'])

    # Create the MoPoE model (specific for 12-lead ECG data)
    pretrained_mopoe = LSEMVAE(
        prior_dist=prior_dist,
        latent_dim=params['latent_dim'],
        num_leads=params['num_leads'],
        input_dim_per_lead=params['input_dim_per_lead']
    )

    # Load the pretrained weights
    # Load the saved state_dict
    state_dict = torch.load("../Main/HPC/pretrain/LS_EMVAE_with_reg_12_lead.pth", map_location=device)
    
    
    # Fix key mismatch by removing '_orig_mod.' prefix
    new_state_dict = {k.replace("_orig_mod.", ""): v for k, v in state_dict.items()}
    
    # Load the updated state dictionary
    pretrained_mopoe.load_state_dict(new_state_dict, strict=False)

    # Move the model to the device (GPU/CPU)
    pretrained_mopoe.to(device)

    model = ECGLeadClassifier(pretrained_mopoe=pretrained_mopoe, num_classes=2, use_12_leads=False).to(device)

    optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=1e-4)
    criterion = nn.CrossEntropyLoss()

    # Train and evaluate
    train_classifier(model, train_loader, criterion, optimizer, epochs=50, use_12_leads=False)
    accuracy, auc_score, f1, mcc = evaluate_model(model, test_loader, use_12_leads=False)
    fold_results.append((accuracy, auc_score, f1, mcc))

    print(f'Fold {fold} Results: Accuracy: {accuracy:.4f}, AUC: {auc_score:.4f}, F1: {f1:.4f}, MCC: {mcc:.4f}\n')

# **Calculate Metrics Across Folds**
accuracies, aucs, f1s, mccs = zip(*fold_results)
print(f'Mean Accuracy: {np.mean(accuracies):.4f}, STD: {np.std(accuracies):.4f}')
print(f'Mean AUC: {np.mean(aucs):.4f}, STD: {np.std(aucs):.4f}')
print(f'Mean F1: {np.mean(f1s):.4f}, STD: {np.std(f1s):.4f}')
print(f'Mean MCC: {np.mean(mccs):.4f}, STD: {np.std(mccs):.4f}')  # MCC Statistics

Imbalance dataset

In [None]:
import torch
import random
import numpy as np
from torch.utils.data import DataLoader, Subset
from sklearn.model_selection import StratifiedKFold
from sklearn.metrics import accuracy_score, roc_auc_score, f1_score, matthews_corrcoef
import torch.optim as optim
import torch.nn as nn

# **Set device configuration**
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")


# **Training Function**
def train_classifier(model, train_loader, criterion, optimizer, epochs=10, use_12_leads=True):
    model.train()
    
    # Select lead names based on `use_12_leads`
    lead_names = (
        ["LEAD_I", "LEAD_II", "LEAD_III", "LEAD_aVR", "LEAD_aVF", "LEAD_aVL"]
        if not use_12_leads else
        ["LEAD_I", "LEAD_II", "LEAD_III", "LEAD_aVR", "LEAD_aVF", "LEAD_aVL",
         "LEAD_V1", "LEAD_V2", "LEAD_V3", "LEAD_V4", "LEAD_V5", "LEAD_V6"]
    )
    
    for epoch in range(epochs):
        total_loss = 0
        for batch in train_loader:
            lead_data, labels = batch
            labels = labels.to(device)

            # Filter lead_data to use only selected leads
            lead_data = {lead: lead_data[lead].to(device) for lead in lead_names}

            optimizer.zero_grad()
            outputs = model(lead_data)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            total_loss += loss.item() * labels.size(0)

        avg_loss = total_loss / len(train_loader.dataset)
        print(f'Epoch {epoch + 1}/{epochs}, Loss: {avg_loss:.4f}')


# **Evaluation Function**
def evaluate_model(model, data_loader, use_12_leads=True):
    model.eval()
    all_labels = []
    all_probs = []
    all_preds = []

    # Select lead names based on `use_12_leads`
    lead_names = (
        ["LEAD_I", "LEAD_II", "LEAD_III", "LEAD_aVR", "LEAD_aVF", "LEAD_aVL"]
        if not use_12_leads else
        ["LEAD_I", "LEAD_II", "LEAD_III", "LEAD_aVR", "LEAD_aVF", "LEAD_aVL",
         "LEAD_V1", "LEAD_V2", "LEAD_V3", "LEAD_V4", "LEAD_V5", "LEAD_V6"]
    )

    with torch.no_grad():
        for batch in data_loader:
            lead_data, labels = batch
            labels = labels.to(device)

            # Filter lead_data to use only selected leads
            lead_data = {lead: lead_data[lead].to(device) for lead in lead_names}

            logits = model(lead_data)
            probs = torch.softmax(logits, dim=1)
            preds = torch.argmax(probs, dim=1)

            all_labels.extend(labels.cpu().numpy())
            all_probs.extend(probs[:, 1].cpu().numpy())
            all_preds.extend(preds.cpu().numpy())

    accuracy = accuracy_score(all_labels, all_preds)
    auc_score = roc_auc_score(all_labels, all_probs)
    f1 = f1_score(all_labels, all_preds)
    mcc = matthews_corrcoef(all_labels, all_preds)

    return accuracy, auc_score, f1, mcc


# **Cross-Validation Training**
skf = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)
fold_results = []

for fold, (train_ids, test_ids) in enumerate(skf.split(np.zeros(len(imbalanced_labels)), imbalanced_labels)):
    print(f'FOLD {fold}')
    print('--------------------------------')

    train_subset = Subset(imbalanced_dataset, train_ids)
    test_subset = Subset(imbalanced_dataset, test_ids)

    train_loader = DataLoader(train_subset, batch_size=32, shuffle=True)
    test_loader = DataLoader(test_subset, batch_size=32, shuffle=False)

    # Load pretrained MoPoE and initialize classifier

    # Define parameters
    params = {'latent_dim': 256, 'input_dim_per_lead': 5000, 'num_leads': 12}

    # Instantiate the prior distribution
    prior_dist = prior_expert(params['latent_dim'])

    # Create the MoPoE model (specific for 12-lead ECG data)
    pretrained_mopoe = LSEMVAE(
        prior_dist=prior_dist,
        latent_dim=params['latent_dim'],
        num_leads=params['num_leads'],
        input_dim_per_lead=params['input_dim_per_lead']
    )

    # Load the pretrained weights
    # Load the saved state_dict
    state_dict = torch.load("../Main/HPC/pretrain/LS_EMVAE_with_reg_12_lead.pth", map_location=device)
    
    
    # Fix key mismatch by removing '_orig_mod.' prefix
    new_state_dict = {k.replace("_orig_mod.", ""): v for k, v in state_dict.items()}
    
    # Load the updated state dictionary
    pretrained_mopoe.load_state_dict(new_state_dict, strict=False)

    # Move the model to the device (GPU/CPU)
    pretrained_mopoe.to(device)

    model = ECGLeadClassifier(pretrained_mopoe=pretrained_mopoe, num_classes=2, use_12_leads=False).to(device)

    optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=1e-4)
    criterion = nn.CrossEntropyLoss()

    # Train and evaluate
    train_classifier(model, train_loader, criterion, optimizer, epochs=50, use_12_leads=False)
    accuracy, auc_score, f1, mcc = evaluate_model(model, test_loader, use_12_leads=False)
    fold_results.append((accuracy, auc_score, f1, mcc))

    print(f'Fold {fold} Results: Accuracy: {accuracy:.4f}, AUC: {auc_score:.4f}, F1: {f1:.4f}, MCC: {mcc:.4f}\n')

# **Calculate Metrics Across Folds**
accuracies, aucs, f1s, mccs = zip(*fold_results)
print(f'Mean Accuracy: {np.mean(accuracies):.4f}, STD: {np.std(accuracies):.4f}')
print(f'Mean AUC: {np.mean(aucs):.4f}, STD: {np.std(aucs):.4f}')
print(f'Mean F1: {np.mean(f1s):.4f}, STD: {np.std(f1s):.4f}')
print(f'Mean MCC: {np.mean(mccs):.4f}, STD: {np.std(mccs):.4f}')  # MCC Statistics

Visualization

In [None]:
# If you only need files (no GUI window), uncomment the next two lines:
# import matplotlib
# matplotlib.use("Agg")  # set before importing pyplot

import matplotlib.pyplot as plt
import numpy as np

# ------------ Data ------------
metrics = ["Accuracy", "AUROC", "MCC"]

labels = ["Healthy 75% : PH 25%", "Balanced (50% : 50%)", "Healthy 25% : PH 75%"]  # fixed typo

means = {
    "Healthy 75% : PH 25%": [0.7840, 0.8106, 0.4963],
    "Balanced (50% : 50%)":     [0.7165, 0.7935, 0.4356],
    "Healthy 25% : PH 75%": [0.8139, 0.8141, 0.4104],
}
stds = {
    "Healthy 75% : PH 25%": [0.0183, 0.0696, 0.0455],
    "Balanced (50% : 50%)":     [0.0643, 0.0527, 0.1272],
    "Healthy 25% : PH 75%": [0.0199, 0.0249, 0.0623],
}

# Professional, readable palette (Tableau 10)


palette = {
    "Healthy 75% : PH 25%": "#98D9A0",  # mint (light green)
    "Balanced (50% : 50%)": "#F2E2C4",  # light sand (warm neutral)
    "Healthy 25% : PH 75%": "#D6C6E1",  # soft orchid (pastel purple)
}

palette = {
    "Healthy 75% : PH 25%": "#EFD9B4",  # soft sand
    "Balanced (50% : 50%)": "#D3BCCC",  # muted orchid
    "Healthy 25% : PH 75%": "#C9E4F3",  # pale ice blue
}


# ------------ Plot ------------
plt.ioff()  # non-interactive to save memory
fig, ax = plt.subplots(figsize=(7.5, 4), dpi=600)

x = np.arange(len(metrics))
bar_width = 0.20
capsize = 3
offsets = [-bar_width, 0.0, bar_width]

for offset, label in zip(offsets, labels):
    ax.bar(
        x + offset,
        means[label],
        yerr=stds[label],
        width=bar_width,
        capsize=capsize,
        label=label,
        color=palette[label],
    )

# Axes formatting (match your reference sizes)
ax.set_xticks(x)
ax.set_xticklabels(metrics, fontsize=16, fontweight="bold")
ax.tick_params(axis='y', labelsize=18)  # y-ticks size like your example
ax.set_title("PH Detection", fontweight="bold", fontsize=16)
ax.set_ylim(0.0, 1.0)

# Legend
ax.legend(frameon=False, ncol=1, loc="upper right",prop={'weight': 'bold', 'size': 11})

fig.tight_layout()
plt.show()

# Save high-quality copies (optional)
fig.savefig("class_distribution_grouped_bars.png", bbox_inches="tight", dpi=300)
fig.savefig("class_distribution_grouped_bars.pdf", bbox_inches="tight")

plt.close(fig)
