# Desease prediction by ECG

## Data processing

In [3]:
import os
import numpy as np
import wfdb

In [None]:
data_path = "./ecg_resources/data"
label_path = "./ecg_resources/annotations.csv"
output_path = "./data"

min_length = 2200 # this is length of the shortest timestamp
batch = []

labels = np.loadtxt(label_path, delimiter=',', skiprows=1, dtype=str)
trimed_labels = np.delete(labels, [0, 1, 2, 3, -1], axis=1)
casted_labels = trimed_labels.astype(np.float32)

filename = "batch"
file_group = 1

os.makedirs(output_path, exist_ok=True)

In [None]:
while file_group <= 39_999:
    record_name = f"TNMG{file_group}_N1"
    print(f"Processing record: {record_name}")
    try:
        record_signal = wfdb.rdrecord(os.path.join(data_path, record_name)).p_signal

        if record_signal.shape[0] > min_length:
            trimmed_signal = record_signal[:min_length, :]
        else:
            trimmed_signal = record_signal

        batch.append(trimmed_signal)

    except Exception as e:
        print(f"Can't load file: {record_name}, error: {e}\n")
        file_group += 1
        continue

    if len(batch) == 100:
        numpy_array = np.array(batch, dtype=np.float32)
        trimmed_labels = casted_labels[:100]

        np.savez(os.path.join(output_path, f"{filename}-{int(file_group / 100)}.npz"),
                 signals=numpy_array, labels=trimmed_labels)

        print(f"Batch saved as {filename}-{int(file_group / 100)}.npz\n")

        batch = []
        casted_labels = casted_labels[100:]

    file_group += 1

## Creating a PyTorch dataset

In [1]:
import torch

torch.__version__

'2.5.1+cu124'

In [None]:
input_path = "./data"

data = []
labels = []

for i, file in enumerate(os.listdir(input_path)):
    print(f"{i}. processing file: {file}")
    file_path = os.path.join(input_path, file)

    record = np.load(file_path)
    signals = record["signals"]
    label = record["labels"]

    data.append(signals)
    labels.append(label)

data = np.array(data)
labels = np.array(labels)

torch.save({'data': data, 'labels': labels}, 'dataset.pt')

## Dataset class

In [28]:
from torch.utils.data import Dataset

class ECGDataset(Dataset):
    def __init__(self, data, labels):
        self.data = torch.tensor(data, dtype=torch.float32)
        self.labels = torch.tensor(labels, dtype=torch.long)

        self.num_samples = self.data.shape[0] * self.data.shape[1]

    def __len__(self):
        return self.num_samples

    def __getitem__(self, idx):
        file_idx = idx // self.data.shape[1]
        matrix_idx = idx % self.data.shape[1]

        x = self.data[file_idx, matrix_idx]
        y = self.labels[file_idx]

        y = y[matrix_idx]

        return x, y

## Model class

In [12]:
import torch.nn as nn

class MulticlassClassifier(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super(MulticlassClassifier, self).__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(hidden_dim, output_dim)

    def forward(self, x):
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        return x

## Model training

In [29]:
device = "cuda" if torch.cuda.is_available() else "cpu"

In [30]:
input_dim = 300
hidden_dim = 128
output_dim = 6

model = MulticlassClassifier(input_dim, hidden_dim, output_dim).to(device)

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

In [31]:
from torch.utils.data import DataLoader

dataset_load = torch.load("dataset.pt", weights_only=False)
dataset = ECGDataset(dataset_load["data"], dataset_load["labels"])

data_loader = DataLoader(dataset, batch_size=32, shuffle=True)

In [None]:
epochs = 100
for epoch in range(epochs):
    model.train()
    for inputs, labels in data_loader:
        inputs = inputs.to(device)
        labels = labels.to(device)
        
        outputs = model(inputs).to(device)
        
        loss = criterion(outputs, labels)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    print(f"Epoch {epoch+1}/{epochs}, Loss: {loss.item():.4f}")