Task-1

In [None]:
import torch
from torch.utils.data import Dataset, DataLoader

# Define file paths for the 12-lead ECG data and labels
lead_file_paths = {
    "LEAD_I": "../Data_processing/data_aspire_PAP/LEAD_I.pt",
    "LEAD_II": "../Data_processing/data_aspire_PAP/LEAD_II.pt",
    "LEAD_III": "../Data_processing/data_aspire_PAP/LEAD_III.pt",
    "LEAD_aVR": "../Data_processing/data_aspire_PAP/LEAD_aVR.pt",
    "LEAD_aVL": "../Data_processing/data_aspire_PAP/LEAD_aVL.pt",
    "LEAD_aVF": "../Data_processing/data_aspire_PAP/LEAD_aVF.pt",
    "LEAD_V1": "../Data_processing/data_aspire_PAP/LEAD_V1.pt",
    "LEAD_V2": "../Data_processing/data_aspire_PAP/LEAD_V2.pt",
    "LEAD_V3": "../Data_processing/data_aspire_PAP/LEAD_V3.pt",
    "LEAD_V4": "../Data_processing/data_aspire_PAP/LEAD_V4.pt",
    "LEAD_V5": "../Data_processing/data_aspire_PAP/LEAD_V5.pt",
    "LEAD_V6": "../Data_processing/data_aspire_PAP/LEAD_V6.pt"
}
labels_file_path = "../Data_processing/data_aspire_PAP/labels.pt"


# Load all lead tensors and labels
ecg_lead_tensors = {lead: torch.load(path) for lead, path in lead_file_paths.items()}
labels = torch.load(labels_file_path)

# Ensure all leads have the same number of samples as the labels
sample_count = len(next(iter(ecg_lead_tensors.values())))
assert len(labels) == sample_count, "Mismatch between number of labels and samples."
for tensor in ecg_lead_tensors.values():
    assert len(tensor) == sample_count, "All leads must have the same number of samples."

# Define the dataset class
class ECGMultiLeadDatasetWithLabels(Dataset):
    def __init__(self, ecg_leads, labels):
        self.ecg_leads = ecg_leads
        self.labels = labels

    def __len__(self):
        return len(next(iter(self.ecg_leads.values())))

    def __getitem__(self, idx):
        # Return each lead sample with the correct input shape and corresponding label
        lead_data = {lead: self.ecg_leads[lead][idx].unsqueeze(0) for lead in self.ecg_leads}
        label = self.labels[idx]
        return lead_data, label

# Initialize the dataset and dataloader
dataset = ECGMultiLeadDatasetWithLabels(ecg_lead_tensors, labels)



In [2]:
# **Classifier Model Class**
class ECGLeadClassifier(nn.Module):
    def __init__(self, pretrained_mopoe, num_classes, use_12_leads=True):
        super(ECGLeadClassifier, self).__init__()

        # Define lead names based on the mode
        self.use_12_leads = use_12_leads
        self.lead_names = (
            ["LEAD_I", "LEAD_II", "LEAD_III", "LEAD_aVR", "LEAD_aVF", "LEAD_aVL",
             "LEAD_V1", "LEAD_V2", "LEAD_V3", "LEAD_V4", "LEAD_V5", "LEAD_V6"]
            if use_12_leads
            else ["LEAD_I", "LEAD_II", "LEAD_III", "LEAD_aVR", "LEAD_aVF", "LEAD_aVL"]
        )

        # Select encoders based on the leads to use
        encoder_indices = range(12) if use_12_leads else range(6)
        self.lead_encoders = nn.ModuleList([pretrained_mopoe.encoders[i] for i in encoder_indices])

        self.feature_dim = pretrained_mopoe.latent_dim

        # Freeze encoder weights
        for encoder in self.lead_encoders:
            for param in encoder.parameters():
                param.requires_grad = False

        # Define classifier network
        self.classifier = nn.Sequential(
            nn.Linear(len(self.lead_names) * self.feature_dim, 128),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(128, num_classes)
        )

    def forward(self, lead_data):
        lead_features = []
        for lead_name, encoder in zip(self.lead_names, self.lead_encoders):
            mu, _ = encoder(lead_data[lead_name])
            lead_features.append(mu)

        combined_features = torch.cat(lead_features, dim=1)
        logits = self.classifier(combined_features)
        return logits

In [None]:
import torch
import random
import numpy as np
from torch.utils.data import DataLoader, Subset
from sklearn.model_selection import StratifiedKFold
from sklearn.metrics import accuracy_score, roc_auc_score, f1_score, matthews_corrcoef
import torch.optim as optim
import torch.nn as nn

# **Set device configuration**
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# **Set seed for reproducibility**
def set_seed(seed_value):
    random.seed(seed_value)
    np.random.seed(seed_value)
    torch.manual_seed(seed_value)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed_value)
        torch.cuda.manual_seed_all(seed_value)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_seed(123)

# **Training Function**
def train_classifier(model, train_loader, criterion, optimizer, epochs=10, use_12_leads=True):
    model.train()
    
    # Select lead names based on `use_12_leads`
    lead_names = (
        ["LEAD_I", "LEAD_II", "LEAD_III", "LEAD_aVR", "LEAD_aVF", "LEAD_aVL"]
        if not use_12_leads else
        ["LEAD_I", "LEAD_II", "LEAD_III", "LEAD_aVR", "LEAD_aVF", "LEAD_aVL",
         "LEAD_V1", "LEAD_V2", "LEAD_V3", "LEAD_V4", "LEAD_V5", "LEAD_V6"]
    )
    
    for epoch in range(epochs):
        total_loss = 0
        for batch in train_loader:
            lead_data, labels = batch
            labels = labels.to(device)

            # Filter lead_data to use only selected leads
            lead_data = {lead: lead_data[lead].to(device) for lead in lead_names}

            optimizer.zero_grad()
            outputs = model(lead_data)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            total_loss += loss.item() * labels.size(0)

        avg_loss = total_loss / len(train_loader.dataset)
        print(f'Epoch {epoch + 1}/{epochs}, Loss: {avg_loss:.4f}')


# **Evaluation Function**
def evaluate_model(model, data_loader, use_12_leads=True):
    model.eval()
    all_labels = []
    all_probs = []
    all_preds = []

    # Select lead names based on `use_12_leads`
    lead_names = (
        ["LEAD_I", "LEAD_II", "LEAD_III", "LEAD_aVR", "LEAD_aVF", "LEAD_aVL"]
        if not use_12_leads else
        ["LEAD_I", "LEAD_II", "LEAD_III", "LEAD_aVR", "LEAD_aVF", "LEAD_aVL",
         "LEAD_V1", "LEAD_V2", "LEAD_V3", "LEAD_V4", "LEAD_V5", "LEAD_V6"]
    )

    with torch.no_grad():
        for batch in data_loader:
            lead_data, labels = batch
            labels = labels.to(device)

            # Filter lead_data to use only selected leads
            lead_data = {lead: lead_data[lead].to(device) for lead in lead_names}

            logits = model(lead_data)
            probs = torch.softmax(logits, dim=1)
            preds = torch.argmax(probs, dim=1)

            all_labels.extend(labels.cpu().numpy())
            all_probs.extend(probs[:, 1].cpu().numpy())
            all_preds.extend(preds.cpu().numpy())

    accuracy = accuracy_score(all_labels, all_preds)
    auc_score = roc_auc_score(all_labels, all_probs)
    f1 = f1_score(all_labels, all_preds)
    mcc = matthews_corrcoef(all_labels, all_preds)

    return accuracy, auc_score, f1, mcc


# **Cross-Validation Training**
skf = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)
fold_results = []

for fold, (train_ids, test_ids) in enumerate(skf.split(np.zeros(len(labels)), labels)):
    print(f'FOLD {fold}')
    print('--------------------------------')

    train_subset = Subset(dataset, train_ids)
    test_subset = Subset(dataset, test_ids)

    train_loader = DataLoader(train_subset, batch_size=32, shuffle=True)
    test_loader = DataLoader(test_subset, batch_size=32, shuffle=False)

    # Load pretrained MoPoE and initialize classifier

    # Define parameters
    params = {'latent_dim': 256, 'input_dim_per_lead': 5000, 'num_leads': 12}

    # Instantiate the prior distribution
    prior_dist = prior_expert(params['latent_dim'])

    # Create the MoPoE model (specific for 12-lead ECG data)
    pretrained_mopoe = LSEMVAE_MOE(
        prior_dist=prior_dist,
        latent_dim=params['latent_dim'],
        num_leads=params['num_leads'],
        input_dim_per_lead=params['input_dim_per_lead']
    )

    # Load the pretrained weights
    # Load the saved state_dict
    state_dict = torch.load("../Main/HPC/pretrain/LS_EMVAE_with_reg_12_w_o_poe.pth", map_location=device)
    
    
    # Fix key mismatch by removing '_orig_mod.' prefix
    new_state_dict = {k.replace("_orig_mod.", ""): v for k, v in state_dict.items()}
    
    # Load the updated state dictionary
    pretrained_mopoe.load_state_dict(new_state_dict, strict=False)

    # Move the model to the device (GPU/CPU)
    pretrained_mopoe.to(device)

    model = ECGLeadClassifier(pretrained_mopoe=pretrained_mopoe, num_classes=2, use_12_leads=False).to(device)

    optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=1e-4)
    criterion = nn.CrossEntropyLoss()

    # Train and evaluate
    train_classifier(model, train_loader, criterion, optimizer, epochs=50, use_12_leads=False)
    accuracy, auc_score, f1, mcc = evaluate_model(model, test_loader, use_12_leads=False)
    fold_results.append((accuracy, auc_score, f1, mcc))

    print(f'Fold {fold} Results: Accuracy: {accuracy:.4f}, AUC: {auc_score:.4f}, F1: {f1:.4f}, MCC: {mcc:.4f}\n')

# **Calculate Metrics Across Folds**
accuracies, aucs, f1s, mccs = zip(*fold_results)
print(f'Mean Accuracy: {np.mean(accuracies):.4f}, STD: {np.std(accuracies):.4f}')
print(f'Mean AUC: {np.mean(aucs):.4f}, STD: {np.std(aucs):.4f}')
print(f'Mean F1: {np.mean(f1s):.4f}, STD: {np.std(f1s):.4f}')
print(f'Mean MCC: {np.mean(mccs):.4f}, STD: {np.std(mccs):.4f}')  # MCC Statistics


Task-2

In [21]:
import torch
from torch.utils.data import Dataset, DataLoader

lead_file_paths = {
    "LEAD_I": "../Data_processing/data_aspire_PAWP/LEAD_I.pt",
    "LEAD_II": "../Data_processing/data_aspire_PAWP/LEAD_II.pt",
    "LEAD_III": "../Data_processing/data_aspire_PAWP/LEAD_III.pt",
    "LEAD_aVR": "../Data_processing/data_aspire_PAWP/LEAD_aVR.pt",
    "LEAD_aVL": "../Data_processing/data_aspire_PAWP/LEAD_aVL.pt",
    "LEAD_aVF": "../Data_processing/data_aspire_PAWP/LEAD_aVF.pt",
    "LEAD_V1": "../Data_processing/data_aspire_PAWP/LEAD_V1.pt",
    "LEAD_V2": "../Data_processing/data_aspire_PAWP/LEAD_V2.pt",
    "LEAD_V3": "../Data_processing/data_aspire_PAWP/LEAD_V3.pt",
    "LEAD_V4": "../Data_processing/data_aspire_PAWP/LEAD_V4.pt",
    "LEAD_V5": "../Data_processing/data_aspire_PAWP/LEAD_V5.pt",
    "LEAD_V6": "../Data_processing/data_aspire_PAWP/LEAD_V6.pt"
}
labels_file_path = "../Data_processing/data_aspire_PAWP/labels.pt"


# Load all lead tensors and labels
ecg_lead_tensors = {lead: torch.load(path) for lead, path in lead_file_paths.items()}
labels = torch.load(labels_file_path)

# Ensure all leads have the same number of samples as the labels
sample_count = len(next(iter(ecg_lead_tensors.values())))
assert len(labels) == sample_count, "Mismatch between number of labels and samples."
for tensor in ecg_lead_tensors.values():
    assert len(tensor) == sample_count, "All leads must have the same number of samples."

# Define the dataset class
class ECGMultiLeadDatasetWithLabels(Dataset):
    def __init__(self, ecg_leads, labels):
        self.ecg_leads = ecg_leads
        self.labels = labels

    def __len__(self):
        return len(next(iter(self.ecg_leads.values())))

    def __getitem__(self, idx):
        # Return each lead sample with the correct input shape and corresponding label
        lead_data = {lead: self.ecg_leads[lead][idx].unsqueeze(0) for lead in self.ecg_leads}
        label = self.labels[idx]
        return lead_data, label

# Initialize the dataset and dataloader
dataset = ECGMultiLeadDatasetWithLabels(ecg_lead_tensors, labels)
dataloader = DataLoader(dataset, batch_size=128, shuffle=True)


In [None]:
import torch
import random
import numpy as np
from torch.utils.data import DataLoader, Subset
from sklearn.model_selection import StratifiedKFold
from sklearn.metrics import accuracy_score, roc_auc_score, f1_score, matthews_corrcoef
import torch.optim as optim
import torch.nn as nn

# **Set device configuration**
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# **Set seed for reproducibility**
def set_seed(seed_value):
    random.seed(seed_value)
    np.random.seed(seed_value)
    torch.manual_seed(seed_value)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed_value)
        torch.cuda.manual_seed_all(seed_value)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_seed(123)

# **Training Function**
def train_classifier(model, train_loader, criterion, optimizer, epochs=10, use_12_leads=True):
    model.train()
    
    # Select lead names based on `use_12_leads`
    lead_names = (
        ["LEAD_I", "LEAD_II", "LEAD_III", "LEAD_aVR", "LEAD_aVF", "LEAD_aVL"]
        if not use_12_leads else
        ["LEAD_I", "LEAD_II", "LEAD_III", "LEAD_aVR", "LEAD_aVF", "LEAD_aVL",
         "LEAD_V1", "LEAD_V2", "LEAD_V3", "LEAD_V4", "LEAD_V5", "LEAD_V6"]
    )
    
    for epoch in range(epochs):
        total_loss = 0
        for batch in train_loader:
            lead_data, labels = batch
            labels = labels.to(device)

            # Filter lead_data to use only selected leads
            lead_data = {lead: lead_data[lead].to(device) for lead in lead_names}

            optimizer.zero_grad()
            outputs = model(lead_data)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            total_loss += loss.item() * labels.size(0)

        avg_loss = total_loss / len(train_loader.dataset)
        print(f'Epoch {epoch + 1}/{epochs}, Loss: {avg_loss:.4f}')


# **Evaluation Function**
def evaluate_model(model, data_loader, use_12_leads=True):
    model.eval()
    all_labels = []
    all_probs = []
    all_preds = []

    # Select lead names based on `use_12_leads`
    lead_names = (
        ["LEAD_I", "LEAD_II", "LEAD_III", "LEAD_aVR", "LEAD_aVF", "LEAD_aVL"]
        if not use_12_leads else
        ["LEAD_I", "LEAD_II", "LEAD_III", "LEAD_aVR", "LEAD_aVF", "LEAD_aVL",
         "LEAD_V1", "LEAD_V2", "LEAD_V3", "LEAD_V4", "LEAD_V5", "LEAD_V6"]
    )

    with torch.no_grad():
        for batch in data_loader:
            lead_data, labels = batch
            labels = labels.to(device)

            # Filter lead_data to use only selected leads
            lead_data = {lead: lead_data[lead].to(device) for lead in lead_names}

            logits = model(lead_data)
            probs = torch.softmax(logits, dim=1)
            preds = torch.argmax(probs, dim=1)

            all_labels.extend(labels.cpu().numpy())
            all_probs.extend(probs[:, 1].cpu().numpy())
            all_preds.extend(preds.cpu().numpy())

    accuracy = accuracy_score(all_labels, all_preds)
    auc_score = roc_auc_score(all_labels, all_probs)
    f1 = f1_score(all_labels, all_preds)
    mcc = matthews_corrcoef(all_labels, all_preds)

    return accuracy, auc_score, f1, mcc


# **Cross-Validation Training**
skf = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)
fold_results = []

for fold, (train_ids, test_ids) in enumerate(skf.split(np.zeros(len(labels)), labels)):
    print(f'FOLD {fold}')
    print('--------------------------------')

    train_subset = Subset(dataset, train_ids)
    test_subset = Subset(dataset, test_ids)

    train_loader = DataLoader(train_subset, batch_size=32, shuffle=True)
    test_loader = DataLoader(test_subset, batch_size=32, shuffle=False)

    # Load pretrained MoPoE and initialize classifier

    # Define parameters
    params = {'latent_dim': 256, 'input_dim_per_lead': 5000, 'num_leads': 12}

    # Instantiate the prior distribution
    prior_dist = prior_expert(params['latent_dim'])

    # Create the MoPoE model (specific for 12-lead ECG data)
    pretrained_mopoe = LSEMVAE_MOE(
        prior_dist=prior_dist,
        latent_dim=params['latent_dim'],
        num_leads=params['num_leads'],
        input_dim_per_lead=params['input_dim_per_lead']
    )

    # Load the pretrained weights
    # Load the saved state_dict
    state_dict = torch.load("../Main/HPC/pretrain/LS_EMVAE_with_reg_12_w_o_poe.pth", map_location=device)
    
    
    # Fix key mismatch by removing '_orig_mod.' prefix
    new_state_dict = {k.replace("_orig_mod.", ""): v for k, v in state_dict.items()}
    
    # Load the updated state dictionary
    pretrained_mopoe.load_state_dict(new_state_dict, strict=False)

    # Move the model to the device (GPU/CPU)
    pretrained_mopoe.to(device)

    model = ECGLeadClassifier(pretrained_mopoe=pretrained_mopoe, num_classes=2, use_12_leads=False).to(device)

    optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=1e-4)
    criterion = nn.CrossEntropyLoss()

    # Train and evaluate
    train_classifier(model, train_loader, criterion, optimizer, epochs=50, use_12_leads=False)
    accuracy, auc_score, f1, mcc = evaluate_model(model, test_loader, use_12_leads=False)
    fold_results.append((accuracy, auc_score, f1, mcc))

    print(f'Fold {fold} Results: Accuracy: {accuracy:.4f}, AUC: {auc_score:.4f}, F1: {f1:.4f}, MCC: {mcc:.4f}\n')

# **Calculate Metrics Across Folds**
accuracies, aucs, f1s, mccs = zip(*fold_results)
print(f'Mean Accuracy: {np.mean(accuracies):.4f}, STD: {np.std(accuracies):.4f}')
print(f'Mean AUC: {np.mean(aucs):.4f}, STD: {np.std(aucs):.4f}')
print(f'Mean F1: {np.mean(f1s):.4f}, STD: {np.std(f1s):.4f}')
print(f'Mean MCC: {np.mean(mccs):.4f}, STD: {np.std(mccs):.4f}')  # MCC Statistics

Task-3

In [None]:
import torch
from torch.utils.data import Dataset, DataLoader
# File paths
lead_file_paths = {
    "LEAD_I": "D:/ukbiobank/ECG_PAWP_UKB_Final/LEAD_I.pt",
    "LEAD_II": "D:/ukbiobank/ECG_PAWP_UKB_Final/LEAD_II.pt",
    "LEAD_III": "D:/ukbiobank/ECG_PAWP_UKB_Final/LEAD_III.pt",
    "LEAD_aVR": "D:/ukbiobank/ECG_PAWP_UKB_Final/LEAD_aVR.pt",
    "LEAD_aVL": "D:/ukbiobank/ECG_PAWP_UKB_Final/LEAD_aVL.pt",
    "LEAD_aVF": "D:/ukbiobank/ECG_PAWP_UKB_Final/LEAD_aVF.pt",
    "LEAD_V1": "D:/ukbiobank/ECG_PAWP_UKB_Final/LEAD_V1.pt",
    "LEAD_V2": "D:/ukbiobank/ECG_PAWP_UKB_Final/LEAD_V2.pt",
    "LEAD_V3": "D:/ukbiobank/ECG_PAWP_UKB_Final/LEAD_V3.pt",
    "LEAD_V4": "D:/ukbiobank/ECG_PAWP_UKB_Final/LEAD_V4.pt",
    "LEAD_V5": "D:/ukbiobank/ECG_PAWP_UKB_Final/LEAD_V5.pt",
    "LEAD_V6": "D:/ukbiobank/ECG_PAWP_UKB_Final/LEAD_V6.pt"
}
labels_file_path = "D:/ukbiobank/ECG_PAWP_UKB_Final/labels.pt"

# Load all lead tensors and labels
ecg_lead_tensors = {lead: torch.load(path) for lead, path in lead_file_paths.items()}
labels = torch.load(labels_file_path)

# Ensure all leads have the same number of samples as the labels
sample_count = len(next(iter(ecg_lead_tensors.values())))
print(len(labels))
print(sample_count)
assert len(labels) == sample_count, "Mismatch between number of labels and samples."
for tensor in ecg_lead_tensors.values():
    assert len(tensor) == sample_count, "All leads must have the same number of samples."

# Define the dataset class
class ECGMultiLeadDatasetWithLabels(Dataset):
    def __init__(self, ecg_leads, labels):
        self.ecg_leads = ecg_leads
        self.labels = labels

    def __len__(self):
        return len(next(iter(self.ecg_leads.values())))

    def __getitem__(self, idx):
        # Return each lead sample with the correct input shape and corresponding label
        lead_data = {lead: self.ecg_leads[lead][idx].unsqueeze(0) for lead in self.ecg_leads}
        label = self.labels[idx]
        return lead_data, label

# Initialize the dataset and dataloader
dataset = ECGMultiLeadDatasetWithLabels(ecg_lead_tensors, labels)

In [None]:
import torch
import random
import numpy as np
from torch.utils.data import DataLoader, Subset
from sklearn.model_selection import StratifiedKFold
from sklearn.metrics import accuracy_score, roc_auc_score, f1_score, matthews_corrcoef
import torch.optim as optim
import torch.nn as nn

# **Set device configuration**
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# **Set seed for reproducibility**
def set_seed(seed_value):
    random.seed(seed_value)
    np.random.seed(seed_value)
    torch.manual_seed(seed_value)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed_value)
        torch.cuda.manual_seed_all(seed_value)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_seed(123)

# **Training Function**
def train_classifier(model, train_loader, criterion, optimizer, epochs=10, use_12_leads=True):
    model.train()
    
    # Select lead names based on `use_12_leads`
    lead_names = (
        ["LEAD_I", "LEAD_II", "LEAD_III", "LEAD_aVR", "LEAD_aVF", "LEAD_aVL"]
        if not use_12_leads else
        ["LEAD_I", "LEAD_II", "LEAD_III", "LEAD_aVR", "LEAD_aVF", "LEAD_aVL",
         "LEAD_V1", "LEAD_V2", "LEAD_V3", "LEAD_V4", "LEAD_V5", "LEAD_V6"]
    )
    
    for epoch in range(epochs):
        total_loss = 0
        for batch in train_loader:
            lead_data, labels = batch
            labels = labels.to(device)

            # Filter lead_data to use only selected leads
            lead_data = {lead: lead_data[lead].to(device) for lead in lead_names}

            optimizer.zero_grad()
            outputs = model(lead_data)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            total_loss += loss.item() * labels.size(0)

        avg_loss = total_loss / len(train_loader.dataset)
        print(f'Epoch {epoch + 1}/{epochs}, Loss: {avg_loss:.4f}')


# **Evaluation Function**
def evaluate_model(model, data_loader, use_12_leads=True):
    model.eval()
    all_labels = []
    all_probs = []
    all_preds = []

    # Select lead names based on `use_12_leads`
    lead_names = (
        ["LEAD_I", "LEAD_II", "LEAD_III", "LEAD_aVR", "LEAD_aVF", "LEAD_aVL"]
        if not use_12_leads else
        ["LEAD_I", "LEAD_II", "LEAD_III", "LEAD_aVR", "LEAD_aVF", "LEAD_aVL",
         "LEAD_V1", "LEAD_V2", "LEAD_V3", "LEAD_V4", "LEAD_V5", "LEAD_V6"]
    )

    with torch.no_grad():
        for batch in data_loader:
            lead_data, labels = batch
            labels = labels.to(device)

            # Filter lead_data to use only selected leads
            lead_data = {lead: lead_data[lead].to(device) for lead in lead_names}

            logits = model(lead_data)
            probs = torch.softmax(logits, dim=1)
            preds = torch.argmax(probs, dim=1)

            all_labels.extend(labels.cpu().numpy())
            all_probs.extend(probs[:, 1].cpu().numpy())
            all_preds.extend(preds.cpu().numpy())

    accuracy = accuracy_score(all_labels, all_preds)
    auc_score = roc_auc_score(all_labels, all_probs)
    f1 = f1_score(all_labels, all_preds)
    mcc = matthews_corrcoef(all_labels, all_preds)

    return accuracy, auc_score, f1, mcc


# **Cross-Validation Training**
skf = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)
fold_results = []

for fold, (train_ids, test_ids) in enumerate(skf.split(np.zeros(len(labels)), labels)):
    print(f'FOLD {fold}')
    print('--------------------------------')

    train_subset = Subset(dataset, train_ids)
    test_subset = Subset(dataset, test_ids)

    train_loader = DataLoader(train_subset, batch_size=32, shuffle=True)
    test_loader = DataLoader(test_subset, batch_size=32, shuffle=False)

    # Load pretrained MoPoE and initialize classifier

    # Define parameters
    params = {'latent_dim': 256, 'input_dim_per_lead': 5000, 'num_leads': 12}

    # Instantiate the prior distribution
    prior_dist = prior_expert(params['latent_dim'])

    # Create the MoPoE model (specific for 12-lead ECG data)
    pretrained_mopoe = LSEMVAE_MOE(
        prior_dist=prior_dist,
        latent_dim=params['latent_dim'],
        num_leads=params['num_leads'],
        input_dim_per_lead=params['input_dim_per_lead']
    )

    # Load the pretrained weights
    # Load the saved state_dict
    state_dict = torch.load("../Main/HPC/pretrain/LS_EMVAE_with_reg_12_w_o_poe.pth", map_location=device)
    
    
    # Fix key mismatch by removing '_orig_mod.' prefix
    new_state_dict = {k.replace("_orig_mod.", ""): v for k, v in state_dict.items()}
    
    # Load the updated state dictionary
    pretrained_mopoe.load_state_dict(new_state_dict, strict=False)

    # Move the model to the device (GPU/CPU)
    pretrained_mopoe.to(device)

    model = ECGLeadClassifier(pretrained_mopoe=pretrained_mopoe, num_classes=2, use_12_leads=False).to(device)

    optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=1e-4)
    criterion = nn.CrossEntropyLoss()

    # Train and evaluate
    train_classifier(model, train_loader, criterion, optimizer, epochs=50, use_12_leads=False)
    accuracy, auc_score, f1, mcc = evaluate_model(model, test_loader, use_12_leads=False)
    fold_results.append((accuracy, auc_score, f1, mcc))

    print(f'Fold {fold} Results: Accuracy: {accuracy:.4f}, AUC: {auc_score:.4f}, F1: {f1:.4f}, MCC: {mcc:.4f}\n')

# **Calculate Metrics Across Folds**
accuracies, aucs, f1s, mccs = zip(*fold_results)
print(f'Mean Accuracy: {np.mean(accuracies):.4f}, STD: {np.std(accuracies):.4f}')
print(f'Mean AUC: {np.mean(aucs):.4f}, STD: {np.std(aucs):.4f}')
print(f'Mean F1: {np.mean(f1s):.4f}, STD: {np.std(f1s):.4f}')
print(f'Mean MCC: {np.mean(mccs):.4f}, STD: {np.std(mccs):.4f}')  # MCC Statistics