In [1]:
import pandas as pd
import numpy as np

## 1. Dataset Setup

In [2]:
import os
import wfdb
import ast

In [6]:
DATA_DIR = "../data"
SAMPLING_RATE = 100 # 100Hz or 500Hz

In [7]:
# Load annotation data
# See https://www.nature.com/articles/s41597-020-0495-6/tables/3 for columns present.
database_file_path = os.path.join(DATA_DIR, "ptbxl_database.csv")
Y = pd.read_csv(database_file_path, index_col="ecg_id")

# convert label dictionary
Y.scp_codes = Y.scp_codes.apply(lambda x: ast.literal_eval(x))

In [8]:
# Load raw signal data
def load_raw_data(df, sampling_rate, path):
    if sampling_rate == 100:
        data = [wfdb.rdsamp(os.path.join(path, f)) for f in df.filename_lr]
    else:
        data = [wfdb.rdsamp(os.path.join(path, f)) for f in df.filename_hr]
    data = np.array([signal for signal, meta in data])
    return data
X = load_raw_data(Y, SAMPLING_RATE, DATA_DIR)

In [9]:
# Load scp_statements.csv for diagnostic aggregation
# See https://www.nature.com/articles/s41597-020-0495-6/tables/13 for columns present.
statements_file_path = os.path.join(DATA_DIR, "scp_statements.csv")
agg_df = pd.read_csv(statements_file_path, index_col=0)
agg_df = agg_df[agg_df.diagnostic == 1]

def aggregate_diagnostic(y_dic):
    """Aggregates and returns a list of superclasses present for each ECG."""
    tmp = []
    for key in y_dic.keys():
        if key in agg_df.index:
            tmp.append(agg_df.loc[key].diagnostic_class)
    return list(set(tmp))

# Apply diagnostic superclass
Y['diagnostic_superclass'] = Y.scp_codes.apply(aggregate_diagnostic)

In [10]:
## Info about sets
print(f"Feature data shape (X): {X.shape}") # (Samples, Timepoints, Leads)
print(f"Label data shape (Y): {Y.shape}")

# Flatten the list of diagnostic superclasses to get a total count
all_diagnostics = [item for sublist in Y.diagnostic_superclass for item in sublist]

# Calculate class distribution
class_counts = pd.Series(all_diagnostics).value_counts()
print("\n--- Class Distribution (Diagnostic Superclass) ---")
print(class_counts)

# Check how many samples have multiple superclasses
multi_label_counts = Y.diagnostic_superclass.apply(len).value_counts().sort_index()
print("\n--- Samples by Number of Labels ---")
print(multi_label_counts)

print("\n--- Missing Values in Metadata (Y) ---")
# Check for nulls in the metadata
missing_meta = Y.isnull().sum()
print(missing_meta[missing_meta > 0]) # Only print columns with missing data

print("\n--- Missing/Corrupt Values in Signal Data (X) ---")
# Check if there are any NaNs or Infinite values in the numpy array
has_nans = np.isnan(X).any()
has_infs = np.isinf(X).any()
print(f"X contains NaNs: {has_nans}")
print(f"X contains Infinite values: {has_infs}")

Feature data shape (X): (21799, 1000, 12)
Label data shape (Y): (21799, 28)

--- Class Distribution (Diagnostic Superclass) ---
NORM    9514
MI      5469
STTC    5235
CD      4898
HYP     2649
Name: count, dtype: int64

--- Samples by Number of Labels ---
diagnostic_superclass
0      411
1    16244
2     4068
3      919
4      157
Name: count, dtype: int64

--- Missing Values in Metadata (Y) ---
height                 14825
weight                 12378
nurse                   1473
site                      17
heart_axis              8468
infarction_stadium1    16187
infarction_stadium2    21696
validated_by            9378
baseline_drift         20201
static_noise           18539
burst_noise            21186
electrodes_problems    21769
extra_beats            19850
pacemaker              21508
dtype: int64

--- Missing/Corrupt Values in Signal Data (X) ---
X contains NaNs: False
X contains Infinite values: False


In [11]:
# Split data into train and test
test_fold = 10
# Train
X_train = X[np.where(Y.strat_fold != test_fold)]
y_train = Y[(Y.strat_fold != test_fold)].diagnostic_superclass
# Test
X_test = X[np.where(Y.strat_fold == test_fold)]
y_test = Y[Y.strat_fold == test_fold].diagnostic_superclass

## 2. Model Setup

In [13]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import TensorDataset, DataLoader
from sklearn.preprocessing import MultiLabelBinarizer
from sklearn.metrics import classification_report, roc_auc_score

### 2.1 Data Preprocess

In [14]:
# One-Hot Encoding
mlb = MultiLabelBinarizer()
y_train_bin = mlb.fit_transform(y_train)
y_test_bin = mlb.transform(y_test)
num_classes = len(mlb.classes_)
print(f"\nClasses found: {mlb.classes_}")

# Convert to PyTorch Tensors
# X input is (Samples, Time, Leads)
# PyTorch Conv1d expects (Samples, Leads, Time).
X_train_t = torch.tensor(X_train, dtype=torch.float32).permute(0, 2, 1)
X_test_t = torch.tensor(X_test, dtype=torch.float32).permute(0, 2, 1)
y_train_t = torch.tensor(y_train_bin, dtype=torch.float32)
y_test_t = torch.tensor(y_test_bin, dtype=torch.float32)

# Create DataLoaders
BATCH_SIZE = 32
train_dataset = TensorDataset(X_train_t, y_train_t)
test_dataset = TensorDataset(X_test_t, y_test_t)

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)


Classes found: ['CD' 'HYP' 'MI' 'NORM' 'STTC']


### 2.2 Model Definition

In [15]:
class ECGNet(nn.Module):
    def __init__(self, num_leads, num_classes):
        super(ECGNet, self).__init__()
        
        # Block 1
        self.conv1 = nn.Conv1d(in_channels=num_leads, out_channels=32, kernel_size=5, padding=2)
        self.bn1 = nn.BatchNorm1d(32)
        self.relu = nn.ReLU()
        self.pool = nn.MaxPool1d(kernel_size=2)
        
        # Block 2
        self.conv2 = nn.Conv1d(32, 64, kernel_size=5, padding=2)
        self.bn2 = nn.BatchNorm1d(64)
        
        # Block 3
        self.conv3 = nn.Conv1d(64, 128, kernel_size=5, padding=2)
        self.bn3 = nn.BatchNorm1d(128)
        
        # Global Average Pooling (handles variable length if needed, averages over time axis)
        self.gap = nn.AdaptiveAvgPool1d(1)
        
        # Classifier
        self.fc = nn.Linear(128, num_classes)

    def forward(self, x):
        # x shape: [Batch, 12, 1000]
        x = self.pool(self.relu(self.bn1(self.conv1(x))))
        x = self.pool(self.relu(self.bn2(self.conv2(x))))
        x = self.pool(self.relu(self.bn3(self.conv3(x))))
        
        x = self.gap(x) # Shape: [Batch, 128, 1]
        x = x.view(x.size(0), -1) # Flatten: [Batch, 128]
        x = self.fc(x) # Output logits
        return x

# Initialize Model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Training on device: {device}")

model = ECGNet(num_leads=12, num_classes=num_classes).to(device)

Training on device: cpu


## 3. Training

In [16]:
criterion = nn.BCEWithLogitsLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
EPOCHS = 10

print("\n--- Starting Training ---")
for epoch in range(EPOCHS):
    model.train()
    running_loss = 0.0
    
    for inputs, labels in train_loader:
        inputs, labels = inputs.to(device), labels.to(device)
        
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
        
    print(f"Epoch {epoch+1}/{EPOCHS} - Loss: {running_loss/len(train_loader):.4f}")


--- Starting Training ---
Epoch 1/10 - Loss: 0.3656
Epoch 2/10 - Loss: 0.3132
Epoch 3/10 - Loss: 0.2972
Epoch 4/10 - Loss: 0.2885
Epoch 5/10 - Loss: 0.2823
Epoch 6/10 - Loss: 0.2755
Epoch 7/10 - Loss: 0.2726
Epoch 8/10 - Loss: 0.2687
Epoch 9/10 - Loss: 0.2658
Epoch 10/10 - Loss: 0.2617


## 4. Evaluation

In [17]:
print("\n--- Evaluating on Test Set ---")
model.eval()
all_preds = []
all_targets = []

with torch.no_grad():
    for inputs, labels in test_loader:
        inputs = inputs.to(device)
        outputs = model(inputs)
        
        # get probabilities
        probs = torch.sigmoid(outputs)
        
        all_preds.append(probs.cpu().numpy())
        all_targets.append(labels.cpu().numpy())

# Concatenate batches
y_pred_probs = np.vstack(all_preds)
y_true = np.vstack(all_targets)

# Convert probabilities to binary predictions (Threshold = 0.5)
y_pred_bin = (y_pred_probs > 0.5).astype(int)


--- Evaluating on Test Set ---


In [18]:
# Classification Report (Precision, Recall, F1 per class)
print("\nClassification Report:")
print(classification_report(y_true, y_pred_bin, target_names=mlb.classes_, zero_division=0))

# Macro AUROC
try:
    roc_auc = roc_auc_score(y_true, y_pred_probs, average="macro")
    print(f"Macro ROC AUC Score: {roc_auc:.4f}")
except ValueError:
    print("Could not calculate ROC AUC (possibly only one class present in test subset).")

# Save model
model_dir = "../models"
os.makedirs(model_dir, exist_ok=True)
model_path = os.path.join(model_dir, "ecg_model.pth")
torch.save(model.state_dict(), model_path)


Classification Report:
              precision    recall  f1-score   support

          CD       0.81      0.63      0.71       496
         HYP       0.79      0.39      0.52       262
          MI       0.76      0.70      0.73       550
        NORM       0.81      0.91      0.85       963
        STTC       0.79      0.62      0.69       521

   micro avg       0.80      0.71      0.75      2792
   macro avg       0.79      0.65      0.70      2792
weighted avg       0.80      0.71      0.74      2792
 samples avg       0.74      0.72      0.72      2792

Macro ROC AUC Score: 0.9152
