Task-1 (mPAP)

In [2]:
"""
resnet for 1-d signal data, pytorch version
 
Shenda Hong, Oct 2019
"""

import numpy as np
from collections import Counter
from tqdm import tqdm
from matplotlib import pyplot as plt
from sklearn.metrics import classification_report 

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

class MyDataset(Dataset):
    def __init__(self, data, label):
        self.data = data
        self.label = label

    def __getitem__(self, index):
        return (torch.tensor(self.data[index], dtype=torch.float), torch.tensor(self.label[index], dtype=torch.long))

    def __len__(self):
        return len(self.data)
    
class MyConv1dPadSame(nn.Module):
    """
    extend nn.Conv1d to support SAME padding
    """
    def __init__(self, in_channels, out_channels, kernel_size, stride, groups=1):
        super(MyConv1dPadSame, self).__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = kernel_size
        self.stride = stride
        self.groups = groups
        self.conv = torch.nn.Conv1d(
            in_channels=self.in_channels, 
            out_channels=self.out_channels, 
            kernel_size=self.kernel_size, 
            stride=self.stride, 
            groups=self.groups)

    def forward(self, x):
        
        net = x
        
        # compute pad shape
        in_dim = net.shape[-1]
        out_dim = (in_dim + self.stride - 1) // self.stride
        p = max(0, (out_dim - 1) * self.stride + self.kernel_size - in_dim)
        pad_left = p // 2
        pad_right = p - pad_left
        net = F.pad(net, (pad_left, pad_right), "constant", 0)
        
        net = self.conv(net)

        return net
        
class MyMaxPool1dPadSame(nn.Module):
    """
    extend nn.MaxPool1d to support SAME padding
    """
    def __init__(self, kernel_size):
        super(MyMaxPool1dPadSame, self).__init__()
        self.kernel_size = kernel_size
        self.stride = 1
        self.max_pool = torch.nn.MaxPool1d(kernel_size=self.kernel_size)

    def forward(self, x):
        
        net = x
        
        # compute pad shape
        in_dim = net.shape[-1]
        out_dim = (in_dim + self.stride - 1) // self.stride
        p = max(0, (out_dim - 1) * self.stride + self.kernel_size - in_dim)
        pad_left = p // 2
        pad_right = p - pad_left
        net = F.pad(net, (pad_left, pad_right), "constant", 0)
        
        net = self.max_pool(net)
        
        return net
    
class BasicBlock(nn.Module):
    """
    ResNet Basic Block
    """
    def __init__(self, in_channels, out_channels, kernel_size, stride, groups, downsample, use_bn, use_do, is_first_block=False):
        super(BasicBlock, self).__init__()
        
        self.in_channels = in_channels
        self.kernel_size = kernel_size
        self.out_channels = out_channels
        self.stride = stride
        self.groups = groups
        self.downsample = downsample
        if self.downsample:
            self.stride = stride
        else:
            self.stride = 1
        self.is_first_block = is_first_block
        self.use_bn = use_bn
        self.use_do = use_do

        # the first conv
        self.bn1 = nn.BatchNorm1d(in_channels)
        self.relu1 = nn.ReLU()
        self.do1 = nn.Dropout(p=0.5)
        self.conv1 = MyConv1dPadSame(
            in_channels=in_channels, 
            out_channels=out_channels, 
            kernel_size=kernel_size, 
            stride=self.stride,
            groups=self.groups)

        # the second conv
        self.bn2 = nn.BatchNorm1d(out_channels)
        self.relu2 = nn.ReLU()
        self.do2 = nn.Dropout(p=0.5)
        self.conv2 = MyConv1dPadSame(
            in_channels=out_channels, 
            out_channels=out_channels, 
            kernel_size=kernel_size, 
            stride=1,
            groups=self.groups)
                
        self.max_pool = MyMaxPool1dPadSame(kernel_size=self.stride)

    def forward(self, x):
        
        identity = x
        
        # the first conv
        out = x
        if not self.is_first_block:
            if self.use_bn:
                out = self.bn1(out)
            out = self.relu1(out)
            if self.use_do:
                out = self.do1(out)
        out = self.conv1(out)
        
        # the second conv
        if self.use_bn:
            out = self.bn2(out)
        out = self.relu2(out)
        if self.use_do:
            out = self.do2(out)
        out = self.conv2(out)
        
        # if downsample, also downsample identity
        if self.downsample:
            identity = self.max_pool(identity)
            
        # if expand channel, also pad zeros to identity
        if self.out_channels != self.in_channels:
            identity = identity.transpose(-1,-2)
            ch1 = (self.out_channels-self.in_channels)//2
            ch2 = self.out_channels-self.in_channels-ch1
            identity = F.pad(identity, (ch1, ch2), "constant", 0)
            identity = identity.transpose(-1,-2)
        
        # shortcut
        out += identity

        return out
    
class ResNet1D(nn.Module):
    """
    
    Input:
        X: (n_samples, n_channel, n_length)
        Y: (n_samples)
        
    Output:
        out: (n_samples)
        
    Pararmetes:
        in_channels: dim of input, the same as n_channel
        base_filters: number of filters in the first several Conv layer, it will double at every 4 layers
        kernel_size: width of kernel
        stride: stride of kernel moving
        groups: set larget to 1 as ResNeXt
        n_block: number of blocks
        n_classes: number of classes
        
    """

    def __init__(self, in_channels, base_filters, kernel_size, stride, groups, n_block, n_classes, downsample_gap=2, increasefilter_gap=4, use_bn=True, use_do=True, verbose=False):
        super(ResNet1D, self).__init__()
        
        self.verbose = verbose
        self.n_block = n_block
        self.kernel_size = kernel_size
        self.stride = stride
        self.groups = groups
        self.use_bn = use_bn
        self.use_do = use_do

        self.downsample_gap = downsample_gap # 2 for base model
        self.increasefilter_gap = increasefilter_gap # 4 for base model

        # first block
        self.first_block_conv = MyConv1dPadSame(in_channels=in_channels, out_channels=base_filters, kernel_size=self.kernel_size, stride=1)
        self.first_block_bn = nn.BatchNorm1d(base_filters)
        self.first_block_relu = nn.ReLU()
        out_channels = base_filters
                
        # residual blocks
        self.basicblock_list = nn.ModuleList()
        for i_block in range(self.n_block):
            # is_first_block
            if i_block == 0:
                is_first_block = True
            else:
                is_first_block = False
            # downsample at every self.downsample_gap blocks
            if i_block % self.downsample_gap == 1:
                downsample = True
            else:
                downsample = False
            # in_channels and out_channels
            if is_first_block:
                in_channels = base_filters
                out_channels = in_channels
            else:
                # increase filters at every self.increasefilter_gap blocks
                in_channels = int(base_filters*2**((i_block-1)//self.increasefilter_gap))
                if (i_block % self.increasefilter_gap == 0) and (i_block != 0):
                    out_channels = in_channels * 2
                else:
                    out_channels = in_channels
            
            tmp_block = BasicBlock(
                in_channels=in_channels, 
                out_channels=out_channels, 
                kernel_size=self.kernel_size, 
                stride = self.stride, 
                groups = self.groups, 
                downsample=downsample, 
                use_bn = self.use_bn, 
                use_do = self.use_do, 
                is_first_block=is_first_block)
            self.basicblock_list.append(tmp_block)

        # final prediction
        self.final_bn = nn.BatchNorm1d(out_channels)
        self.final_relu = nn.ReLU(inplace=True)
        # self.do = nn.Dropout(p=0.5)
        self.dense = nn.Linear(out_channels, n_classes)
        # self.softmax = nn.Softmax(dim=1)
        
    def forward(self, x):
        
        out = x
        
        # first conv
        if self.verbose:
            print('input shape', out.shape)
        out = self.first_block_conv(out)
        if self.verbose:
            print('after first conv', out.shape)
        if self.use_bn:
            out = self.first_block_bn(out)
        out = self.first_block_relu(out)
        
        # residual blocks, every block has two conv
        for i_block in range(self.n_block):
            net = self.basicblock_list[i_block]
            if self.verbose:
                print('i_block: {0}, in_channels: {1}, out_channels: {2}, downsample: {3}'.format(i_block, net.in_channels, net.out_channels, net.downsample))
            out = net(out)
            if self.verbose:
                print(out.shape)

        # final prediction
        if self.use_bn:
            out = self.final_bn(out)
        out = self.final_relu(out)
        out = out.mean(-1)
        if self.verbose:
            print('final pooling', out.shape)
        # out = self.do(out)
        out = self.dense(out)
        if self.verbose:
            print('dense', out.shape)
        # out = self.softmax(out)
        if self.verbose:
            print('softmax', out.shape)
        
        return out    

In [None]:
import torch
import random
import numpy as np
from torch.utils.data import DataLoader, Dataset, Subset
from sklearn.model_selection import StratifiedKFold
from sklearn.metrics import accuracy_score, roc_auc_score, f1_score, matthews_corrcoef
import torch.nn as nn
import torch.optim as optim

# **Set device configuration**
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# **Set seed for reproducibility**
def set_seed(seed_value):
    random.seed(seed_value)
    np.random.seed(seed_value)
    torch.manual_seed(seed_value)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed_value)
        torch.cuda.manual_seed_all(seed_value)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_seed(123)

# Define file paths for the 12-lead ECG data and labels
lead_file_paths = {
    "LEAD_I": "../../Data_processing/data_aspire_PAP_1/LEAD_I.pt",
    "LEAD_II": "../../Data_processing/data_aspire_PAP_1/LEAD_II.pt",
    "LEAD_III": "../../Data_processing/data_aspire_PAP_1/LEAD_III.pt",
    "LEAD_aVR": "../../Data_processing/data_aspire_PAP_1/LEAD_aVR.pt",
    "LEAD_aVL": "../../Data_processing/data_aspire_PAP_1/LEAD_aVL.pt",
    "LEAD_aVF": "../../Data_processing/data_aspire_PAP_1/LEAD_aVF.pt",
    "LEAD_V1": "../../Data_processing/data_aspire_PAP_1/LEAD_V1.pt",
    "LEAD_V2": "../../Data_processing/data_aspire_PAP_1/LEAD_V2.pt",
    "LEAD_V3": "../../Data_processing/data_aspire_PAP_1/LEAD_V3.pt",
    "LEAD_V4": "../../Data_processing/data_aspire_PAP_1/LEAD_V4.pt",
    "LEAD_V5": "../../Data_processing/data_aspire_PAP_1/LEAD_V5.pt",
    "LEAD_V6": "../../Data_processing/data_aspire_PAP_1/LEAD_V6.pt"
}
labels_file_path = "../../Data_processing/data_aspire_PAP_1/labels.pt"

# Load all lead tensors and labels
ecg_lead_tensors = {lead: torch.load(path) for lead, path in lead_file_paths.items()}
labels = torch.load(labels_file_path)

# Ensure all leads have the same number of samples
sample_count = len(next(iter(ecg_lead_tensors.values())))
assert len(labels) == sample_count, "Mismatch between number of labels and samples."
for tensor in ecg_lead_tensors.values():
    assert len(tensor) == sample_count, "All leads must have the same number of samples."

# **Dataset Class**
class ECGMultiLeadDatasetWithLabels(Dataset):
    def __init__(self, ecg_leads, labels, lead_names):
        self.ecg_leads = {lead: ecg_leads[lead] for lead in lead_names}
        self.labels = labels

    def __len__(self):
        return len(next(iter(self.ecg_leads.values())))

    def __getitem__(self, idx):
        lead_data = {lead: self.ecg_leads[lead][idx].unsqueeze(0) for lead in self.ecg_leads}
        label = self.labels[idx]
        # Stack lead data to form a tensor of shape (num_leads, sequence_length)
        lead_tensor = torch.cat([lead_data[lead] for lead in lead_data], dim=0)
        return lead_tensor, label

# Define the set of leads (choose between 6-lead and 12-lead configurations)
use_6_leads = True  # Change to False for 12-lead ECG
lead_names = (
    ["LEAD_I", "LEAD_II", "LEAD_III", "LEAD_aVR", "LEAD_aVL", "LEAD_aVF"]
    if use_6_leads
    else list(lead_file_paths.keys())
)

# **Initialize the dataset and dataloader**
dataset = ECGMultiLeadDatasetWithLabels(ecg_lead_tensors, labels, lead_names)
dataloader = DataLoader(dataset, batch_size=128, shuffle=True)

# **Model Class**
class ECGResNet1D(nn.Module):
    def __init__(self, input_channels, num_classes=2):
        super(ECGResNet1D, self).__init__()

        # Initialize ResNet1D with all required parameters
        self.resnet = ResNet1D(
            in_channels=input_channels,
            base_filters=64,
            kernel_size=7,
            stride=3,
            groups=1,
            n_block=16,
            n_classes=num_classes,
            downsample_gap=2,
            increasefilter_gap=4,
            use_bn=True,
            use_do=True,
            verbose=False
        )

    def forward(self, x):
        return self.resnet(x)

# **Training and Evaluation Functions**
def train_classifier(model, train_loader, criterion, optimizer, epochs=10):
    model.train()
    for epoch in range(epochs):
        total_loss = 0
        for batch in train_loader:
            lead_data, labels = batch
            lead_data, labels = lead_data.to(device), labels.to(device)

            optimizer.zero_grad()
            outputs = model(lead_data)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            total_loss += loss.item() * labels.size(0)

        avg_loss = total_loss / len(train_loader.dataset)
        print(f'Epoch {epoch + 1}/{epochs}, Loss: {avg_loss:.4f}')

def evaluate_model(model, data_loader):
    model.eval()
    all_labels = []
    all_probs = []
    all_preds = []

    with torch.no_grad():
        for batch in data_loader:
            lead_data, labels = batch
            lead_data, labels = lead_data.to(device), labels.to(device)

            logits = model(lead_data)
            probs = torch.softmax(logits, dim=1)
            preds = torch.argmax(probs, dim=1)

            all_labels.extend(labels.cpu().numpy())
            all_probs.extend(probs[:, 1].cpu().numpy())
            all_preds.extend(preds.cpu().numpy())

    accuracy = accuracy_score(all_labels, all_preds)
    auc_score = roc_auc_score(all_labels, all_probs)
    f1 = f1_score(all_labels, all_preds)
    mcc = matthews_corrcoef(all_labels, all_preds)

    return accuracy, auc_score, f1, mcc

# **Cross-Validation Training**
skf = StratifiedKFold(n_splits=5, shuffle=True, random_state=123)
fold_results = []

for fold, (train_ids, test_ids) in enumerate(skf.split(np.zeros(len(labels)), labels)):
    print(f'FOLD {fold}')
    print('--------------------------------')

    train_subset = Subset(dataset, train_ids)
    test_subset = Subset(dataset, test_ids)

    train_loader = DataLoader(train_subset, batch_size=32, shuffle=True)
    test_loader = DataLoader(test_subset, batch_size=32, shuffle=False)

    # Initialize ResNet1D model
    model = ECGResNet1D(input_channels=len(lead_names), num_classes=2).to(device)

    optimizer = optim.Adam(model.parameters(), lr=1e-4)
    criterion = nn.CrossEntropyLoss()

    # Train and evaluate
    train_classifier(model, train_loader, criterion, optimizer, epochs=50)
    accuracy, auc_score, f1, mcc = evaluate_model(model, test_loader)
    fold_results.append((accuracy, auc_score, f1, mcc))

    print(f'Fold {fold} Results: Accuracy: {accuracy:.4f}, AUC: {auc_score:.4f}, F1: {f1:.4f}, MCC: {mcc:.4f}\n')

# **Calculate Metrics Across Folds**
accuracies, aucs, f1s, mccs = zip(*fold_results)
print(f'Mean Accuracy: {np.mean(accuracies):.4f}, STD: {np.std(accuracies):.4f}')
print(f'Mean AUC: {np.mean(aucs):.4f}, STD: {np.std(aucs):.4f}')
print(f'Mean F1: {np.mean(f1s):.4f}, STD: {np.std(f1s):.4f}')
print(f'Mean MCC: {np.mean(mccs):.4f}, STD: {np.std(mccs):.4f}')


Task-2 (PAWP)

In [None]:
import torch
import random
import numpy as np
from torch.utils.data import DataLoader, Dataset, Subset
from sklearn.model_selection import StratifiedKFold
from sklearn.metrics import accuracy_score, roc_auc_score, f1_score, matthews_corrcoef
import torch.nn as nn
import torch.optim as optim

# **Set device configuration**
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# **Set seed for reproducibility**
def set_seed(seed_value):
    random.seed(seed_value)
    np.random.seed(seed_value)
    torch.manual_seed(seed_value)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed_value)
        torch.cuda.manual_seed_all(seed_value)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_seed(123)

# Define file paths for the 12-lead ECG data and labels
lead_file_paths = {
    "LEAD_I": "../../Data_processing/data_aspire_PAWP_1/LEAD_I.pt",
    "LEAD_II": "../../Data_processing/data_aspire_PAWP_1/LEAD_II.pt",
    "LEAD_III": "../../Data_processing/data_aspire_PAWP_1/LEAD_III.pt",
    "LEAD_aVR": "../../Data_processing/data_aspire_PAWP_1/LEAD_aVR.pt",
    "LEAD_aVL": "../../Data_processing/data_aspire_PAWP_1/LEAD_aVL.pt",
    "LEAD_aVF": "../../Data_processing/data_aspire_PAWP_1/LEAD_aVF.pt",
    "LEAD_V1": "../../Data_processing/data_aspire_PAWP_1/LEAD_V1.pt",
    "LEAD_V2": "../../Data_processing/data_aspire_PAWP_1/LEAD_V2.pt",
    "LEAD_V3": "../../Data_processing/data_aspire_PAWP_1/LEAD_V3.pt",
    "LEAD_V4": "../../Data_processing/data_aspire_PAWP_1/LEAD_V4.pt",
    "LEAD_V5": "../../Data_processing/data_aspire_PAWP_1/LEAD_V5.pt",
    "LEAD_V6": "../../Data_processing/data_aspire_PAWP_1/LEAD_V6.pt"
}
labels_file_path = "../../Data_processing/data_aspire_PAWP_1/labels.pt"

# Load all lead tensors and labels
ecg_lead_tensors = {lead: torch.load(path) for lead, path in lead_file_paths.items()}
labels = torch.load(labels_file_path)

# Ensure all leads have the same number of samples
sample_count = len(next(iter(ecg_lead_tensors.values())))
assert len(labels) == sample_count, "Mismatch between number of labels and samples."
for tensor in ecg_lead_tensors.values():
    assert len(tensor) == sample_count, "All leads must have the same number of samples."

# **Dataset Class**
class ECGMultiLeadDatasetWithLabels(Dataset):
    def __init__(self, ecg_leads, labels, lead_names):
        self.ecg_leads = {lead: ecg_leads[lead] for lead in lead_names}
        self.labels = labels

    def __len__(self):
        return len(next(iter(self.ecg_leads.values())))

    def __getitem__(self, idx):
        lead_data = {lead: self.ecg_leads[lead][idx].unsqueeze(0) for lead in self.ecg_leads}
        label = self.labels[idx]
        # Stack lead data to form a tensor of shape (num_leads, sequence_length)
        lead_tensor = torch.cat([lead_data[lead] for lead in lead_data], dim=0)
        return lead_tensor, label

# Define the set of leads (choose between 6-lead and 12-lead configurations)
use_6_leads = True  # Change to False for 12-lead ECG
lead_names = (
    ["LEAD_I", "LEAD_II", "LEAD_III", "LEAD_aVR", "LEAD_aVL", "LEAD_aVF"]
    if use_6_leads
    else list(lead_file_paths.keys())
)

# **Initialize the dataset and dataloader**
dataset = ECGMultiLeadDatasetWithLabels(ecg_lead_tensors, labels, lead_names)
dataloader = DataLoader(dataset, batch_size=128, shuffle=True)

# **Model Class**
class ECGResNet1D(nn.Module):
    def __init__(self, input_channels, num_classes=2):
        super(ECGResNet1D, self).__init__()

        # Initialize ResNet1D with all required parameters
        self.resnet = ResNet1D(
            in_channels=input_channels,
            base_filters=64,
            kernel_size=7,
            stride=3,
            groups=1,
            n_block=16,
            n_classes=num_classes,
            downsample_gap=2,
            increasefilter_gap=4,
            use_bn=True,
            use_do=True,
            verbose=False
        )

    def forward(self, x):
        return self.resnet(x)

# **Training and Evaluation Functions**
def train_classifier(model, train_loader, criterion, optimizer, epochs=10):
    model.train()
    for epoch in range(epochs):
        total_loss = 0
        for batch in train_loader:
            lead_data, labels = batch
            lead_data, labels = lead_data.to(device), labels.to(device)

            optimizer.zero_grad()
            outputs = model(lead_data)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            total_loss += loss.item() * labels.size(0)

        avg_loss = total_loss / len(train_loader.dataset)
        print(f'Epoch {epoch + 1}/{epochs}, Loss: {avg_loss:.4f}')

def evaluate_model(model, data_loader):
    model.eval()
    all_labels = []
    all_probs = []
    all_preds = []

    with torch.no_grad():
        for batch in data_loader:
            lead_data, labels = batch
            lead_data, labels = lead_data.to(device), labels.to(device)

            logits = model(lead_data)
            probs = torch.softmax(logits, dim=1)
            preds = torch.argmax(probs, dim=1)

            all_labels.extend(labels.cpu().numpy())
            all_probs.extend(probs[:, 1].cpu().numpy())
            all_preds.extend(preds.cpu().numpy())

    accuracy = accuracy_score(all_labels, all_preds)
    auc_score = roc_auc_score(all_labels, all_probs)
    f1 = f1_score(all_labels, all_preds)
    mcc = matthews_corrcoef(all_labels, all_preds)

    return accuracy, auc_score, f1, mcc

# **Cross-Validation Training**
skf = StratifiedKFold(n_splits=5, shuffle=True, random_state=123)
fold_results = []

for fold, (train_ids, test_ids) in enumerate(skf.split(np.zeros(len(labels)), labels)):
    print(f'FOLD {fold}')
    print('--------------------------------')

    train_subset = Subset(dataset, train_ids)
    test_subset = Subset(dataset, test_ids)

    train_loader = DataLoader(train_subset, batch_size=32, shuffle=True)
    test_loader = DataLoader(test_subset, batch_size=32, shuffle=False)

    # Initialize ResNet1D model
    model = ECGResNet1D(input_channels=len(lead_names), num_classes=2).to(device)

    optimizer = optim.Adam(model.parameters(), lr=1e-4)
    criterion = nn.CrossEntropyLoss()

    # Train and evaluate
    train_classifier(model, train_loader, criterion, optimizer, epochs=50)
    accuracy, auc_score, f1, mcc = evaluate_model(model, test_loader)
    fold_results.append((accuracy, auc_score, f1, mcc))

    print(f'Fold {fold} Results: Accuracy: {accuracy:.4f}, AUC: {auc_score:.4f}, F1: {f1:.4f}, MCC: {mcc:.4f}\n')

# **Calculate Metrics Across Folds**
accuracies, aucs, f1s, mccs = zip(*fold_results)
print(f'Mean Accuracy: {np.mean(accuracies):.4f}, STD: {np.std(accuracies):.4f}')
print(f'Mean AUC: {np.mean(aucs):.4f}, STD: {np.std(aucs):.4f}')
print(f'Mean F1: {np.mean(f1s):.4f}, STD: {np.std(f1s):.4f}')
print(f'Mean MCC: {np.mean(mccs):.4f}, STD: {np.std(mccs):.4f}')



Task-3

In [None]:
import torch
import random
import numpy as np
from torch.utils.data import DataLoader, Dataset, Subset
from sklearn.model_selection import StratifiedKFold
from sklearn.metrics import accuracy_score, roc_auc_score, f1_score, matthews_corrcoef
import torch.nn as nn
import torch.optim as optim

# **Set device configuration**
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# **Set seed for reproducibility**
def set_seed(seed_value):
    random.seed(seed_value)
    np.random.seed(seed_value)
    torch.manual_seed(seed_value)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed_value)
        torch.cuda.manual_seed_all(seed_value)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_seed(123)

# Define file paths for the 12-lead ECG data and labels
lead_file_paths = {
    "LEAD_I": "D:/ukbiobank/ECG_PAWP_UKB_Final/LEAD_I.pt",
    "LEAD_II": "D:/ukbiobank/ECG_PAWP_UKB_Final/LEAD_II.pt",
    "LEAD_III": "D:/ukbiobank/ECG_PAWP_UKB_Final/LEAD_III.pt",
    "LEAD_aVR": "D:/ukbiobank/ECG_PAWP_UKB_Final/LEAD_aVR.pt",
    "LEAD_aVL": "D:/ukbiobank/ECG_PAWP_UKB_Final/LEAD_aVL.pt",
    "LEAD_aVF": "D:/ukbiobank/ECG_PAWP_UKB_Final/LEAD_aVF.pt",
    "LEAD_V1": "D:/ukbiobank/ECG_PAWP_UKB_Final/LEAD_V1.pt",
    "LEAD_V2": "D:/ukbiobank/ECG_PAWP_UKB_Final/LEAD_V2.pt",
    "LEAD_V3": "D:/ukbiobank/ECG_PAWP_UKB_Final/LEAD_V3.pt",
    "LEAD_V4": "D:/ukbiobank/ECG_PAWP_UKB_Final/LEAD_V4.pt",
    "LEAD_V5": "D:/ukbiobank/ECG_PAWP_UKB_Final/LEAD_V5.pt",
    "LEAD_V6": "D:/ukbiobank/ECG_PAWP_UKB_Final/LEAD_V6.pt"
}
labels_file_path = "D:/ukbiobank/ECG_PAWP_UKB_Final/labels.pt"


# Load all lead tensors and labels
ecg_lead_tensors = {lead: torch.load(path) for lead, path in lead_file_paths.items()}
labels = torch.load(labels_file_path)

# Ensure all leads have the same number of samples
sample_count = len(next(iter(ecg_lead_tensors.values())))
assert len(labels) == sample_count, "Mismatch between number of labels and samples."
for tensor in ecg_lead_tensors.values():
    assert len(tensor) == sample_count, "All leads must have the same number of samples."

# **Dataset Class**
class ECGMultiLeadDatasetWithLabels(Dataset):
    def __init__(self, ecg_leads, labels, lead_names):
        self.ecg_leads = {lead: ecg_leads[lead] for lead in lead_names}
        self.labels = labels

    def __len__(self):
        return len(next(iter(self.ecg_leads.values())))

    def __getitem__(self, idx):
        lead_data = {lead: self.ecg_leads[lead][idx].unsqueeze(0) for lead in self.ecg_leads}
        label = self.labels[idx]
        # Stack lead data to form a tensor of shape (num_leads, sequence_length)
        lead_tensor = torch.cat([lead_data[lead] for lead in lead_data], dim=0)
        return lead_tensor, label

# Define the set of leads (choose between 6-lead and 12-lead configurations)
use_6_leads = True  # Change to False for 12-lead ECG
lead_names = (
    ["LEAD_I", "LEAD_II", "LEAD_III", "LEAD_aVR", "LEAD_aVL", "LEAD_aVF"]
    if use_6_leads
    else list(lead_file_paths.keys())
)

# **Initialize the dataset and dataloader**
dataset = ECGMultiLeadDatasetWithLabels(ecg_lead_tensors, labels, lead_names)


# **Model Class**
class ECGResNet1D(nn.Module):
    def __init__(self, input_channels, num_classes=2):
        super(ECGResNet1D, self).__init__()

        # Initialize ResNet1D with all required parameters
        self.resnet = ResNet1D(
            in_channels=input_channels,
            base_filters=64,
            kernel_size=3,
            stride=3,
            groups=1,
            n_block=4,
            n_classes=num_classes,
            downsample_gap=1,
            increasefilter_gap=2,
            use_bn=False,
            use_do=False,
            verbose=False
        )

    def forward(self, x):
        return self.resnet(x)

# **Training and Evaluation Functions**
def train_classifier(model, train_loader, criterion, optimizer, epochs=10):
    model.train()
    for epoch in range(epochs):
        total_loss = 0
        for batch in train_loader:
            lead_data, labels = batch
            lead_data, labels = lead_data.to(device), labels.to(device)

            optimizer.zero_grad()
            outputs = model(lead_data)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            total_loss += loss.item() * labels.size(0)

        avg_loss = total_loss / len(train_loader.dataset)
        print(f'Epoch {epoch + 1}/{epochs}, Loss: {avg_loss:.4f}')

def evaluate_model(model, data_loader):
    model.eval()
    all_labels = []
    all_probs = []
    all_preds = []

    with torch.no_grad():
        for batch in data_loader:
            lead_data, labels = batch
            lead_data, labels = lead_data.to(device), labels.to(device)

            logits = model(lead_data)
            probs = torch.softmax(logits, dim=1)
            preds = torch.argmax(probs, dim=1)

            all_labels.extend(labels.cpu().numpy())
            all_probs.extend(probs[:, 1].cpu().numpy())
            all_preds.extend(preds.cpu().numpy())

    accuracy = accuracy_score(all_labels, all_preds)
    auc_score = roc_auc_score(all_labels, all_probs)
    f1 = f1_score(all_labels, all_preds)
    mcc = matthews_corrcoef(all_labels, all_preds)

    return accuracy, auc_score, f1, mcc

# **Cross-Validation Training**
skf = StratifiedKFold(n_splits=5, shuffle=True, random_state=123)
fold_results = []

for fold, (train_ids, test_ids) in enumerate(skf.split(np.zeros(len(labels)), labels)):
    print(f'FOLD {fold}')
    print('--------------------------------')

    train_subset = Subset(dataset, train_ids)
    test_subset = Subset(dataset, test_ids)

    train_loader = DataLoader(train_subset, batch_size=32, shuffle=True)
    test_loader = DataLoader(test_subset, batch_size=32, shuffle=False)

    # Initialize ResNet1D model
    model = ECGResNet1D(input_channels=len(lead_names), num_classes=2).to(device)

    optimizer = optim.Adam(model.parameters(), lr=1e-4)
    criterion = nn.CrossEntropyLoss()

    # Train and evaluate
    train_classifier(model, train_loader, criterion, optimizer, epochs=50)
    accuracy, auc_score, f1, mcc = evaluate_model(model, test_loader)
    fold_results.append((accuracy, auc_score, f1, mcc))

    print(f'Fold {fold} Results: Accuracy: {accuracy:.4f}, AUC: {auc_score:.4f}, F1: {f1:.4f}, MCC: {mcc:.4f}\n')

# **Calculate Metrics Across Folds**
accuracies, aucs, f1s, mccs = zip(*fold_results)
print(f'Mean Accuracy: {np.mean(accuracies):.4f}, STD: {np.std(accuracies):.4f}')
print(f'Mean AUC: {np.mean(aucs):.4f}, STD: {np.std(aucs):.4f}')
print(f'Mean F1: {np.mean(f1s):.4f}, STD: {np.std(f1s):.4f}')
print(f'Mean MCC: {np.mean(mccs):.4f}, STD: {np.std(mccs):.4f}')

