In [1]:
# 1. Imports
import os
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from sklearn.preprocessing import MultiLabelBinarizer, StandardScaler
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, roc_auc_score
import wfdb
import ast
import time
from tqdm import tqdm
import wandb 
wandb.finish()

In [2]:
# 2. Load CSVs
ptbxl_path = "/media/nicholas/Storage/Datasets/ptb-xl-a-large-publicly-available-electrocardiography-dataset-1.0.3/ptbxl_database.csv"
scp_path = "/media/nicholas/Storage/Datasets/ptb-xl-a-large-publicly-available-electrocardiography-dataset-1.0.3/scp_statements.csv"
waveform_path = "/media/nicholas/Storage/Datasets/ptb-xl-a-large-publicly-available-electrocardiography-dataset-1.0.3"

df = pd.read_csv(ptbxl_path)
scp_df = pd.read_csv(scp_path, index_col=0)

In [3]:
# Parse SCP codes
df['scp_codes'] = df['scp_codes'].apply(ast.literal_eval)
df['scp_keys'] = df['scp_codes'].apply(lambda x: list(x.keys()))
scp_df = scp_df[scp_df.diagnostic == 1]

In [4]:
# 3. Select Top 10
target_labels = ['NORM', 'SR', 'AFIB', 'PVC', 'LVH', 'ABQRS', 'IMI', 'ASMI', 'LAFB', 'IRBBB']

In [5]:
# Filter dataset
df['scp_filtered'] = df['scp_keys'].apply(lambda codes: [code for code in codes if code in target_labels])
df = df[df['scp_filtered'].map(len) > 0]

In [6]:
# Binarize
mlb = MultiLabelBinarizer(classes=target_labels)
y = mlb.fit_transform(df['scp_filtered'])

In [7]:
# 4. Split
X_train, X_test, y_train, y_test = train_test_split(df, y, test_size=0.2, random_state=42)

In [8]:
# 5. Dataset
class PTBXL_Dataset(Dataset):
    def __init__(self, df, labels, base_dir, signal_len=5000):
        self.df = df
        self.labels = labels
        self.base_dir = base_dir
        self.signal_len = signal_len

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        path = os.path.join(self.base_dir, row['filename_lr'])
        record = wfdb.rdrecord(path)
        signal = record.p_signal.T

        if signal.shape[1] < self.signal_len:
            pad = self.signal_len - signal.shape[1]
            signal = np.pad(signal, ((0, 0), (0, pad)), 'constant')
        else:
            signal = signal[:, :self.signal_len]

        return torch.tensor(signal, dtype=torch.float32), torch.tensor(self.labels[idx], dtype=torch.float32)

train_dataset = PTBXL_Dataset(X_train, y_train, waveform_path)
test_dataset = PTBXL_Dataset(X_test, y_test, waveform_path)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=4, pin_memory=True)
test_loader = DataLoader(test_dataset, batch_size=32, num_workers=4, pin_memory=True)

In [9]:
# 6. CNN Model
class ECG_CNN(nn.Module):
    def __init__(self, num_classes):
        super().__init__()
        self.cnn = nn.Sequential(
            nn.Conv1d(12, 32, kernel_size=7, padding=3),
            nn.ReLU(),
            nn.MaxPool1d(2),
            nn.Conv1d(32, 64, kernel_size=5, padding=2),
            nn.ReLU(),
            nn.MaxPool1d(2),
            nn.Conv1d(64, 128, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.AdaptiveAvgPool1d(1)
        )
        self.fc = nn.Sequential(
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Linear(64, num_classes)
        )

    def forward(self, x):
        x = self.cnn(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        return x

In [10]:
# 7. Initialize Weights & Biases
epochs = 15
wandb.init(settings=wandb.Settings(start_method="fork", _disable_stats=False))
wandb.init(
    project="ptbxl-cnn",
    name=f"cnn-multilabel-run-{int(time.time())}",# unique name using timestamp
    settings=wandb.Settings(start_method="fork", _disable_stats=False)
)


wandb.config.update({
    "epochs": epochs,
    "batch_size": 32,
    "learning_rate": 0.001,
    "architecture": "1D CNN Multi-label PTB-XL",
    "dataset": "PTB-XL Top 10 Classes"
})

[34m[1mwandb[0m: Currently logged in as: [33mngkanos[0m ([33mngkanos-youngstown-state-university[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [11]:
# 8. Train
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = ECG_CNN(num_classes=len(target_labels)).to(device)
criterion = nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

for epoch in range(epochs):
    model.train()
    total_loss = 0
    correct = 0
    total = 0

    train_loader_tqdm = tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs}")

    for signals, labels in train_loader_tqdm:
        signals, labels = signals.to(device), labels.to(device)
        signals = signals.permute(0, 1, 2)

        optimizer.zero_grad()
        outputs = model(signals)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

        preds = (torch.sigmoid(outputs) > 0.5).float()
        acc = (preds == labels).float().mean().item()
        correct += acc
        total += 1

        train_loader_tqdm.set_postfix(loss=loss.item())

    avg_loss = total_loss / len(train_loader)
    avg_acc = correct / total

    print(f"Epoch {epoch+1}, Loss: {avg_loss:.4f}, Accuracy: {avg_acc:.4f}")
    wandb.log({"epoch": epoch+1, "loss": avg_loss, "accuracy": avg_acc})


Epoch 1/15: 100%|█████████████████| 523/523 [00:11<00:00, 47.08it/s, loss=0.337]


Epoch 1, Loss: 0.3688, Accuracy: 0.8568


Epoch 2/15: 100%|█████████████████| 523/523 [00:10<00:00, 48.48it/s, loss=0.293]


Epoch 2, Loss: 0.3170, Accuracy: 0.8803


Epoch 3/15: 100%|██████████████████| 523/523 [00:11<00:00, 46.70it/s, loss=0.33]


Epoch 3, Loss: 0.2986, Accuracy: 0.8845


Epoch 4/15: 100%|█████████████████| 523/523 [00:10<00:00, 49.71it/s, loss=0.257]


Epoch 4, Loss: 0.2896, Accuracy: 0.8868


Epoch 5/15: 100%|█████████████████| 523/523 [00:11<00:00, 47.43it/s, loss=0.277]


Epoch 5, Loss: 0.2839, Accuracy: 0.8886


Epoch 6/15: 100%|█████████████████| 523/523 [00:10<00:00, 51.64it/s, loss=0.265]


Epoch 6, Loss: 0.2773, Accuracy: 0.8897


Epoch 7/15: 100%|█████████████████| 523/523 [00:10<00:00, 50.04it/s, loss=0.223]


Epoch 7, Loss: 0.2674, Accuracy: 0.8930


Epoch 8/15: 100%|█████████████████| 523/523 [00:11<00:00, 47.21it/s, loss=0.256]


Epoch 8, Loss: 0.2622, Accuracy: 0.8946


Epoch 9/15: 100%|█████████████████| 523/523 [00:10<00:00, 50.23it/s, loss=0.266]


Epoch 9, Loss: 0.2572, Accuracy: 0.8961


Epoch 10/15: 100%|████████████████| 523/523 [00:10<00:00, 51.59it/s, loss=0.303]


Epoch 10, Loss: 0.2537, Accuracy: 0.8978


Epoch 11/15: 100%|████████████████| 523/523 [00:10<00:00, 51.30it/s, loss=0.256]


Epoch 11, Loss: 0.2489, Accuracy: 0.8996


Epoch 12/15: 100%|████████████████| 523/523 [00:10<00:00, 51.29it/s, loss=0.251]


Epoch 12, Loss: 0.2428, Accuracy: 0.9020


Epoch 13/15: 100%|████████████████| 523/523 [00:09<00:00, 52.49it/s, loss=0.196]


Epoch 13, Loss: 0.2364, Accuracy: 0.9047


Epoch 14/15: 100%|█████████████████| 523/523 [00:10<00:00, 50.82it/s, loss=0.23]


Epoch 14, Loss: 0.2312, Accuracy: 0.9061


Epoch 15/15: 100%|████████████████| 523/523 [00:10<00:00, 49.69it/s, loss=0.183]

Epoch 15, Loss: 0.2271, Accuracy: 0.9072





In [12]:
wandb.finish()

0,1
accuracy,▁▄▅▅▅▆▆▆▆▇▇▇███
epoch,▁▁▂▃▃▃▄▅▅▅▆▇▇▇█
loss,█▅▅▄▄▃▃▃▂▂▂▂▁▁▁

0,1
accuracy,0.90719
epoch,15.0
loss,0.22715
