# Desease prediction by ECG

## Data processing

In [3]:
import os
import numpy as np
import wfdb

In [None]:
data_path = "./ecg_resources/data"
label_path = "./ecg_resources/annotations.csv"
output_path = "./data"

min_length = 2200 # this is length of the shortest timestamp
batch = []

labels = np.loadtxt(label_path, delimiter=',', skiprows=1, dtype=str)
trimed_labels = np.delete(labels, [0, 1, 2, 3, -1], axis=1)
casted_labels = trimed_labels.astype(np.float32)

filename = "batch"
file_group = 1

os.makedirs(output_path, exist_ok=True)

In [None]:
while file_group <= 39_999:
    record_name = f"TNMG{file_group}_N1"
    print(f"Processing record: {record_name}")
    try:
        record_signal = wfdb.rdrecord(os.path.join(data_path, record_name)).p_signal

        if record_signal.shape[0] > min_length:
            trimmed_signal = record_signal[:min_length, :]
        else:
            trimmed_signal = record_signal

        batch.append(trimmed_signal)

    except Exception as e:
        print(f"Can't load file: {record_name}, error: {e}\n")
        file_group += 1
        continue

    if len(batch) == 100:
        numpy_array = np.array(batch, dtype=np.float32)
        trimmed_labels = casted_labels[:100]

        np.savez(os.path.join(output_path, f"{filename}-{int(file_group / 100)}.npz"),
                 signals=numpy_array, labels=trimmed_labels)

        print(f"Batch saved as {filename}-{int(file_group / 100)}.npz\n")

        batch = []
        casted_labels = casted_labels[100:]

    file_group += 1

## Creating a PyTorch dataset

In [4]:
import torch

torch.__version__

'2.5.1+cu124'

In [None]:
input_path = "./data"

data = []
labels = []

for i, file in enumerate(os.listdir(input_path)):
    print(f"{i}. processing file: {file}")
    file_path = os.path.join(input_path, file)

    record = np.load(file_path)
    signals = record["signals"]
    label = record["labels"]

    data.append(signals)
    labels.append(label)

data = np.array(data)
labels = np.array(labels)

torch.save({'data': data, 'labels': labels}, 'dataset.pt')

## Dataset class

In [44]:
from torch.utils.data import Dataset

class ECGDataset(Dataset):
    def __init__(self, data, labels):
        self.data = torch.tensor(data, dtype=torch.float32)
        self.labels = torch.tensor(labels, dtype=torch.float32)

        self.data = self.data.view(-1, 2200 * 8)
        self.labels = self.labels.view(-1, 6)

        self.num_samples = self.data.shape[0]

    def __len__(self):
        return self.num_samples

    def __getitem__(self, idx):
        x = self.data[idx]
        y = self.labels[idx]
        return x, y

## Model class

In [45]:
import torch.nn as nn

class MultiLabelClassifier(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super(MultiLabelClassifier, self).__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(hidden_dim, output_dim)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        x = self.sigmoid(x)
        return x

## Model training

In [46]:
device = "cuda" if torch.cuda.is_available() else "cpu"

In [47]:
import torch.optim as optim

input_dim = 2200 * 8
hidden_dim = 128
output_dim = 6

model = MultiLabelClassifier(input_dim, hidden_dim, output_dim).to(device)

criterion = nn.BCELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

In [48]:
from torch.utils.data import DataLoader

dataset_load = torch.load("dataset.pt", weights_only=False)
dataset = ECGDataset(dataset_load["data"], dataset_load["labels"])

data_loader = DataLoader(dataset, batch_size=32, shuffle=True)


In [49]:
epochs = 100
for epoch in range(epochs):
    model.train()
    total_train_loss = 0
    train_correct = 0
    train_total = 0

    for inputs, labels in data_loader:
        inputs, labels = inputs.to(device), labels.to(device)
        labels = labels.float()

        optimizer.zero_grad()
        outputs = model(inputs)

        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        total_train_loss += loss.item()

        predicted = (outputs > 0.5).float()
        train_correct += (predicted == labels).sum().item()
        train_total += labels.numel()

    avg_train_loss = total_train_loss / len(data_loader)
    train_accuracy = train_correct / train_total


    model.eval()
    total_test_loss = 0
    test_correct = 0
    test_total = 0

    with torch.inference_mode():
        for inputs, labels in data_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            labels = labels.float()

            outputs = model(inputs)
            loss = criterion(outputs, labels)
            total_test_loss += loss.item()

            predicted = (outputs > 0.5).float()
            test_correct += (predicted == labels).sum().item()
            test_total += labels.numel()

    avg_test_loss = total_test_loss / len(data_loader)
    test_accuracy = test_correct / test_total

    print(f"Epoch {epoch+1}/{epochs} | Train Loss: {avg_train_loss:.4f}, Train Acc: {train_accuracy:.4f} | Test Loss: {avg_test_loss:.4f}, Test Acc: {test_accuracy:.4f}")


Epoch 1/100 | Train Loss: 0.3363, Train Acc: 0.9710 | Test Loss: 0.2585, Test Acc: 0.9767
Epoch 2/100 | Train Loss: 0.2820, Train Acc: 0.9765 | Test Loss: 1.2255, Test Acc: 0.9612
Epoch 3/100 | Train Loss: 0.2638, Train Acc: 0.9763 | Test Loss: 0.2426, Test Acc: 0.9778
Epoch 4/100 | Train Loss: 0.2510, Train Acc: 0.9768 | Test Loss: 0.2279, Test Acc: 0.9780
Epoch 5/100 | Train Loss: 0.2510, Train Acc: 0.9770 | Test Loss: 0.2321, Test Acc: 0.9788
Epoch 6/100 | Train Loss: 0.2636, Train Acc: 0.9774 | Test Loss: 0.2591, Test Acc: 0.9791
Epoch 7/100 | Train Loss: 0.2885, Train Acc: 0.9778 | Test Loss: 0.2713, Test Acc: 0.9797
Epoch 8/100 | Train Loss: 0.2775, Train Acc: 0.9785 | Test Loss: 0.2980, Test Acc: 0.9809
Epoch 9/100 | Train Loss: 0.2859, Train Acc: 0.9797 | Test Loss: 0.2692, Test Acc: 0.9821
Epoch 10/100 | Train Loss: 0.2854, Train Acc: 0.9807 | Test Loss: 0.2712, Test Acc: 0.9833
Epoch 11/100 | Train Loss: 0.2974, Train Acc: 0.9814 | Test Loss: 0.2908, Test Acc: 0.9838
Epoch 12

## Save model

In [50]:
torch.save(model.state_dict(), "ecg_model.pth")

## Test Model Prediction Accuracy

In [60]:
def print_pred(vals, ans):
    model.eval()
    with torch.inference_mode():
        vals = vals.clone().detach().float().to(device)

        vals = vals.view(1, -1)

        pred = model(vals)
        probabilities = pred.cpu().numpy()
        predicted_labels = (pred > 0.5).int().cpu().numpy()

        print(f"\nProbabilities: {probabilities}, \nPrediction: {predicted_labels}, \nAnswer: {ans} \n")

In [62]:
test_model = MultiLabelClassifier(input_dim, hidden_dim, output_dim).to(device)
test_model.load_state_dict(torch.load("ecg_model.pth", weights_only=False, map_location=device))

sample_0 = dataset.data[0]
sample_1 = dataset.data[1]
sample_2 = dataset.data[2]
sample_3 = dataset.data[3]
sample_4 = dataset.data[4]

label_0 = dataset.labels[0].cpu().numpy()
label_1 = dataset.labels[1].cpu().numpy()
label_2 = dataset.labels[2].cpu().numpy()
label_3 = dataset.labels[3].cpu().numpy()
label_4 = dataset.labels[4].cpu().numpy()


print_pred(sample_0, label_0)
print_pred(sample_1, label_1)
print_pred(sample_2, label_2)
print_pred(sample_3, label_3)
print_pred(sample_4, label_4)


Probabilities: [[6.8153909e-08 2.4420189e-07 5.6282011e-17 1.3305446e-13 8.2911463e-08
  1.9636718e-07]], 
Prediction: [[0 0 0 0 0 0]], 
Answer: [0. 0. 0. 0. 0. 0.] 


Probabilities: [[3.2594273e-25 5.2787104e-21 6.1008863e-14 1.0344354e-22 1.0000000e+00
  9.3258457e-12]], 
Prediction: [[0 0 0 0 1 0]], 
Answer: [0. 0. 0. 0. 1. 0.] 


Probabilities: [[0. 0. 0. 0. 0. 0.]], 
Prediction: [[0 0 0 0 0 0]], 
Answer: [1. 0. 0. 0. 0. 0.] 


Probabilities: [[0.0000000e+00 3.2699278e-24 1.0000000e+00 1.2412947e-32 0.0000000e+00
  1.1415087e-38]], 
Prediction: [[0 0 1 0 0 0]], 
Answer: [0. 0. 1. 0. 0. 0.] 


Probabilities: [[0. 0. 0. 0. 0. 0.]], 
Prediction: [[0 0 0 0 0 0]], 
Answer: [0. 0. 0. 0. 0. 0.] 

