Task-1 (mPAP)

In [1]:
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Wed Apr 26 14:34:18 2023
Latest Change on Wed May 03 09:56:10 2023

@author: hawkiyc
"""

'Import Libraries'

import torch
import torch.nn as nn

'Set Activation Function'
if __name__ == "__main__":
    stem_k, block_k = 11, 5
    activation = nn.ReLU(inplace=True)
    data_dim = 12
    out_dim = 5

'Bulid th Model'


class conv(nn.Module):

    def __init__(self, in_ch, out_ch, k_size=25, stride=1,
                 drop_r=None, zero_batch_norm=False,
                 bias=False, use_act_fun=True,
                 act_fun: nn.Module = activation):

        assert k_size % 2 == 1, 'kernel size shall be odd number'
        super(conv, self).__init__()
        self.conv1d = nn.Conv1d(in_ch, out_ch, k_size, stride,
                                padding=(k_size - 1) // 2, bias=bias, )
        self.batch_norm = nn.BatchNorm1d(out_ch)
        nn.init.constant_(self.batch_norm.weight,
                          0. if zero_batch_norm else 1.)
        self.act_fun = act_fun
        self.drop_r, self.drop = drop_r, nn.Dropout(drop_r
                                                    ) if drop_r else None
        self.use_act_fun = use_act_fun

    def forward(self, x):

        x = self.conv1d(x)
        x = self.batch_norm(x)
        if self.use_act_fun:
            x = self.act_fun(x)
        if self.drop_r:
            x = self.drop(x)

        return x


class XResNetBlock(nn.Module):

    def __init__(self, expansion, in_ch, between_ch, k=9,
                 stride=1, b_verbose=None,
                 act_fun: nn.Module = activation):

        assert expansion in [1, 4], 'expansion shall be 1 or 4'
        super(XResNetBlock, self).__init__()

        in_ch = in_ch * expansion
        out_ch = between_ch * expansion

        if expansion == 1:

            layers = [conv(in_ch, between_ch,
                           k, stride=stride),
                      conv(between_ch, out_ch, k,
                           zero_batch_norm=True,
                           use_act_fun=False)]

        else:

            layers = [conv(in_ch, between_ch, 1),
                      conv(between_ch, between_ch,
                           k, stride=stride, ),
                      conv(between_ch, out_ch, 1,
                           zero_batch_norm=True,
                           use_act_fun=False)]

        self.xres_block = nn.ModuleList(layers)

        self.res_conv = conv(in_ch, out_ch, 1, use_act_fun=False
                             ) if in_ch != out_ch else None
        self.res_pool = nn.AvgPool1d(2, ceil_mode=True
                                     ) if stride != 1 else None
        self.act_fun = act_fun
        self.b_verbose = b_verbose if b_verbose else None

    def forward(self, x):

        identity = x

        for l in self.xres_block:
            x = l(x)
            print('res_torch_size:', x.shape) if self.b_verbose else None

        identity = self.res_pool(identity) if self.res_pool else identity
        identity = self.res_conv(identity) if self.res_conv else identity
        print('identity_torch_size:', x.shape) if self.b_verbose else None

        x += identity
        x = self.act_fun(x)

        return x


class ConcatPool(nn.Module):

    def __init__(self, dim=1):
        super().__init__()

        self.maxpool = nn.AdaptiveMaxPool1d(1)
        self.avgpool = nn.AdaptiveAvgPool1d(1)

        self.dim = dim

    def forward(self, x):
        maxpooled = self.maxpool(x).squeeze(self.dim)
        avgpooled = self.avgpool(x).squeeze(self.dim)

        return torch.cat((maxpooled, avgpooled), dim=self.dim)


class XResNet1d(nn.Module):

    def __init__(self, expansion, num_layers, stem_k,
                 block_k, in_ch=data_dim, c_out=out_dim,
                 model_drop_r=None, verbose=False,
                 b_verbose=False, original_f_number=False,
                 fc_drop=None):

        super(XResNet1d, self).__init__()

        stem_filters = [in_ch, 32, 32, 64]

        stem = [conv(stem_filters[i], stem_filters[i + 1], k_size=stem_k,
                     stride=2 if i == 0 else 1, drop_r=model_drop_r,
                     ) for i in range(3)]
        self.stem = nn.ModuleList(stem)

        self.stem_pool = nn.MaxPool1d(3, 2, padding=1)
        self.model_drop_r = nn.Dropout(model_drop_r
                                       ) if model_drop_r else None
        self.b_verbose = b_verbose if b_verbose else None

        if original_f_number:

            block_filters = [64 // expansion] + [(o) for o in [
                64, 128, 256, 512] + [256] * (len(num_layers) - 4)]
        else:

            block_filters = [64 // expansion] + [(o) for o in [
                64, 64, 64, 64] + [32] * (len(num_layers) - 4)]

        self.block_k = block_k
        block = [self.make_layers(expansion, block_filters[i],
                                  block_filters[i + 1], n_blocks=l,
                                  stride=1 if i == 0 else 2,
                                  ) for i, l in enumerate(num_layers)]
        self.block = nn.ModuleList(block)

        self.concat_pool = ConcatPool()
        self.fc1 = nn.Linear(block_filters[-1] * expansion * 2, 128)
        self.fc_batch_norm = nn.BatchNorm1d(128)
        self.fc_drop = nn.Dropout(fc_drop) if fc_drop else None
        self.fc_out = nn.Linear(128, c_out)
        self.expansion = expansion
        self.verbose = verbose

        for m in self.modules():

            if isinstance(m, nn.Conv1d) or isinstance(m, nn.Linear):
                nn.init.kaiming_normal_(m.weight.data)
            if getattr(m, 'bias', None) is not None:
                nn.init.constant_(m.bias, 0)

    def make_layers(self, expansion, n_inputs,
                    n_filters, n_blocks, stride,
                    ):

        sub_block = []

        if self.model_drop_r:

            for i in range(n_blocks):
                sub_block.append(XResNetBlock(expansion,
                                              n_inputs if i == 0 else n_filters,
                                              n_filters, self.block_k,
                                              stride if i == 0 else 1,
                                              b_verbose=self.b_verbose
                                              if self.b_verbose else None,
                                              ))
                sub_block.append(self.model_drop_r)

        else:
            sub_block = [XResNetBlock(expansion,
                                      n_inputs if i == 0 else n_filters,
                                      n_filters, self.block_k,
                                      stride if i == 0 else 1,
                                      b_verbose=self.b_verbose
                                      if self.b_verbose else None,
                                      ) for i in range(n_blocks)]

        return nn.Sequential(*sub_block)

    def forward(self, x):

        for l in self.stem:
            x = l(x)
            print('stem_torch_size:', x.shape) if self.verbose else None

        x = self.stem_pool(x)

        for b in self.block:
            x = b(x)
            print('block_torch_size:', x.shape) if self.verbose else None

        x = self.concat_pool(x)
        print('concat_pool_torch_size:', x.shape) if self.verbose else None
        x = x.view(x.size(0), -1)
        x = self.fc1(x)
        x = self.fc_batch_norm(x)
        x = self.fc_drop(x) if self.fc_drop else x
        x = self.fc_out(x)

        return x



In [None]:
import torch
import random
import numpy as np
from torch.utils.data import DataLoader, Dataset, Subset
from sklearn.model_selection import StratifiedKFold
from sklearn.metrics import accuracy_score, roc_auc_score, f1_score, matthews_corrcoef
import torch.nn as nn
import torch.optim as optim

# **Set device configuration**
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# **Set seed for reproducibility**
def set_seed(seed_value):
    random.seed(seed_value)
    np.random.seed(seed_value)
    torch.manual_seed(seed_value)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed_value)
        torch.cuda.manual_seed_all(seed_value)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_seed(123)

# Define file paths for the 12-lead ECG data and labels
lead_file_paths = {
    "LEAD_I": "../../Data_processing/data_aspire_PAP_1/LEAD_I.pt",
    "LEAD_II": "../../Data_processing/data_aspire_PAP_1/LEAD_II.pt",
    "LEAD_III": "../../Data_processing/data_aspire_PAP_1/LEAD_III.pt",
    "LEAD_aVR": "../../Data_processing/data_aspire_PAP_1/LEAD_aVR.pt",
    "LEAD_aVL": "../../Data_processing/data_aspire_PAP_1/LEAD_aVL.pt",
    "LEAD_aVF": "../../Data_processing/data_aspire_PAP_1/LEAD_aVF.pt",
    "LEAD_V1": "../../Data_processing/data_aspire_PAP_1/LEAD_V1.pt",
    "LEAD_V2": "../../Data_processing/data_aspire_PAP_1/LEAD_V2.pt",
    "LEAD_V3": "../../Data_processing/data_aspire_PAP_1/LEAD_V3.pt",
    "LEAD_V4": "../../Data_processing/data_aspire_PAP_1/LEAD_V4.pt",
    "LEAD_V5": "../../Data_processing/data_aspire_PAP_1/LEAD_V5.pt",
    "LEAD_V6": "../../Data_processing/data_aspire_PAP_1/LEAD_V6.pt"
}
labels_file_path = "../../Data_processing/data_aspire_PAP_1/labels.pt"

# Load all lead tensors and labels
ecg_lead_tensors = {lead: torch.load(path) for lead, path in lead_file_paths.items()}
labels = torch.load(labels_file_path)

# Ensure all leads have the same number of samples
sample_count = len(next(iter(ecg_lead_tensors.values())))
assert len(labels) == sample_count, "Mismatch between number of labels and samples."
for tensor in ecg_lead_tensors.values():
    assert len(tensor) == sample_count, "All leads must have the same number of samples."

# **Dataset Class**
class ECGMultiLeadDatasetWithLabels(Dataset):
    def __init__(self, ecg_leads, labels, lead_names):
        self.ecg_leads = {lead: ecg_leads[lead] for lead in lead_names}
        self.labels = labels

    def __len__(self):
        return len(next(iter(self.ecg_leads.values())))

    def __getitem__(self, idx):
        lead_data = {lead: self.ecg_leads[lead][idx].unsqueeze(0) for lead in self.ecg_leads}
        label = self.labels[idx]
        # Stack lead data to form a tensor of shape (num_leads, sequence_length)
        lead_tensor = torch.cat([lead_data[lead] for lead in lead_data], dim=0)
        return lead_tensor, label

# Define the set of leads (choose between 6-lead and 12-lead configurations)
use_6_leads = False  # Change to False for 12-lead ECG
lead_names = (
    ["LEAD_I", "LEAD_II", "LEAD_III", "LEAD_aVR", "LEAD_aVL", "LEAD_aVF"]
    if use_6_leads
    else list(lead_file_paths.keys())
)

# **Initialize the dataset and dataloader**
dataset = ECGMultiLeadDatasetWithLabels(ecg_lead_tensors, labels, lead_names)
dataloader = DataLoader(dataset, batch_size=128, shuffle=True)

class ECGResNet1D(nn.Module):
    def __init__(self, input_channels, num_classes=2):
        super(ECGResNet1D, self).__init__()

        # Optimized configuration for 10-second ECG data at ~500 Hz
        self.resnet = XResNet1d(
            expansion=1,
            num_layers=[3, 4, 6, 3],  # Optimized for longer ECG sequences
            stem_k=11,                # For higher sampling rates
            block_k=5,                # Moderate local feature extraction
            in_ch=input_channels,
            c_out=num_classes,
            model_drop_r=0.07,         # Slightly higher dropout for regularization
            verbose=False,
            b_verbose=False
        )

    def forward(self, x):
        return self.resnet(x)


# **Training and Evaluation Functions**
def train_classifier(model, train_loader, criterion, optimizer, epochs=10):
    model.train()
    for epoch in range(epochs):
        total_loss = 0
        for batch in train_loader:
            lead_data, labels = batch
            lead_data, labels = lead_data.to(device), labels.to(device)

            optimizer.zero_grad()
            outputs = model(lead_data)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            total_loss += loss.item() * labels.size(0)

        avg_loss = total_loss / len(train_loader.dataset)
        print(f'Epoch {epoch + 1}/{epochs}, Loss: {avg_loss:.4f}')

def evaluate_model(model, data_loader):
    model.eval()
    all_labels = []
    all_probs = []
    all_preds = []

    with torch.no_grad():
        for batch in data_loader:
            lead_data, labels = batch
            lead_data, labels = lead_data.to(device), labels.to(device)

            logits = model(lead_data)
            probs = torch.softmax(logits, dim=1)
            preds = torch.argmax(probs, dim=1)

            all_labels.extend(labels.cpu().numpy())
            all_probs.extend(probs[:, 1].cpu().numpy())
            all_preds.extend(preds.cpu().numpy())

    accuracy = accuracy_score(all_labels, all_preds)
    auc_score = roc_auc_score(all_labels, all_probs)
    f1 = f1_score(all_labels, all_preds)
    mcc = matthews_corrcoef(all_labels, all_preds)

    return accuracy, auc_score, f1, mcc

# **Cross-Validation Training**
skf = StratifiedKFold(n_splits=5, shuffle=True, random_state=123)
fold_results = []

for fold, (train_ids, test_ids) in enumerate(skf.split(np.zeros(len(labels)), labels)):
    print(f'FOLD {fold}')
    print('--------------------------------')

    train_subset = Subset(dataset, train_ids)
    test_subset = Subset(dataset, test_ids)

    train_loader = DataLoader(train_subset, batch_size=32, shuffle=True)
    test_loader = DataLoader(test_subset, batch_size=32, shuffle=False)

    # Initialize XResNet1d model
    model = ECGResNet1D(input_channels=len(lead_names), num_classes=2).to(device)

    optimizer = optim.Adam(model.parameters(), lr=1e-4)
    criterion = nn.CrossEntropyLoss()

    # Train and evaluate
    train_classifier(model, train_loader, criterion, optimizer, epochs=50)
    accuracy, auc_score, f1, mcc = evaluate_model(model, test_loader)
    fold_results.append((accuracy, auc_score, f1, mcc))

    print(f'Fold {fold} Results: Accuracy: {accuracy:.4f}, AUC: {auc_score:.4f}, F1: {f1:.4f}, MCC: {mcc:.4f}\n')

# **Calculate Metrics Across Folds**
accuracies, aucs, f1s, mccs = zip(*fold_results)
print(f'Mean Accuracy: {np.mean(accuracies):.4f}, STD: {np.std(accuracies):.4f}')
print(f'Mean AUC: {np.mean(aucs):.4f}, STD: {np.std(aucs):.4f}')
print(f'Mean F1: {np.mean(f1s):.4f}, STD: {np.std(f1s):.4f}')
print(f'Mean MCC: {np.mean(mccs):.4f}, STD: {np.std(mccs):.4f}')


Task-2

In [None]:
import torch
import random
import numpy as np
from torch.utils.data import DataLoader, Dataset, Subset
from sklearn.model_selection import StratifiedKFold
from sklearn.metrics import accuracy_score, roc_auc_score, f1_score, matthews_corrcoef
import torch.nn as nn
import torch.optim as optim

# **Set device configuration**
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# **Set seed for reproducibility**
def set_seed(seed_value):
    random.seed(seed_value)
    np.random.seed(seed_value)
    torch.manual_seed(seed_value)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed_value)
        torch.cuda.manual_seed_all(seed_value)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_seed(123)

# Define file paths for the 12-lead ECG data and labels
lead_file_paths = {
    "LEAD_I": "../../Data_processing/data_aspire_PAWP_1/LEAD_I.pt",
    "LEAD_II": "../../Data_processing/data_aspire_PAWP_1/LEAD_II.pt",
    "LEAD_III": "../../Data_processing/data_aspire_PAWP_1/LEAD_III.pt",
    "LEAD_aVR": "../../Data_processing/data_aspire_PAWP_1/LEAD_aVR.pt",
    "LEAD_aVL": "../../Data_processing/data_aspire_PAWP_1/LEAD_aVL.pt",
    "LEAD_aVF": "../../Data_processing/data_aspire_PAWP_1/LEAD_aVF.pt",
    "LEAD_V1": "../../Data_processing/data_aspire_PAWP_1/LEAD_V1.pt",
    "LEAD_V2": "../../Data_processing/data_aspire_PAWP_1/LEAD_V2.pt",
    "LEAD_V3": "../../Data_processing/data_aspire_PAWP_1/LEAD_V3.pt",
    "LEAD_V4": "../../Data_processing/data_aspire_PAWP_1/LEAD_V4.pt",
    "LEAD_V5": "../../Data_processing/data_aspire_PAWP_1/LEAD_V5.pt",
    "LEAD_V6": "../../Data_processing/data_aspire_PAWP_1/LEAD_V6.pt"
}
labels_file_path = "../../Data_processing/data_aspire_PAWP_1/labels.pt"


# Load all lead tensors and labels
ecg_lead_tensors = {lead: torch.load(path) for lead, path in lead_file_paths.items()}
labels = torch.load(labels_file_path)

# Ensure all leads have the same number of samples
sample_count = len(next(iter(ecg_lead_tensors.values())))
assert len(labels) == sample_count, "Mismatch between number of labels and samples."
for tensor in ecg_lead_tensors.values():
    assert len(tensor) == sample_count, "All leads must have the same number of samples."

# **Dataset Class**
class ECGMultiLeadDatasetWithLabels(Dataset):
    def __init__(self, ecg_leads, labels, lead_names):
        self.ecg_leads = {lead: ecg_leads[lead] for lead in lead_names}
        self.labels = labels

    def __len__(self):
        return len(next(iter(self.ecg_leads.values())))

    def __getitem__(self, idx):
        lead_data = {lead: self.ecg_leads[lead][idx].unsqueeze(0) for lead in self.ecg_leads}
        label = self.labels[idx]
        # Stack lead data to form a tensor of shape (num_leads, sequence_length)
        lead_tensor = torch.cat([lead_data[lead] for lead in lead_data], dim=0)
        return lead_tensor, label

# Define the set of leads (choose between 6-lead and 12-lead configurations)
use_6_leads = False  # Change to False for 12-lead ECG
lead_names = (
    ["LEAD_I", "LEAD_II", "LEAD_III", "LEAD_aVR", "LEAD_aVL", "LEAD_aVF"]
    if use_6_leads
    else list(lead_file_paths.keys())
)

# **Initialize the dataset and dataloader**
dataset = ECGMultiLeadDatasetWithLabels(ecg_lead_tensors, labels, lead_names)
dataloader = DataLoader(dataset, batch_size=128, shuffle=True)

class ECGResNet1D(nn.Module):
    def __init__(self, input_channels, num_classes=2):
        super(ECGResNet1D, self).__init__()

        # Optimized configuration for 10-second ECG data at ~500 Hz
        self.resnet = XResNet1d(
            expansion=1,
            num_layers=[3, 4, 6, 3],  # Optimized for longer ECG sequences
            stem_k=11,                # For higher sampling rates
            block_k=5,                # Moderate local feature extraction
            in_ch=input_channels,
            c_out=num_classes,
            model_drop_r=0.07,         # Slightly higher dropout for regularization
            verbose=False,
            b_verbose=False
        )

    def forward(self, x):
        return self.resnet(x)
# **Training and Evaluation Functions**
def train_classifier(model, train_loader, criterion, optimizer, epochs=10):
    model.train()
    for epoch in range(epochs):
        total_loss = 0
        for batch in train_loader:
            lead_data, labels = batch
            lead_data, labels = lead_data.to(device), labels.to(device)

            optimizer.zero_grad()
            outputs = model(lead_data)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            total_loss += loss.item() * labels.size(0)

        avg_loss = total_loss / len(train_loader.dataset)
        print(f'Epoch {epoch + 1}/{epochs}, Loss: {avg_loss:.4f}')

def evaluate_model(model, data_loader):
    model.eval()
    all_labels = []
    all_probs = []
    all_preds = []

    with torch.no_grad():
        for batch in data_loader:
            lead_data, labels = batch
            lead_data, labels = lead_data.to(device), labels.to(device)

            logits = model(lead_data)
            probs = torch.softmax(logits, dim=1)
            preds = torch.argmax(probs, dim=1)

            all_labels.extend(labels.cpu().numpy())
            all_probs.extend(probs[:, 1].cpu().numpy())
            all_preds.extend(preds.cpu().numpy())

    accuracy = accuracy_score(all_labels, all_preds)
    auc_score = roc_auc_score(all_labels, all_probs)
    f1 = f1_score(all_labels, all_preds)
    mcc = matthews_corrcoef(all_labels, all_preds)

    return accuracy, auc_score, f1, mcc

# **Cross-Validation Training**
skf = StratifiedKFold(n_splits=5, shuffle=True, random_state=123)
fold_results = []

for fold, (train_ids, test_ids) in enumerate(skf.split(np.zeros(len(labels)), labels)):
    print(f'FOLD {fold}')
    print('--------------------------------')

    train_subset = Subset(dataset, train_ids)
    test_subset = Subset(dataset, test_ids)

    train_loader = DataLoader(train_subset, batch_size=32, shuffle=True)
    test_loader = DataLoader(test_subset, batch_size=32, shuffle=False)

    # Initialize XResNet1d model
    model = ECGResNet1D(input_channels=len(lead_names), num_classes=2).to(device)

    optimizer = optim.Adam(model.parameters(), lr=1e-4)
    criterion = nn.CrossEntropyLoss()

    # Train and evaluate
    train_classifier(model, train_loader, criterion, optimizer, epochs=50)
    accuracy, auc_score, f1, mcc = evaluate_model(model, test_loader)
    fold_results.append((accuracy, auc_score, f1, mcc))

    print(f'Fold {fold} Results: Accuracy: {accuracy:.4f}, AUC: {auc_score:.4f}, F1: {f1:.4f}, MCC: {mcc:.4f}\n')

# **Calculate Metrics Across Folds**
accuracies, aucs, f1s, mccs = zip(*fold_results)
print(f'Mean Accuracy: {np.mean(accuracies):.4f}, STD: {np.std(accuracies):.4f}')
print(f'Mean AUC: {np.mean(aucs):.4f}, STD: {np.std(aucs):.4f}')
print(f'Mean F1: {np.mean(f1s):.4f}, STD: {np.std(f1s):.4f}')
print(f'Mean MCC: {np.mean(mccs):.4f}, STD: {np.std(mccs):.4f}')


Taks-3

In [None]:
import torch
import random
import numpy as np
from torch.utils.data import DataLoader, Dataset, Subset
from sklearn.model_selection import StratifiedKFold
from sklearn.metrics import accuracy_score, roc_auc_score, f1_score, matthews_corrcoef
import torch.nn as nn
import torch.optim as optim

# **Set device configuration**
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# **Set seed for reproducibility**
def set_seed(seed_value):
    random.seed(seed_value)
    np.random.seed(seed_value)
    torch.manual_seed(seed_value)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed_value)
        torch.cuda.manual_seed_all(seed_value)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_seed(123)

lead_file_paths = {
    "LEAD_I": "D:/ukbiobank/ECG_PAWP_UKB_Final/LEAD_I.pt",
    "LEAD_II": "D:/ukbiobank/ECG_PAWP_UKB_Final/LEAD_II.pt",
    "LEAD_III": "D:/ukbiobank/ECG_PAWP_UKB_Final/LEAD_III.pt",
    "LEAD_aVR": "D:/ukbiobank/ECG_PAWP_UKB_Final/LEAD_aVR.pt",
    "LEAD_aVL": "D:/ukbiobank/ECG_PAWP_UKB_Final/LEAD_aVL.pt",
    "LEAD_aVF": "D:/ukbiobank/ECG_PAWP_UKB_Final/LEAD_aVF.pt",
    "LEAD_V1": "D:/ukbiobank/ECG_PAWP_UKB_Final/LEAD_V1.pt",
    "LEAD_V2": "D:/ukbiobank/ECG_PAWP_UKB_Final/LEAD_V2.pt",
    "LEAD_V3": "D:/ukbiobank/ECG_PAWP_UKB_Final/LEAD_V3.pt",
    "LEAD_V4": "D:/ukbiobank/ECG_PAWP_UKB_Final/LEAD_V4.pt",
    "LEAD_V5": "D:/ukbiobank/ECG_PAWP_UKB_Final/LEAD_V5.pt",
    "LEAD_V6": "D:/ukbiobank/ECG_PAWP_UKB_Final/LEAD_V6.pt"
}
labels_file_path = "D:/ukbiobank/ECG_PAWP_UKB_Final/labels.pt"


# Load all lead tensors and labels
ecg_lead_tensors = {lead: torch.load(path) for lead, path in lead_file_paths.items()}
labels = torch.load(labels_file_path)

# Ensure all leads have the same number of samples
sample_count = len(next(iter(ecg_lead_tensors.values())))
assert len(labels) == sample_count, "Mismatch between number of labels and samples."
for tensor in ecg_lead_tensors.values():
    assert len(tensor) == sample_count, "All leads must have the same number of samples."

# **Dataset Class**
class ECGMultiLeadDatasetWithLabels(Dataset):
    def __init__(self, ecg_leads, labels, lead_names):
        self.ecg_leads = {lead: ecg_leads[lead] for lead in lead_names}
        self.labels = labels

    def __len__(self):
        return len(next(iter(self.ecg_leads.values())))

    def __getitem__(self, idx):
        lead_data = {lead: self.ecg_leads[lead][idx].unsqueeze(0) for lead in self.ecg_leads}
        label = self.labels[idx]
        # Stack lead data to form a tensor of shape (num_leads, sequence_length)
        lead_tensor = torch.cat([lead_data[lead] for lead in lead_data], dim=0)
        return lead_tensor, label

# Define the set of leads (choose between 6-lead and 12-lead configurations)
use_6_leads = True  # Change to False for 12-lead ECG
lead_names = (
    ["LEAD_I", "LEAD_II", "LEAD_III", "LEAD_aVR", "LEAD_aVL", "LEAD_aVF"]
    if use_6_leads
    else list(lead_file_paths.keys())
)

# **Initialize the dataset and dataloader**
dataset = ECGMultiLeadDatasetWithLabels(ecg_lead_tensors, labels, lead_names)

class ECGResNet1D(nn.Module):
    def __init__(self, input_channels, num_classes=2):
        super(ECGResNet1D, self).__init__()

        # Optimized configuration for 10-second ECG data at ~500 Hz
        self.resnet = XResNet1d(
            expansion=1,
            num_layers=[3, 4, 6, 3],  # Optimized for longer ECG sequences
            stem_k=9,                # For higher sampling rates
            block_k=3,                # Moderate local feature extraction
            in_ch=input_channels,
            c_out=num_classes,
            model_drop_r=0.1,         # Slightly higher dropout for regularization
            verbose=False,
            b_verbose=False
        )

    def forward(self, x):
        return self.resnet(x)
# **Training and Evaluation Functions**
def train_classifier(model, train_loader, criterion, optimizer, epochs=10):
    model.train()
    for epoch in range(epochs):
        total_loss = 0
        for batch in train_loader:
            lead_data, labels = batch
            lead_data, labels = lead_data.to(device), labels.to(device)

            optimizer.zero_grad()
            outputs = model(lead_data)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            total_loss += loss.item() * labels.size(0)

        avg_loss = total_loss / len(train_loader.dataset)
        print(f'Epoch {epoch + 1}/{epochs}, Loss: {avg_loss:.4f}')

def evaluate_model(model, data_loader):
    model.eval()
    all_labels = []
    all_probs = []
    all_preds = []

    with torch.no_grad():
        for batch in data_loader:
            lead_data, labels = batch
            lead_data, labels = lead_data.to(device), labels.to(device)

            logits = model(lead_data)
            probs = torch.softmax(logits, dim=1)
            preds = torch.argmax(probs, dim=1)

            all_labels.extend(labels.cpu().numpy())
            all_probs.extend(probs[:, 1].cpu().numpy())
            all_preds.extend(preds.cpu().numpy())

    accuracy = accuracy_score(all_labels, all_preds)
    auc_score = roc_auc_score(all_labels, all_probs)
    f1 = f1_score(all_labels, all_preds)
    mcc = matthews_corrcoef(all_labels, all_preds)

    return accuracy, auc_score, f1, mcc

# **Cross-Validation Training**
skf = StratifiedKFold(n_splits=5, shuffle=True, random_state=123)
fold_results = []

for fold, (train_ids, test_ids) in enumerate(skf.split(np.zeros(len(labels)), labels)):
    print(f'FOLD {fold}')
    print('--------------------------------')

    train_subset = Subset(dataset, train_ids)
    test_subset = Subset(dataset, test_ids)

    train_loader = DataLoader(train_subset, batch_size=32, shuffle=True)
    test_loader = DataLoader(test_subset, batch_size=32, shuffle=False)

    # Initialize XResNet1d model
    model = ECGResNet1D(input_channels=len(lead_names), num_classes=2).to(device)

    optimizer = optim.Adam(model.parameters(), lr=1e-4)
    criterion = nn.CrossEntropyLoss()

    # Train and evaluate
    train_classifier(model, train_loader, criterion, optimizer, epochs=50)
    accuracy, auc_score, f1, mcc = evaluate_model(model, test_loader)
    fold_results.append((accuracy, auc_score, f1, mcc))

    print(f'Fold {fold} Results: Accuracy: {accuracy:.4f}, AUC: {auc_score:.4f}, F1: {f1:.4f}, MCC: {mcc:.4f}\n')

# **Calculate Metrics Across Folds**
accuracies, aucs, f1s, mccs = zip(*fold_results)
print(f'Mean Accuracy: {np.mean(accuracies):.4f}, STD: {np.std(accuracies):.4f}')
print(f'Mean AUC: {np.mean(aucs):.4f}, STD: {np.std(aucs):.4f}')
print(f'Mean F1: {np.mean(f1s):.4f}, STD: {np.std(f1s):.4f}')
print(f'Mean MCC: {np.mean(mccs):.4f}, STD: {np.std(mccs):.4f}')
