Task-1

In [None]:
import torch
from torch.utils.data import Dataset, DataLoader

# Define file paths for the 12-lead ECG data and labels
lead_file_paths = {
    "LEAD_I": "../Data_processing/data_aspire_PAP_1/LEAD_I.pt",
    "LEAD_II": "../Data_processing/data_aspire_PAP_1/LEAD_II.pt",
    "LEAD_III": "../Data_processing/data_aspire_PAP_1/LEAD_III.pt",
    "LEAD_aVR": "../Data_processing/data_aspire_PAP_1/LEAD_aVR.pt",
    "LEAD_aVL": "../Data_processing/data_aspire_PAP_1/LEAD_aVL.pt",
    "LEAD_aVF": "../Data_processing/data_aspire_PAP_1/LEAD_aVF.pt",
    "LEAD_V1": "../Data_processing/data_aspire_PAP_1/LEAD_V1.pt",
    "LEAD_V2": "../Data_processing/data_aspire_PAP_1/LEAD_V2.pt",
    "LEAD_V3": "../Data_processing/data_aspire_PAP_1/LEAD_V3.pt",
    "LEAD_V4": "../Data_processing/data_aspire_PAP_1/LEAD_V4.pt",
    "LEAD_V5": "../Data_processing/data_aspire_PAP_1/LEAD_V5.pt",
    "LEAD_V6": "../Data_processing/data_aspire_PAP_1/LEAD_V6.pt"
}
labels_file_path = "../Data_processing/data_aspire_PAP_1/labels.pt"


# Load all lead tensors and labels
ecg_lead_tensors = {lead: torch.load(path) for lead, path in lead_file_paths.items()}
labels = torch.load(labels_file_path)

# Ensure all leads have the same number of samples as the labels
sample_count = len(next(iter(ecg_lead_tensors.values())))
assert len(labels) == sample_count, "Mismatch between number of labels and samples."
for tensor in ecg_lead_tensors.values():
    assert len(tensor) == sample_count, "All leads must have the same number of samples."

# Define the dataset class
class ECGMultiLeadDatasetWithLabels(Dataset):
    def __init__(self, ecg_leads, labels):
        self.ecg_leads = ecg_leads
        self.labels = labels

    def __len__(self):
        return len(next(iter(self.ecg_leads.values())))

    def __getitem__(self, idx):
        # Return each lead sample with the correct input shape and corresponding label
        lead_data = {lead: self.ecg_leads[lead][idx].unsqueeze(0) for lead in self.ecg_leads}
        label = self.labels[idx]
        return lead_data, label

# Initialize the dataset and dataloader
dataset_2 = ECGMultiLeadDatasetWithLabels(ecg_lead_tensors, labels)
dataloader = DataLoader(dataset, batch_size=128, shuffle=True)


In [7]:
# **Classifier Model Class**
class ECGLeadClassifier(nn.Module):
    def __init__(self, pretrained_mopoe, num_classes, use_12_leads=True):
        super(ECGLeadClassifier, self).__init__()

        # Define lead names based on the mode
        self.use_12_leads = use_12_leads
        self.lead_names = (
            ["LEAD_I", "LEAD_II", "LEAD_III", "LEAD_aVR", "LEAD_aVF", "LEAD_aVL",
             "LEAD_V1", "LEAD_V2", "LEAD_V3", "LEAD_V4", "LEAD_V5", "LEAD_V6"]
            if use_12_leads
            else ["LEAD_I", "LEAD_II", "LEAD_III", "LEAD_aVR", "LEAD_aVF", "LEAD_aVL"]
        )

        # Select encoders based on the leads to use
        encoder_indices = range(12) if use_12_leads else range(6)
        self.lead_encoders = nn.ModuleList([pretrained_mopoe.encoders[i] for i in encoder_indices])

        self.feature_dim = pretrained_mopoe.latent_dim

        # Freeze encoder weights
        for encoder in self.lead_encoders:
            for param in encoder.parameters():
                param.requires_grad = False

        # Define classifier network
        self.classifier = nn.Sequential(
            nn.Linear(len(self.lead_names) * self.feature_dim, 128),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(128, num_classes)
        )

    def forward(self, lead_data):
        lead_features = []
        for lead_name, encoder in zip(self.lead_names, self.lead_encoders):
            mu, _ = encoder(lead_data[lead_name])
            lead_features.append(mu)

        combined_features = torch.cat(lead_features, dim=1)
        logits = self.classifier(combined_features)
        return logits

In [None]:
import torch
import random
import numpy as np
from torch.utils.data import DataLoader, Subset
from sklearn.model_selection import StratifiedKFold
from sklearn.metrics import accuracy_score, roc_auc_score, f1_score, matthews_corrcoef
import torch.optim as optim
import torch.nn as nn
from captum.attr import IntegratedGradients
import pandas as pd
import neurokit2 as nk

import warnings
warnings.filterwarnings("ignore")


# -------------------- Setup --------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

def set_seed(seed_value):
    random.seed(seed_value)
    np.random.seed(seed_value)
    torch.manual_seed(seed_value)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed_value)
        torch.cuda.manual_seed_all(seed_value)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
set_seed(123)

# -------------------- Training Function --------------------
def train_classifier(model, train_loader, criterion, optimizer, epochs=10, use_12_leads=True):
    model.train()
    lead_names = (
        ["LEAD_I", "LEAD_II", "LEAD_III", "LEAD_aVR", "LEAD_aVF", "LEAD_aVL"]
        if not use_12_leads else
        ["LEAD_I", "LEAD_II", "LEAD_III", "LEAD_aVR", "LEAD_aVF", "LEAD_aVL",
         "LEAD_V1", "LEAD_V2", "LEAD_V3", "LEAD_V4", "LEAD_V5", "LEAD_V6"]
    )
    for epoch in range(epochs):
        total_loss = 0
        for batch in train_loader:
            lead_data, labels = batch
            labels = labels.to(device)
            lead_data = {lead: lead_data[lead].to(device) for lead in lead_names}
            optimizer.zero_grad()
            outputs = model(lead_data)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            total_loss += loss.item() * labels.size(0)
        avg_loss = total_loss / len(train_loader.dataset)
        print(f'Epoch {epoch + 1}/{epochs}, Loss: {avg_loss:.4f}')

# -------------------- Evaluation Function --------------------
def evaluate_model(model, data_loader, use_12_leads=True):
    model.eval()
    all_labels = []
    all_probs = []
    all_preds = []
    lead_names = (
        ["LEAD_I", "LEAD_II", "LEAD_III", "LEAD_aVR", "LEAD_aVF", "LEAD_aVL"]
        if not use_12_leads else
        ["LEAD_I", "LEAD_II", "LEAD_III", "LEAD_aVR", "LEAD_aVF", "LEAD_aVL",
         "LEAD_V1", "LEAD_V2", "LEAD_V3", "LEAD_V4", "LEAD_V5", "LEAD_V6"]
    )
    with torch.no_grad():
        for batch in data_loader:
            lead_data, labels = batch
            labels = labels.to(device)
            lead_data = {lead: lead_data[lead].to(device) for lead in lead_names}
            logits = model(lead_data)
            probs = torch.softmax(logits, dim=1)
            preds = torch.argmax(probs, dim=1)
            all_labels.extend(labels.cpu().numpy())
            all_probs.extend(probs[:, 1].cpu().numpy())
            all_preds.extend(preds.cpu().numpy())
    accuracy = accuracy_score(all_labels, all_preds)
    auc_score = roc_auc_score(all_labels, all_probs)
    f1 = f1_score(all_labels, all_preds)
    mcc = matthews_corrcoef(all_labels, all_preds)
    return accuracy, auc_score, f1, mcc

# -------------------- IGAR Metric Function --------------------
def compute_igar_for_fold(model, val_dataset, device, threshold=0.7):
    model.eval()
    igar_data = []
    for sample_idx in range(len(val_dataset)):
        lead_data_sample, label = val_dataset[sample_idx]
        lead_inputs = []
        for lead in model.lead_names:
            tensor = lead_data_sample[lead]
            if tensor.dim() == 2:
                tensor = tensor.unsqueeze(0)
            tensor = tensor.to(device)
            tensor.requires_grad_(True)
            lead_inputs.append(tensor)
        inputs_tuple = tuple(lead_inputs)
        def model_wrapper(*inputs):
            return model({lead: tensor for lead, tensor in zip(model.lead_names, inputs)})
        with torch.no_grad():
            logits = model_wrapper(*inputs_tuple)
        predicted_label = torch.argmax(logits, dim=1).item()
        integrated_gradients = IntegratedGradients(model_wrapper)
        attributions, _ = integrated_gradients.attribute(
            inputs=inputs_tuple,
            target=predicted_label,
            return_convergence_delta=True
        )
        lead_igar_scores = {}
        for idx, lead in enumerate(model.lead_names):
            attr = attributions[idx].detach().cpu().numpy().squeeze()
            norm_attr = (attr - attr.min()) / (attr.max() - attr.min() + 1e-10)
            important_indices = np.where(norm_attr >= threshold)[0]
            igar = len(important_indices) / len(attr)
            lead_igar_scores[lead] = igar
        total_igar = sum(lead_igar_scores.values())
        for lead, igar in lead_igar_scores.items():
            percent = (igar / total_igar * 100) if total_igar > 0 else 0
            igar_data.append({
                "sample_idx": sample_idx,
                "lead": lead,
                "IGAR": igar,
                "IGAR_percent": percent
            })
    df = pd.DataFrame(igar_data)
    return df

# ----------- Peak-level IGAR Calculation (per sample) -----------
def compute_peak_igar_for_sample(
    model, lead_data_sample, device, threshold=0.6, sampling_rate=500
):
    import numpy as np
    import pandas as pd
    import neurokit2 as nk
    from captum.attr import IntegratedGradients
    # Prepare data for model
    for lead in lead_data_sample:
        if lead_data_sample[lead].dim() == 2:
            lead_data_sample[lead] = lead_data_sample[lead].unsqueeze(0)
        lead_data_sample[lead] = lead_data_sample[lead].to(device)
        lead_data_sample[lead].requires_grad_(True)
    def model_wrapper(*inputs):
        lead_data = {lead: tensor for lead, tensor in zip(model.lead_names, inputs)}
        return model(lead_data)
    inputs_tuple = tuple(lead_data_sample[lead] for lead in model.lead_names)
    model.eval()
    with torch.no_grad():
        logits = model_wrapper(*inputs_tuple)
    predicted_label = torch.argmax(logits, dim=1).item()
    integrated_gradients = IntegratedGradients(model_wrapper)
    attributions, _ = integrated_gradients.attribute(
        inputs=inputs_tuple,
        target=predicted_label,
        return_convergence_delta=True
    )
    peak_igar_results = {}
    window_seconds = 0.04
    window = int(window_seconds * sampling_rate)
    for idx, lead in enumerate(model.lead_names):
        ecg_signal = lead_data_sample[lead].detach().cpu().numpy().squeeze()
        attr = attributions[idx].detach().cpu().numpy().squeeze()
        norm_attr = (attr - attr.min()) / (attr.max() - attr.min() + 1e-10)
        important_indices = np.where(norm_attr >= threshold)[0]
        try:
            ecg_cleaned = nk.ecg_clean(ecg_signal, sampling_rate=sampling_rate)
            rpeaks_dict = nk.ecg_peaks(ecg_cleaned, sampling_rate=sampling_rate)[1]
            r_locs = rpeaks_dict.get('ECG_R_Peaks', [])
            if not isinstance(r_locs, (list, np.ndarray)) or len(r_locs) == 0:
                print(f"No R peaks detected for {lead}. Skipping.")
                continue
            _, waves_peak = nk.ecg_delineate(
                ecg_cleaned,
                rpeaks=r_locs,
                sampling_rate=sampling_rate,
                method="dwt"
            )
            def safe_peaks(peaks):
                if peaks is None:
                    return []
                return [
                    int(i) for i in peaks
                    if (
                        isinstance(i, (int, np.integer))
                        or (isinstance(i, float) and not np.isnan(i))
                    )
                ]
            p_locs = safe_peaks(waves_peak.get('ECG_P_Peaks', []))
            q_locs = safe_peaks(waves_peak.get('ECG_Q_Peaks', []))
            s_locs = safe_peaks(waves_peak.get('ECG_S_Peaks', []))
            t_locs = safe_peaks(waves_peak.get('ECG_T_Peaks', []))
        except Exception as e:
            print(f"NeuroKit2 delineation failed for {lead}: {e}")
            continue
        peak_indices = {'P': p_locs, 'Q': q_locs, 'R': r_locs, 'S': s_locs, 'T': t_locs}
        peak_import_points = {}
        peak_igar = {}
        for peak_type, locs in peak_indices.items():
            window_indices = set()
            for idx_peak in locs:
                start = max(0, idx_peak - window)
                end = min(len(ecg_signal), idx_peak + window + 1)
                window_indices.update(range(start, end))
            imp_points = [ix for ix in important_indices if ix in window_indices]
            peak_import_points[peak_type] = len(imp_points)
            peak_igar[peak_type] = len(imp_points) / len(ecg_signal)
        total_imp_points = sum(peak_import_points.values())
        peak_percent = {
            peak_type: (peak_import_points[peak_type] / total_imp_points * 100 if total_imp_points > 0 else 0)
            for peak_type in peak_import_points
        }
        result_df = pd.DataFrame({
            'Peak': list(peak_igar.keys()),
            'Peak_IGAR': list(peak_igar.values()),
            'ImportantPoints': list(peak_import_points.values()),
            'PercentOfAllImportant': list(peak_percent.values())
        })
        result_df = result_df.sort_values('PercentOfAllImportant', ascending=False)
        peak_igar_results[lead] = result_df
    return peak_igar_results


skf = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)
fold_results = []
all_fold_igar = []
all_peak_igar_rows = []

for fold, (train_ids, test_ids) in enumerate(skf.split(np.zeros(len(labels)), labels)):
    print(f'FOLD {fold}\n' + '-'*30)
    train_subset = Subset(dataset, train_ids)
    test_subset = Subset(dataset, test_ids)
    train_loader = DataLoader(train_subset, batch_size=32, shuffle=True)
    test_loader = DataLoader(test_subset, batch_size=32, shuffle=False)
    # ---- Model setup ----
    params = {'latent_dim': 256, 'input_dim_per_lead': 5000, 'num_leads': 12}
    prior_dist = prior_expert(params['latent_dim'])
    pretrained_mopoe = LSEMVAE(
        prior_dist=prior_dist,
        latent_dim=params['latent_dim'],
        num_leads=params['num_leads'],
        input_dim_per_lead=params['input_dim_per_lead']
    )
    state_dict = torch.load("../Main/HPC/pretrain/LS_EMVAE_with_reg_12_lead.pth", map_location=device)
    new_state_dict = {k.replace("_orig_mod.", ""): v for k, v in state_dict.items()}
    pretrained_mopoe.load_state_dict(new_state_dict, strict=False)
    pretrained_mopoe.to(device)
    model = ECGLeadClassifier(pretrained_mopoe=pretrained_mopoe, num_classes=2, use_12_leads=False).to(device)
    optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=1e-4)
    criterion = nn.CrossEntropyLoss()
    train_classifier(model, train_loader, criterion, optimizer, epochs=50, use_12_leads=False)
    accuracy, auc_score, f1, mcc = evaluate_model(model, test_loader, use_12_leads=False)
    fold_results.append((accuracy, auc_score, f1, mcc))
    print(f"Computing IGAR for fold {fold} ...")
    igar_df = compute_igar_for_fold(model, test_subset, device, threshold=0.7)
    igar_df['fold'] = fold
    all_fold_igar.append(igar_df)
    print(f'Fold {fold} Results: Accuracy: {accuracy:.4f}, AUC: {auc_score:.4f}, F1: {f1:.4f}, MCC: {mcc:.4f}\n')
    for sample_idx in range(len(test_subset)):
        lead_data_sample, label = test_subset[sample_idx]
        peak_igar_results = compute_peak_igar_for_sample(
            model, dict(lead_data_sample), device, threshold=0.6, sampling_rate=500
        )
        for lead, df in peak_igar_results.items():
            df['fold'] = fold
            df['sample_idx'] = sample_idx
            df['lead'] = lead
            all_peak_igar_rows.append(df)

# Aggregate classification metrics
accuracies, aucs, f1s, mccs = zip(*fold_results)
print(f'Mean Accuracy: {np.mean(accuracies):.4f}, STD: {np.std(accuracies):.4f}')
print(f'Mean AUC: {np.mean(aucs):.4f}, STD: {np.std(aucs):.4f}')
print(f'Mean F1: {np.mean(f1s):.4f}, STD: {np.std(f1s):.4f}')
print(f'Mean MCC: {np.mean(mccs):.4f}, STD: {np.std(mccs):.4f}')

# -------- IGAR per-lead summary (mean only, no std) --------
igar_all = pd.concat(all_fold_igar, ignore_index=True)
igar_summary = igar_all.groupby("lead").agg(
    IGAR_mean = ("IGAR", "mean"),
    IGAR_percent_mean = ("IGAR_percent", "mean")
).reset_index()
print("Aggregated IGAR metrics across all folds:")
print(igar_summary)

# -------- Per-lead, per-peak IGAR summary (mean only, no std) --------
if all_peak_igar_rows:
    peak_igar_all = pd.concat(all_peak_igar_rows, ignore_index=True)
    peak_igar_all = peak_igar_all.replace([np.inf, -np.inf], np.nan)
    # Remove ImportantPoints column if not needed for summary
    peak_igar_all = peak_igar_all.drop(columns=["ImportantPoints"], errors='ignore')
    peak_igar_summary = (
        peak_igar_all.groupby(["lead", "Peak"])
        .agg(
            Peak_IGAR_mean=pd.NamedAgg(column="Peak_IGAR", aggfunc=lambda x: np.nanmean(x)),
            PercentOfAllImportant_mean=pd.NamedAgg(column="PercentOfAllImportant", aggfunc=lambda x: np.nanmean(x)),
        )
        .reset_index()
    )
    print("Aggregated IGAR and Percent for each peak and lead (mean only):")
    print(peak_igar_summary)
else:
    print("No peak IGAR results to aggregate.")

Task-2

In [None]:
import torch
from torch.utils.data import Dataset, DataLoader


lead_file_paths = {
    "LEAD_I": "../Data_processing/data_aspire_PAWP_1/LEAD_I.pt",
    "LEAD_II": "../Data_processing/data_aspire_PAWP_1/LEAD_II.pt",
    "LEAD_III": "../Data_processing/data_aspire_PAWP_1/LEAD_III.pt",
    "LEAD_aVR": "../Data_processing/data_aspire_PAWP_1/LEAD_aVR.pt",
    "LEAD_aVL": "../Data_processing/data_aspire_PAWP_1/LEAD_aVL.pt",
    "LEAD_aVF": "../Data_processing/data_aspire_PAWP_1/LEAD_aVF.pt",
    "LEAD_V1": "../Data_processing/data_aspire_PAWP_1/LEAD_V1.pt",
    "LEAD_V2": "../Data_processing/data_aspire_PAWP_1/LEAD_V2.pt",
    "LEAD_V3": "../Data_processing/data_aspire_PAWP_1/LEAD_V3.pt",
    "LEAD_V4": "../Data_processing/data_aspire_PAWP_1/LEAD_V4.pt",
    "LEAD_V5": "../Data_processing/data_aspire_PAWP_1/LEAD_V5.pt",
    "LEAD_V6": "../Data_processing/data_aspire_PAWP_1/LEAD_V6.pt"
}
labels_file_path = "../Data_processing/data_aspire_PAWP_1/labels.pt"


# Load all lead tensors and labels
ecg_lead_tensors = {lead: torch.load(path) for lead, path in lead_file_paths.items()}
labels = torch.load(labels_file_path)

# Ensure all leads have the same number of samples as the labels
sample_count = len(next(iter(ecg_lead_tensors.values())))
assert len(labels) == sample_count, "Mismatch between number of labels and samples."
for tensor in ecg_lead_tensors.values():
    assert len(tensor) == sample_count, "All leads must have the same number of samples."

# Define the dataset class
class ECGMultiLeadDatasetWithLabels(Dataset):
    def __init__(self, ecg_leads, labels):
        self.ecg_leads = ecg_leads
        self.labels = labels

    def __len__(self):
        return len(next(iter(self.ecg_leads.values())))

    def __getitem__(self, idx):
        # Return each lead sample with the correct input shape and corresponding label
        lead_data = {lead: self.ecg_leads[lead][idx].unsqueeze(0) for lead in self.ecg_leads}
        label = self.labels[idx]
        return lead_data, label

# Initialize the dataset and dataloader
dataset_2 = ECGMultiLeadDatasetWithLabels(ecg_lead_tensors, labels)
dataloader = DataLoader(dataset, batch_size=128, shuffle=True)

In [None]:
import torch
import random
import numpy as np
from torch.utils.data import DataLoader, Subset
from sklearn.model_selection import StratifiedKFold
from sklearn.metrics import accuracy_score, roc_auc_score, f1_score, matthews_corrcoef
import torch.optim as optim
import torch.nn as nn
from captum.attr import IntegratedGradients
import pandas as pd
import neurokit2 as nk

import warnings
warnings.filterwarnings("ignore")


# -------------------- Setup --------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

def set_seed(seed_value):
    random.seed(seed_value)
    np.random.seed(seed_value)
    torch.manual_seed(seed_value)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed_value)
        torch.cuda.manual_seed_all(seed_value)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
set_seed(123)

# -------------------- Training Function --------------------
def train_classifier(model, train_loader, criterion, optimizer, epochs=10, use_12_leads=True):
    model.train()
    lead_names = (
        ["LEAD_I", "LEAD_II", "LEAD_III", "LEAD_aVR", "LEAD_aVF", "LEAD_aVL"]
        if not use_12_leads else
        ["LEAD_I", "LEAD_II", "LEAD_III", "LEAD_aVR", "LEAD_aVF", "LEAD_aVL",
         "LEAD_V1", "LEAD_V2", "LEAD_V3", "LEAD_V4", "LEAD_V5", "LEAD_V6"]
    )
    for epoch in range(epochs):
        total_loss = 0
        for batch in train_loader:
            lead_data, labels = batch
            labels = labels.to(device)
            lead_data = {lead: lead_data[lead].to(device) for lead in lead_names}
            optimizer.zero_grad()
            outputs = model(lead_data)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            total_loss += loss.item() * labels.size(0)
        avg_loss = total_loss / len(train_loader.dataset)
        print(f'Epoch {epoch + 1}/{epochs}, Loss: {avg_loss:.4f}')

# -------------------- Evaluation Function --------------------
def evaluate_model(model, data_loader, use_12_leads=True):
    model.eval()
    all_labels = []
    all_probs = []
    all_preds = []
    lead_names = (
        ["LEAD_I", "LEAD_II", "LEAD_III", "LEAD_aVR", "LEAD_aVF", "LEAD_aVL"]
        if not use_12_leads else
        ["LEAD_I", "LEAD_II", "LEAD_III", "LEAD_aVR", "LEAD_aVF", "LEAD_aVL",
         "LEAD_V1", "LEAD_V2", "LEAD_V3", "LEAD_V4", "LEAD_V5", "LEAD_V6"]
    )
    with torch.no_grad():
        for batch in data_loader:
            lead_data, labels = batch
            labels = labels.to(device)
            lead_data = {lead: lead_data[lead].to(device) for lead in lead_names}
            logits = model(lead_data)
            probs = torch.softmax(logits, dim=1)
            preds = torch.argmax(probs, dim=1)
            all_labels.extend(labels.cpu().numpy())
            all_probs.extend(probs[:, 1].cpu().numpy())
            all_preds.extend(preds.cpu().numpy())
    accuracy = accuracy_score(all_labels, all_preds)
    auc_score = roc_auc_score(all_labels, all_probs)
    f1 = f1_score(all_labels, all_preds)
    mcc = matthews_corrcoef(all_labels, all_preds)
    return accuracy, auc_score, f1, mcc

# -------------------- IGAR Metric Function --------------------
def compute_igar_for_fold(model, val_dataset, device, threshold=0.7):
    model.eval()
    igar_data = []
    for sample_idx in range(len(val_dataset)):
        lead_data_sample, label = val_dataset[sample_idx]
        lead_inputs = []
        for lead in model.lead_names:
            tensor = lead_data_sample[lead]
            if tensor.dim() == 2:
                tensor = tensor.unsqueeze(0)
            tensor = tensor.to(device)
            tensor.requires_grad_(True)
            lead_inputs.append(tensor)
        inputs_tuple = tuple(lead_inputs)
        def model_wrapper(*inputs):
            return model({lead: tensor for lead, tensor in zip(model.lead_names, inputs)})
        with torch.no_grad():
            logits = model_wrapper(*inputs_tuple)
        predicted_label = torch.argmax(logits, dim=1).item()
        integrated_gradients = IntegratedGradients(model_wrapper)
        attributions, _ = integrated_gradients.attribute(
            inputs=inputs_tuple,
            target=predicted_label,
            return_convergence_delta=True
        )
        lead_igar_scores = {}
        for idx, lead in enumerate(model.lead_names):
            attr = attributions[idx].detach().cpu().numpy().squeeze()
            norm_attr = (attr - attr.min()) / (attr.max() - attr.min() + 1e-10)
            important_indices = np.where(norm_attr >= threshold)[0]
            igar = len(important_indices) / len(attr)
            lead_igar_scores[lead] = igar
        total_igar = sum(lead_igar_scores.values())
        for lead, igar in lead_igar_scores.items():
            percent = (igar / total_igar * 100) if total_igar > 0 else 0
            igar_data.append({
                "sample_idx": sample_idx,
                "lead": lead,
                "IGAR": igar,
                "IGAR_percent": percent
            })
    df = pd.DataFrame(igar_data)
    return df

# ----------- Peak-level IGAR Calculation (per sample) -----------
def compute_peak_igar_for_sample(
    model, lead_data_sample, device, threshold=0.6, sampling_rate=500
):
    import numpy as np
    import pandas as pd
    import neurokit2 as nk
    from captum.attr import IntegratedGradients
    # Prepare data for model
    for lead in lead_data_sample:
        if lead_data_sample[lead].dim() == 2:
            lead_data_sample[lead] = lead_data_sample[lead].unsqueeze(0)
        lead_data_sample[lead] = lead_data_sample[lead].to(device)
        lead_data_sample[lead].requires_grad_(True)
    def model_wrapper(*inputs):
        lead_data = {lead: tensor for lead, tensor in zip(model.lead_names, inputs)}
        return model(lead_data)
    inputs_tuple = tuple(lead_data_sample[lead] for lead in model.lead_names)
    model.eval()
    with torch.no_grad():
        logits = model_wrapper(*inputs_tuple)
    predicted_label = torch.argmax(logits, dim=1).item()
    integrated_gradients = IntegratedGradients(model_wrapper)
    attributions, _ = integrated_gradients.attribute(
        inputs=inputs_tuple,
        target=predicted_label,
        return_convergence_delta=True
    )
    peak_igar_results = {}
    window_seconds = 0.04
    window = int(window_seconds * sampling_rate)
    for idx, lead in enumerate(model.lead_names):
        ecg_signal = lead_data_sample[lead].detach().cpu().numpy().squeeze()
        attr = attributions[idx].detach().cpu().numpy().squeeze()
        norm_attr = (attr - attr.min()) / (attr.max() - attr.min() + 1e-10)
        important_indices = np.where(norm_attr >= threshold)[0]
        try:
            ecg_cleaned = nk.ecg_clean(ecg_signal, sampling_rate=sampling_rate)
            rpeaks_dict = nk.ecg_peaks(ecg_cleaned, sampling_rate=sampling_rate)[1]
            r_locs = rpeaks_dict.get('ECG_R_Peaks', [])
            if not isinstance(r_locs, (list, np.ndarray)) or len(r_locs) == 0:
                print(f"No R peaks detected for {lead}. Skipping.")
                continue
            _, waves_peak = nk.ecg_delineate(
                ecg_cleaned,
                rpeaks=r_locs,
                sampling_rate=sampling_rate,
                method="dwt"
            )
            def safe_peaks(peaks):
                if peaks is None:
                    return []
                return [
                    int(i) for i in peaks
                    if (
                        isinstance(i, (int, np.integer))
                        or (isinstance(i, float) and not np.isnan(i))
                    )
                ]
            p_locs = safe_peaks(waves_peak.get('ECG_P_Peaks', []))
            q_locs = safe_peaks(waves_peak.get('ECG_Q_Peaks', []))
            s_locs = safe_peaks(waves_peak.get('ECG_S_Peaks', []))
            t_locs = safe_peaks(waves_peak.get('ECG_T_Peaks', []))
        except Exception as e:
            print(f"NeuroKit2 delineation failed for {lead}: {e}")
            continue
        peak_indices = {'P': p_locs, 'Q': q_locs, 'R': r_locs, 'S': s_locs, 'T': t_locs}
        peak_import_points = {}
        peak_igar = {}
        for peak_type, locs in peak_indices.items():
            window_indices = set()
            for idx_peak in locs:
                start = max(0, idx_peak - window)
                end = min(len(ecg_signal), idx_peak + window + 1)
                window_indices.update(range(start, end))
            imp_points = [ix for ix in important_indices if ix in window_indices]
            peak_import_points[peak_type] = len(imp_points)
            peak_igar[peak_type] = len(imp_points) / len(ecg_signal)
        total_imp_points = sum(peak_import_points.values())
        peak_percent = {
            peak_type: (peak_import_points[peak_type] / total_imp_points * 100 if total_imp_points > 0 else 0)
            for peak_type in peak_import_points
        }
        result_df = pd.DataFrame({
            'Peak': list(peak_igar.keys()),
            'Peak_IGAR': list(peak_igar.values()),
            'ImportantPoints': list(peak_import_points.values()),
            'PercentOfAllImportant': list(peak_percent.values())
        })
        result_df = result_df.sort_values('PercentOfAllImportant', ascending=False)
        peak_igar_results[lead] = result_df
    return peak_igar_results

skf = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)
fold_results = []
all_fold_igar = []
all_peak_igar_rows = []

for fold, (train_ids, test_ids) in enumerate(skf.split(np.zeros(len(labels)), labels)):
    print(f'FOLD {fold}\n' + '-'*30)
    train_subset = Subset(dataset, train_ids)
    test_subset = Subset(dataset, test_ids)
    train_loader = DataLoader(train_subset, batch_size=32, shuffle=True)
    test_loader = DataLoader(test_subset, batch_size=32, shuffle=False)
    # ---- Model setup ----
    params = {'latent_dim': 256, 'input_dim_per_lead': 5000, 'num_leads': 12}
    prior_dist = prior_expert(params['latent_dim'])
    pretrained_mopoe = LSEMVAE(
        prior_dist=prior_dist,
        latent_dim=params['latent_dim'],
        num_leads=params['num_leads'],
        input_dim_per_lead=params['input_dim_per_lead']
    )
    state_dict = torch.load("../Main/HPC/pretrain/LS_EMVAE_with_reg_12_lead.pth", map_location=device)
    new_state_dict = {k.replace("_orig_mod.", ""): v for k, v in state_dict.items()}
    pretrained_mopoe.load_state_dict(new_state_dict, strict=False)
    pretrained_mopoe.to(device)
    model = ECGLeadClassifier(pretrained_mopoe=pretrained_mopoe, num_classes=2, use_12_leads=False).to(device)
    optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=1e-4)
    criterion = nn.CrossEntropyLoss()
    train_classifier(model, train_loader, criterion, optimizer, epochs=50, use_12_leads=False)
    accuracy, auc_score, f1, mcc = evaluate_model(model, test_loader, use_12_leads=False)
    fold_results.append((accuracy, auc_score, f1, mcc))
    print(f"Computing IGAR for fold {fold} ...")
    igar_df = compute_igar_for_fold(model, test_subset, device, threshold=0.7)
    igar_df['fold'] = fold
    all_fold_igar.append(igar_df)
    print(f'Fold {fold} Results: Accuracy: {accuracy:.4f}, AUC: {auc_score:.4f}, F1: {f1:.4f}, MCC: {mcc:.4f}\n')

    for sample_idx in range(len(test_subset)):
        lead_data_sample, label = test_subset[sample_idx]
        peak_igar_results = compute_peak_igar_for_sample(
            model, dict(lead_data_sample), device, threshold=0.6, sampling_rate=500
        )
        for lead, df in peak_igar_results.items():
            df['fold'] = fold
            df['sample_idx'] = sample_idx
            df['lead'] = lead
            all_peak_igar_rows.append(df)

# Aggregate classification metrics
accuracies, aucs, f1s, mccs = zip(*fold_results)
print(f'Mean Accuracy: {np.mean(accuracies):.4f}, STD: {np.std(accuracies):.4f}')
print(f'Mean AUC: {np.mean(aucs):.4f}, STD: {np.std(aucs):.4f}')
print(f'Mean F1: {np.mean(f1s):.4f}, STD: {np.std(f1s):.4f}')
print(f'Mean MCC: {np.mean(mccs):.4f}, STD: {np.std(mccs):.4f}')

# -------- IGAR per-lead summary (mean only, no std) --------
igar_all = pd.concat(all_fold_igar, ignore_index=True)
igar_summary = igar_all.groupby("lead").agg(
    IGAR_mean = ("IGAR", "mean"),
    IGAR_percent_mean = ("IGAR_percent", "mean")
).reset_index()
print("Aggregated IGAR metrics across all folds:")
print(igar_summary)

# -------- Per-lead, per-peak IGAR summary (mean only, no std) --------
if all_peak_igar_rows:
    peak_igar_all = pd.concat(all_peak_igar_rows, ignore_index=True)
    peak_igar_all = peak_igar_all.replace([np.inf, -np.inf], np.nan)
    # Remove ImportantPoints column if not needed for summary
    peak_igar_all = peak_igar_all.drop(columns=["ImportantPoints"], errors='ignore')
    peak_igar_summary = (
        peak_igar_all.groupby(["lead", "Peak"])
        .agg(
            Peak_IGAR_mean=pd.NamedAgg(column="Peak_IGAR", aggfunc=lambda x: np.nanmean(x)),
            PercentOfAllImportant_mean=pd.NamedAgg(column="PercentOfAllImportant", aggfunc=lambda x: np.nanmean(x)),
        )
        .reset_index()
    )
    print("Aggregated IGAR and Percent for each peak and lead (mean only):")
    print(peak_igar_summary)
else:
    print("No peak IGAR results to aggregate.")

Task-3

In [None]:
import torch
from torch.utils.data import Dataset, DataLoader
# File paths
lead_file_paths = {
    "LEAD_I": "D:/ukbiobank/ECG_PAWP_UKB_Final/LEAD_I.pt",
    "LEAD_II": "D:/ukbiobank/ECG_PAWP_UKB_Final/LEAD_II.pt",
    "LEAD_III": "D:/ukbiobank/ECG_PAWP_UKB_Final/LEAD_III.pt",
    "LEAD_aVR": "D:/ukbiobank/ECG_PAWP_UKB_Final/LEAD_aVR.pt",
    "LEAD_aVL": "D:/ukbiobank/ECG_PAWP_UKB_Final/LEAD_aVL.pt",
    "LEAD_aVF": "D:/ukbiobank/ECG_PAWP_UKB_Final/LEAD_aVF.pt",
    "LEAD_V1": "D:/ukbiobank/ECG_PAWP_UKB_Final/LEAD_V1.pt",
    "LEAD_V2": "D:/ukbiobank/ECG_PAWP_UKB_Final/LEAD_V2.pt",
    "LEAD_V3": "D:/ukbiobank/ECG_PAWP_UKB_Final/LEAD_V3.pt",
    "LEAD_V4": "D:/ukbiobank/ECG_PAWP_UKB_Final/LEAD_V4.pt",
    "LEAD_V5": "D:/ukbiobank/ECG_PAWP_UKB_Final/LEAD_V5.pt",
    "LEAD_V6": "D:/ukbiobank/ECG_PAWP_UKB_Final/LEAD_V6.pt"
}
labels_file_path = "D:/ukbiobank/ECG_PAWP_UKB_Final/labels.pt"

# Load all lead tensors and labels
ecg_lead_tensors = {lead: torch.load(path) for lead, path in lead_file_paths.items()}
labels = torch.load(labels_file_path)

# Ensure all leads have the same number of samples as the labels
sample_count = len(next(iter(ecg_lead_tensors.values())))
print(len(labels))
print(sample_count)
assert len(labels) == sample_count, "Mismatch between number of labels and samples."
for tensor in ecg_lead_tensors.values():
    assert len(tensor) == sample_count, "All leads must have the same number of samples."

# Define the dataset class
class ECGMultiLeadDatasetWithLabels(Dataset):
    def __init__(self, ecg_leads, labels):
        self.ecg_leads = ecg_leads
        self.labels = labels

    def __len__(self):
        return len(next(iter(self.ecg_leads.values())))

    def __getitem__(self, idx):
        # Return each lead sample with the correct input shape and corresponding label
        lead_data = {lead: self.ecg_leads[lead][idx].unsqueeze(0) for lead in self.ecg_leads}
        label = self.labels[idx]
        return lead_data, label

# Initialize the dataset and dataloader
dataset = ECGMultiLeadDatasetWithLabels(ecg_lead_tensors, labels)

In [None]:
import torch
import random
import numpy as np
from torch.utils.data import DataLoader, Subset
from sklearn.model_selection import StratifiedKFold
from sklearn.metrics import accuracy_score, roc_auc_score, f1_score, matthews_corrcoef
import torch.optim as optim
import torch.nn as nn
from captum.attr import IntegratedGradients
import pandas as pd
import neurokit2 as nk

import warnings
warnings.filterwarnings("ignore")


# -------------------- Setup --------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

def set_seed(seed_value):
    random.seed(seed_value)
    np.random.seed(seed_value)
    torch.manual_seed(seed_value)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed_value)
        torch.cuda.manual_seed_all(seed_value)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
set_seed(123)

# -------------------- Training Function --------------------
def train_classifier(model, train_loader, criterion, optimizer, epochs=10, use_12_leads=True):
    model.train()
    lead_names = (
        ["LEAD_I", "LEAD_II", "LEAD_III", "LEAD_aVR", "LEAD_aVF", "LEAD_aVL"]
        if not use_12_leads else
        ["LEAD_I", "LEAD_II", "LEAD_III", "LEAD_aVR", "LEAD_aVF", "LEAD_aVL",
         "LEAD_V1", "LEAD_V2", "LEAD_V3", "LEAD_V4", "LEAD_V5", "LEAD_V6"]
    )
    for epoch in range(epochs):
        total_loss = 0
        for batch in train_loader:
            lead_data, labels = batch
            labels = labels.to(device)
            lead_data = {lead: lead_data[lead].to(device) for lead in lead_names}
            optimizer.zero_grad()
            outputs = model(lead_data)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            total_loss += loss.item() * labels.size(0)
        avg_loss = total_loss / len(train_loader.dataset)
        print(f'Epoch {epoch + 1}/{epochs}, Loss: {avg_loss:.4f}')

# -------------------- Evaluation Function --------------------
def evaluate_model(model, data_loader, use_12_leads=True):
    model.eval()
    all_labels = []
    all_probs = []
    all_preds = []
    lead_names = (
        ["LEAD_I", "LEAD_II", "LEAD_III", "LEAD_aVR", "LEAD_aVF", "LEAD_aVL"]
        if not use_12_leads else
        ["LEAD_I", "LEAD_II", "LEAD_III", "LEAD_aVR", "LEAD_aVF", "LEAD_aVL",
         "LEAD_V1", "LEAD_V2", "LEAD_V3", "LEAD_V4", "LEAD_V5", "LEAD_V6"]
    )
    with torch.no_grad():
        for batch in data_loader:
            lead_data, labels = batch
            labels = labels.to(device)
            lead_data = {lead: lead_data[lead].to(device) for lead in lead_names}
            logits = model(lead_data)
            probs = torch.softmax(logits, dim=1)
            preds = torch.argmax(probs, dim=1)
            all_labels.extend(labels.cpu().numpy())
            all_probs.extend(probs[:, 1].cpu().numpy())
            all_preds.extend(preds.cpu().numpy())
    accuracy = accuracy_score(all_labels, all_preds)
    auc_score = roc_auc_score(all_labels, all_probs)
    f1 = f1_score(all_labels, all_preds)
    mcc = matthews_corrcoef(all_labels, all_preds)
    return accuracy, auc_score, f1, mcc

# -------------------- IGAR Metric Function --------------------
def compute_igar_for_fold(model, val_dataset, device, threshold=0.7):
    model.eval()
    igar_data = []
    for sample_idx in range(len(val_dataset)):
        lead_data_sample, label = val_dataset[sample_idx]
        lead_inputs = []
        for lead in model.lead_names:
            tensor = lead_data_sample[lead]
            if tensor.dim() == 2:
                tensor = tensor.unsqueeze(0)
            tensor = tensor.to(device)
            tensor.requires_grad_(True)
            lead_inputs.append(tensor)
        inputs_tuple = tuple(lead_inputs)
        def model_wrapper(*inputs):
            return model({lead: tensor for lead, tensor in zip(model.lead_names, inputs)})
        with torch.no_grad():
            logits = model_wrapper(*inputs_tuple)
        predicted_label = torch.argmax(logits, dim=1).item()
        integrated_gradients = IntegratedGradients(model_wrapper)
        attributions, _ = integrated_gradients.attribute(
            inputs=inputs_tuple,
            target=predicted_label,
            return_convergence_delta=True
        )
        lead_igar_scores = {}
        for idx, lead in enumerate(model.lead_names):
            attr = attributions[idx].detach().cpu().numpy().squeeze()
            norm_attr = (attr - attr.min()) / (attr.max() - attr.min() + 1e-10)
            important_indices = np.where(norm_attr >= threshold)[0]
            igar = len(important_indices) / len(attr)
            lead_igar_scores[lead] = igar
        total_igar = sum(lead_igar_scores.values())
        for lead, igar in lead_igar_scores.items():
            percent = (igar / total_igar * 100) if total_igar > 0 else 0
            igar_data.append({
                "sample_idx": sample_idx,
                "lead": lead,
                "IGAR": igar,
                "IGAR_percent": percent
            })
    df = pd.DataFrame(igar_data)
    return df

# ----------- Peak-level IGAR Calculation (per sample) -----------
def compute_peak_igar_for_sample(
    model, lead_data_sample, device, threshold=0.6, sampling_rate=500
):
    import numpy as np
    import pandas as pd
    import neurokit2 as nk
    from captum.attr import IntegratedGradients
    # Prepare data for model
    for lead in lead_data_sample:
        if lead_data_sample[lead].dim() == 2:
            lead_data_sample[lead] = lead_data_sample[lead].unsqueeze(0)
        lead_data_sample[lead] = lead_data_sample[lead].to(device)
        lead_data_sample[lead].requires_grad_(True)
    def model_wrapper(*inputs):
        lead_data = {lead: tensor for lead, tensor in zip(model.lead_names, inputs)}
        return model(lead_data)
    inputs_tuple = tuple(lead_data_sample[lead] for lead in model.lead_names)
    model.eval()
    with torch.no_grad():
        logits = model_wrapper(*inputs_tuple)
    predicted_label = torch.argmax(logits, dim=1).item()
    integrated_gradients = IntegratedGradients(model_wrapper)
    attributions, _ = integrated_gradients.attribute(
        inputs=inputs_tuple,
        target=predicted_label,
        return_convergence_delta=True
    )
    peak_igar_results = {}
    window_seconds = 0.04
    window = int(window_seconds * sampling_rate)
    for idx, lead in enumerate(model.lead_names):
        ecg_signal = lead_data_sample[lead].detach().cpu().numpy().squeeze()
        attr = attributions[idx].detach().cpu().numpy().squeeze()
        norm_attr = (attr - attr.min()) / (attr.max() - attr.min() + 1e-10)
        important_indices = np.where(norm_attr >= threshold)[0]
        try:
            ecg_cleaned = nk.ecg_clean(ecg_signal, sampling_rate=sampling_rate)
            rpeaks_dict = nk.ecg_peaks(ecg_cleaned, sampling_rate=sampling_rate)[1]
            r_locs = rpeaks_dict.get('ECG_R_Peaks', [])
            if not isinstance(r_locs, (list, np.ndarray)) or len(r_locs) == 0:
                print(f"No R peaks detected for {lead}. Skipping.")
                continue
            _, waves_peak = nk.ecg_delineate(
                ecg_cleaned,
                rpeaks=r_locs,
                sampling_rate=sampling_rate,
                method="dwt"
            )
            def safe_peaks(peaks):
                if peaks is None:
                    return []
                return [
                    int(i) for i in peaks
                    if (
                        isinstance(i, (int, np.integer))
                        or (isinstance(i, float) and not np.isnan(i))
                    )
                ]
            p_locs = safe_peaks(waves_peak.get('ECG_P_Peaks', []))
            q_locs = safe_peaks(waves_peak.get('ECG_Q_Peaks', []))
            s_locs = safe_peaks(waves_peak.get('ECG_S_Peaks', []))
            t_locs = safe_peaks(waves_peak.get('ECG_T_Peaks', []))
        except Exception as e:
            print(f"NeuroKit2 delineation failed for {lead}: {e}")
            continue
        peak_indices = {'P': p_locs, 'Q': q_locs, 'R': r_locs, 'S': s_locs, 'T': t_locs}
        peak_import_points = {}
        peak_igar = {}
        for peak_type, locs in peak_indices.items():
            window_indices = set()
            for idx_peak in locs:
                start = max(0, idx_peak - window)
                end = min(len(ecg_signal), idx_peak + window + 1)
                window_indices.update(range(start, end))
            imp_points = [ix for ix in important_indices if ix in window_indices]
            peak_import_points[peak_type] = len(imp_points)
            peak_igar[peak_type] = len(imp_points) / len(ecg_signal)
        total_imp_points = sum(peak_import_points.values())
        peak_percent = {
            peak_type: (peak_import_points[peak_type] / total_imp_points * 100 if total_imp_points > 0 else 0)
            for peak_type in peak_import_points
        }
        result_df = pd.DataFrame({
            'Peak': list(peak_igar.keys()),
            'Peak_IGAR': list(peak_igar.values()),
            'ImportantPoints': list(peak_import_points.values()),
            'PercentOfAllImportant': list(peak_percent.values())
        })
        result_df = result_df.sort_values('PercentOfAllImportant', ascending=False)
        peak_igar_results[lead] = result_df
    return peak_igar_results

skf = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)
fold_results = []
all_fold_igar = []
all_peak_igar_rows = []

for fold, (train_ids, test_ids) in enumerate(skf.split(np.zeros(len(labels)), labels)):
    print(f'FOLD {fold}\n' + '-'*30)
    train_subset = Subset(dataset, train_ids)
    test_subset = Subset(dataset, test_ids)
    train_loader = DataLoader(train_subset, batch_size=32, shuffle=True)
    test_loader = DataLoader(test_subset, batch_size=32, shuffle=False)
    # ---- Model setup ----
    params = {'latent_dim': 256, 'input_dim_per_lead': 5000, 'num_leads': 12}
    prior_dist = prior_expert(params['latent_dim'])
    pretrained_mopoe = MoPoE(
        prior_dist=prior_dist,
        latent_dim=params['latent_dim'],
        num_leads=params['num_leads'],
        input_dim_per_lead=params['input_dim_per_lead']
    )
    state_dict = torch.load("../Main/HPC/pretrain/LS_EMVAE_with_reg_12_lead.pth", map_location=device)
    new_state_dict = {k.replace("_orig_mod.", ""): v for k, v in state_dict.items()}
    pretrained_mopoe.load_state_dict(new_state_dict, strict=False)
    pretrained_mopoe.to(device)
    model = ECGLeadClassifier(pretrained_mopoe=pretrained_mopoe, num_classes=2, use_12_leads=False).to(device)
    optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=1e-4)
    criterion = nn.CrossEntropyLoss()
    train_classifier(model, train_loader, criterion, optimizer, epochs=50, use_12_leads=False)
    accuracy, auc_score, f1, mcc = evaluate_model(model, test_loader, use_12_leads=False)
    fold_results.append((accuracy, auc_score, f1, mcc))
    print(f"Computing IGAR for fold {fold} ...")
    igar_df = compute_igar_for_fold(model, test_subset, device, threshold=0.7)
    igar_df['fold'] = fold
    all_fold_igar.append(igar_df)
    print(f'Fold {fold} Results: Accuracy: {accuracy:.4f}, AUC: {auc_score:.4f}, F1: {f1:.4f}, MCC: {mcc:.4f}\n')
    for sample_idx in range(len(test_subset)):
        lead_data_sample, label = test_subset[sample_idx]
        peak_igar_results = compute_peak_igar_for_sample(
            model, dict(lead_data_sample), device, threshold=0.6, sampling_rate=500
        )
        for lead, df in peak_igar_results.items():
            df['fold'] = fold
            df['sample_idx'] = sample_idx
            df['lead'] = lead
            all_peak_igar_rows.append(df)

# Aggregate classification metrics
accuracies, aucs, f1s, mccs = zip(*fold_results)
print(f'Mean Accuracy: {np.mean(accuracies):.4f}, STD: {np.std(accuracies):.4f}')
print(f'Mean AUC: {np.mean(aucs):.4f}, STD: {np.std(aucs):.4f}')
print(f'Mean F1: {np.mean(f1s):.4f}, STD: {np.std(f1s):.4f}')
print(f'Mean MCC: {np.mean(mccs):.4f}, STD: {np.std(mccs):.4f}')

# -------- IGAR per-lead summary (mean only, no std) --------
igar_all = pd.concat(all_fold_igar, ignore_index=True)
igar_summary = igar_all.groupby("lead").agg(
    IGAR_mean = ("IGAR", "mean"),
    IGAR_percent_mean = ("IGAR_percent", "mean")
).reset_index()
print("Aggregated IGAR metrics across all folds:")
print(igar_summary)

# -------- Per-lead, per-peak IGAR summary (mean only, no std) --------
if all_peak_igar_rows:
    peak_igar_all = pd.concat(all_peak_igar_rows, ignore_index=True)
    peak_igar_all = peak_igar_all.replace([np.inf, -np.inf], np.nan)
    # Remove ImportantPoints column if not needed for summary
    peak_igar_all = peak_igar_all.drop(columns=["ImportantPoints"], errors='ignore')
    peak_igar_summary = (
        peak_igar_all.groupby(["lead", "Peak"])
        .agg(
            Peak_IGAR_mean=pd.NamedAgg(column="Peak_IGAR", aggfunc=lambda x: np.nanmean(x)),
            PercentOfAllImportant_mean=pd.NamedAgg(column="PercentOfAllImportant", aggfunc=lambda x: np.nanmean(x)),
        )
        .reset_index()
    )
    print("Aggregated IGAR and Percent for each peak and lead (mean only):")
    print(peak_igar_summary)
else:
    print("No peak IGAR results to aggregate.")

Visualization

In [None]:
import matplotlib.pyplot as plt
import numpy as np

# First chart data (original values)
lead_data_1 = {
    "LEAD-I": 12.564727,
    "LEAD-II": 21.844531,
    "LEAD-III": 13.018823,
    "LEAD-aVF": 11.387362,
    "LEAD-aVL": 13.751365,
    "LEAD-aVR": 27.433193
}

peak_data_1 = {
    "LEAD-I":     [3.036466, 9.138413, 49.824495, 30.227844, 7.548567],
    "LEAD-II":    [4.417913, 8.735130, 45.335424, 29.172962, 11.544920],
    "LEAD-III":   [5.750188, 16.674821, 50.809959, 16.893581, 8.047049],
    "LEAD-aVF":   [2.444300, 11.491595, 47.125567, 33.095187, 5.729065],
    "LEAD-aVL":   [6.625456, 12.417031, 48.337901, 21.719252, 10.328932],
    "LEAD-aVR":   [3.843334, 32.156304, 41.643393, 17.645506, 4.141337]
}

# Second chart data (IGAR_percent_mean and PercentOfAllImportant_mean)
lead_data_2 = {
    "LEAD-I": 12.130735,
    "LEAD-II": 22.511896,
    "LEAD-III": 15.411085,
    "LEAD-aVF": 14.734527,
    "LEAD-aVL": 19.132485,
    "LEAD-aVR": 16.079271
}

peak_data_2 = {
    "LEAD-I":     [3.790666, 9.092196, 46.494400, 32.260213, 8.073088],
    "LEAD-II":    [5.751059, 8.778773, 43.917717, 29.813044, 11.300168],
    "LEAD-III":   [5.370212, 16.532282, 54.104467, 15.730855, 7.378531],
    "LEAD-aVF":   [2.228947, 10.425398, 45.050057, 36.523867, 5.771732],
    "LEAD-aVL":   [7.383158, 12.135778, 49.124514, 20.593080, 10.026007],
    "LEAD-aVR":   [3.511032, 32.936929, 42.726368, 16.250718, 3.991012]
}

# Base colors
base_colors = {
    "LEAD-I": "#CB4335",
    "LEAD-II": "#2E86C1",
    "LEAD-III": "#239B56",
    "LEAD-aVF": "#AF7AC5",
    "LEAD-aVL": "#F39C12",
    "LEAD-aVR": "#7F8C8D"
}

lighter_colors = {
    "LEAD-I": ["#EC7063", "#F1948A", "#F5B7B1", "#FADBD8", "#FDEDEC"],
    "LEAD-II": ["#5DADE2", "#85C1E9", "#AED6F1", "#D6EAF8", "#EBF5FB"],
    "LEAD-III": ["#58D68D", "#82E0AA", "#ABEBC6", "#D5F5E3", "#EAFAF1"],
    "LEAD-aVF": ["#C39BD3", "#D7BDE2", "#E8DAEF", "#F4ECF7", "#FBF6FB"],
    "LEAD-aVL": ["#F8C471", "#FAD7A0", "#FDEBD0", "#FEF5E7", "#FFFCF5"],
    "LEAD-aVR": ["#95A5A6", "#B2BABB", "#CCD1D1", "#E5E8E8", "#F7F9F9"]
}

base_colors = {
    "LEAD-I":   "#3b7fba",   # Softened Blue (was #2066a8)
    "LEAD-II":  "#7cab5f",   # Softer Olive Green (was #6a994e)
    "LEAD-III": "#f3cd94",   # Slightly lighter Gold (was #f0c987)
    "LEAD-aVF": "#6fb7aa",   # Softer Teal (was #59a89c)
    "LEAD-aVL": "#b86ab7",   # Softer Purple (was #a559aa)
    "LEAD-aVR": "#e6434c"    # Softer Red (was #e02b35)
}


lighter_colors = {
    "LEAD-I":    ["#5a97cb", "#7aafdc", "#9ac7ed", "#badaf2", "#def0fa"],
    "LEAD-II":   ["#90bb7c", "#a5cb99", "#badcb6", "#cfecd3", "#e4fce9"],
    "LEAD-III":  ["#f5d6a9", "#f7e0be", "#fae9d4", "#fcf3e9", "#fefcf5"],
    "LEAD-aVF":  ["#8bc6bc", "#a7d6ce", "#c3e5e1", "#dff5f3", "#effcfa"],
    "LEAD-aVL":  ["#c885c7", "#d7a0d6", "#e5bbe5", "#f4d6f4", "#fdf0fd"],
    "LEAD-aVR":  ["#ea6a71", "#f09196", "#f7b8ba", "#fddedf", "#fff5f5"]
}


peak_labels = ['P', 'Q', 'R', 'S', 'T']

# Function to prepare pie data
def prepare_pie_data(lead_data, peak_data):
    outer_values = []
    outer_colors = []
    outer_labels = []
    for lead, percent in lead_data.items():
        for idx, peak_pct in enumerate(peak_data[lead]):
            value = peak_pct * percent / 100.0
            outer_values.append(value)
            outer_colors.append(lighter_colors[lead][idx])
            outer_labels.append(peak_labels[idx])
    inner_values = list(lead_data.values())
    inner_colors = [base_colors[k] for k in lead_data]
    inner_keys = list(lead_data.keys())
    return outer_values, outer_colors, outer_labels, inner_values, inner_colors, inner_keys

outer_values1, outer_colors1, outer_labels1, inner_values1, inner_colors1, inner_keys1 = prepare_pie_data(lead_data_1, peak_data_1)
outer_values2, outer_colors2, outer_labels2, inner_values2, inner_colors2, inner_keys2 = prepare_pie_data(lead_data_2, peak_data_2)

# Plotting side by side
fig, axs = plt.subplots(1, 2, figsize=(10, 6))
axs[0].text(0.5, 0.95, "PH Detection", transform=axs[0].transAxes,
            fontsize=18, fontweight='bold', ha='center')
axs[1].text(0.5, 0.95, "Phenotyping PH", transform=axs[1].transAxes,
            fontsize=18, fontweight='bold', ha='center')


for ax, outer_values, outer_colors, outer_labels, inner_values, inner_colors, inner_keys in zip(
        axs, [outer_values1, outer_values2], [outer_colors1, outer_colors2], [outer_labels1, outer_labels2],
        [inner_values1, inner_values2], [inner_colors1, inner_colors2], [inner_keys1, inner_keys2]):

    wedges_outer, _ = ax.pie(
        outer_values, radius=1, labels=outer_labels, colors=outer_colors,
        wedgeprops=dict(width=0.3, edgecolor='white'),
        labeldistance=1.01, textprops={'fontsize': 9, 'fontweight': 'bold'}
    )

    wedges_inner, _ = ax.pie(
        inner_values, radius=0.7, labels=None, colors=inner_colors,
        wedgeprops=dict(width=0.3, edgecolor='white')
    )

    for w, val in zip(wedges_outer, outer_values):
        ang = (w.theta2 + w.theta1) / 2
        x = np.cos(np.deg2rad(ang)) * 0.85
        y = np.sin(np.deg2rad(ang)) * 0.85
        ax.text(x, y, f"{val:.1f}%", ha='center', va='center', fontsize=8)

    for w, val, name in zip(wedges_inner, inner_values, inner_keys):
        ang = (w.theta2 + w.theta1) / 2
        x = np.cos(np.deg2rad(ang)) * 0.5
        y = np.sin(np.deg2rad(ang)) * 0.5
        ax.text(x, y + 0.01, f"{name}", ha='center', va='center', fontsize=8, fontweight='bold')
        ax.text(x, y - 0.05, f"{val:.1f}%", ha='center', va='center', fontsize=8)

plt.tight_layout()
plt.show()
fig.savefig("quantitative_analysis.pdf", format="pdf", dpi=600, bbox_inches='tight')

import fitz 
# Open PDF
doc = fitz.open("quantitative_analysis.pdf")
page = doc[0]

# Crop the page: shrink the bottom (units = points; 72 points = 1 inch)
rect = page.rect
new_rect = fitz.Rect(rect.x0+45, rect.y0+15, rect.x1-45, rect.y1- 45)  # crop 30pt from bottom
page.set_cropbox(new_rect)

# Save
doc.save("quantitative_analysis_cropped.pdf")
