In [1]:
import os
import numpy as np
import pandas as pd
from sklearn.metrics import recall_score, precision_score, f1_score, confusion_matrix

import torch
import torch.nn as nn
import torch.optim as optim

from torch.utils.data import Dataset, DataLoader, TensorDataset, random_split


In [2]:
# =========================
# 1. load data & Dataset define
# =========================
data_dir = r".\archive"
train_path = os.path.join(data_dir, "mitbih_train.csv")
test_path = os.path.join(data_dir, "mitbih_test.csv")

train_df = pd.read_csv(train_path, header=None)
test_df = pd.read_csv(test_path, header=None)

X_train = train_df.iloc[:, :-1].values
y_train = train_df.iloc[:, -1].values.astype(int)

X_test = test_df.iloc[:, :-1].values
y_test = test_df.iloc[:, -1].values.astype(int)

# Z-score per sample
X_train = (X_train - X_train.mean(axis=1, keepdims=True)) / (X_train.std(axis=1, keepdims=True) + 1e-8)
X_test = (X_test - X_test.mean(axis=1, keepdims=True)) / (X_test.std(axis=1, keepdims=True) + 1e-8)

# reshape: (samples, 1, 187)
X_train = X_train[:, np.newaxis, :]
X_test = X_test[:, np.newaxis, :]

X_train_tensor = torch.tensor(X_train, dtype=torch.float32)
y_train_tensor = torch.tensor(y_train, dtype=torch.long)

X_test_tensor = torch.tensor(X_test, dtype=torch.float32)
y_test_tensor = torch.tensor(y_test, dtype=torch.long)


In [3]:
#Couple Conv Block
class CoupleConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels, dropout=0.3):
        super(CoupleConvBlock, self).__init__()
        self.block = nn.Sequential(
            nn.Conv1d(in_channels, out_channels, kernel_size=5, padding=2),
            nn.BatchNorm1d(out_channels),
            nn.ReLU(inplace=True),

            nn.Conv1d(out_channels, out_channels, kernel_size=5, padding=2),
            nn.BatchNorm1d(out_channels),
            nn.ReLU(inplace=True),

            nn.MaxPool1d(kernel_size=2, stride=2),
            nn.Dropout(dropout)
        )

    def forward(self, x):
        return self.block(x)

class ECG_CNN(nn.Module):
    def __init__(self, num_classes=5):
        super(ECG_CNN, self).__init__()
        self.block1 = CoupleConvBlock(1, 64, dropout=0.2)
        self.block2 = CoupleConvBlock(64, 128, dropout=0.3)

        #Fc 187 -> pool 2x -> 93 -> pool 2x -> 46
        self.fc_in = nn.Linear(128 * 46, 128)
        self.fc_bn = nn.BatchNorm1d(128)
        self.fc_drop = nn.Dropout(0.5)
        self.fc_out = nn.Linear(128, num_classes)

    def forward(self, x):
        x = self.block1(x)
        x = self.block2(x)
        x = x.view(x.size(0), -1)
        x = self.fc_drop(self.fc_bn(self.fc_in(x)))
        x = self.fc_out(x)
        return x


In [27]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
save_dir = os.path.join(data_dir, "models")
os.makedirs(save_dir, exist_ok=True)
model_path = os.path.join(save_dir, "ECG_couple_BN_CNN_V2_Adam.pt")


In [28]:
# =========================
# 3.Training
# =========================
#normal split
# dataset = TensorDataset(X_train_tensor, y_train_tensor)
# train_size = int(0.8 * len(dataset))
# val_size = len(dataset) - train_size
# train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

#keep same seed for all model
np.random.seed(42)

#Stratified split
class_ranges = [
    (0, 72471),
    (72471, 74694),
    (74694, 80483),
    (80483, 81123),
    (81123, 87554)
]

train_indices = []
val_indices = []

for start, end in class_ranges:
    idx = np.arange(start, end)
    np.random.shuffle(idx)  # shuffle
    n_val = int(len(idx) * 0.2)
    val_indices.extend(idx[:n_val])
    train_indices.extend(idx[n_val:])
    
#  tensor
train_dataset = TensorDataset(X_train_tensor[train_indices], y_train_tensor[train_indices])
val_dataset = TensorDataset(X_train_tensor[val_indices], y_train_tensor[val_indices])

#data loader
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=64, shuffle=False)

# weight loss
class_counts = [72471, 2223, 5789, 640, 6431]
total_samples = sum(class_counts)
num_classes = len(class_counts)
base_weights = [total_samples / (num_classes * n) for n in class_counts]

# alpha = 1, fully weight according to class frequency, alpha = 0 no weighting
alpha = 0
adjusted_weights = [1 + alpha * (w - 1) for w in base_weights]
weights_tensor = torch.tensor(adjusted_weights, dtype=torch.float32).to(device)

model = ECG_CNN().to(device)
criterion = nn.CrossEntropyLoss()
#criterion = nn.CrossEntropyLoss(weight=weights_tensor)
#optimizer = optim.Adam(model.parameters(), lr=1e-3)
optimizer = optim.AdamW(model.parameters(), lr=1e-3,weight_decay=1e-3)
# Scheduler: reduce LR on plateau (val loss)
#scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="min", factor=0.5, patience=3, min_lr=1e-6)

num_epochs = 100
best_val_loss = float("inf")
patience_es = 15
no_improve_count = 0
for epoch in range(num_epochs):
    # ---- Training ----
    model.train()
    train_loss = 0
    for X, y in train_loader:
        X, y = X.to(device), y.to(device)
        optimizer.zero_grad()
        outputs = model(X)
        loss = criterion(outputs, y)
        loss.backward()
        optimizer.step()
        train_loss += loss.item() * X.size(0)
    train_loss /= len(train_loader.dataset)

    # ---- Validation ----
    model.eval()
    val_loss = 0
    correct = 0
    with torch.no_grad():
        for X, y in val_loader:
            X, y = X.to(device), y.to(device)
            outputs = model(X)
            loss = criterion(outputs, y)
            val_loss += loss.item() * X.size(0)
            preds = outputs.argmax(dim=1)
            correct += (preds == y).sum().item()
    val_loss /= len(val_loader.dataset)
    val_acc = correct / len(val_loader.dataset)

    print(f"Epoch {epoch+1}/{num_epochs} - "
          f"Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.4f}")
    # scheduler step (use val loss)
    #scheduler.step(val_loss)

    # ---- Save Best ----
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        torch.save(model.state_dict(), model_path)
        print("Best model saved.")
        no_improve_count = 0
    else:
        no_improve_count += 1
        if no_improve_count >= patience_es:
            print(f"Early stopping at epoch {epoch+1}")
            break



Epoch 1/100 - Train Loss: 0.2633 | Val Loss: 0.1094 | Val Acc: 0.9706
Best model saved.
Epoch 2/100 - Train Loss: 0.1291 | Val Loss: 0.0966 | Val Acc: 0.9733
Best model saved.
Epoch 3/100 - Train Loss: 0.1004 | Val Loss: 0.0887 | Val Acc: 0.9753
Best model saved.
Epoch 4/100 - Train Loss: 0.0865 | Val Loss: 0.0708 | Val Acc: 0.9805
Best model saved.
Epoch 5/100 - Train Loss: 0.0792 | Val Loss: 0.0669 | Val Acc: 0.9816
Best model saved.
Epoch 6/100 - Train Loss: 0.0718 | Val Loss: 0.0642 | Val Acc: 0.9830
Best model saved.
Epoch 7/100 - Train Loss: 0.0664 | Val Loss: 0.0622 | Val Acc: 0.9825
Best model saved.
Epoch 8/100 - Train Loss: 0.0609 | Val Loss: 0.0620 | Val Acc: 0.9822
Best model saved.
Epoch 9/100 - Train Loss: 0.0575 | Val Loss: 0.0554 | Val Acc: 0.9843
Best model saved.
Epoch 10/100 - Train Loss: 0.0536 | Val Loss: 0.0548 | Val Acc: 0.9845
Best model saved.
Epoch 11/100 - Train Loss: 0.0519 | Val Loss: 0.0517 | Val Acc: 0.9859
Best model saved.
Epoch 12/100 - Train Loss: 0.0

In [29]:
# =========================
# evaluation
# =========================
test_ds = TensorDataset(X_test_tensor, y_test_tensor)
test_loader = DataLoader(test_ds, batch_size=128, shuffle=False)

# load model
model = ECG_CNN().to(device)
model.load_state_dict(torch.load(model_path))
model.eval()
#eval
preds, labels = [], []
with torch.no_grad():
    for xb, yb in test_loader:
        xb, yb = xb.to(device), yb.to(device)
        out = model(xb)
        pred = torch.argmax(out, dim=1)
        preds.extend(pred.cpu().numpy())
        labels.extend(yb.cpu().numpy())
#save csv
csv_path = os.path.join(save_dir, "test_pred.csv")
pd.DataFrame({"y_true": labels, "y_pred": preds}).to_csv(csv_path, index=False)
print(f"Evaluation CSV saved at {csv_path}")


Evaluation CSV saved at .\archive\models\test_pred.csv


In [30]:
# =========================
# confusion matrix & metrics
# =========================
cm = confusion_matrix(labels, preds, labels=list(range(5)))

print("===== Confusion Matrix =====")
print(cm)

# calculate metrics for 5 classes
metrics_per_class = {"recall": [], "specificity": [], "precision": [], "f1": []}
class_counts = cm.sum(axis=1)
total_samples = class_counts.sum()
weights = class_counts / total_samples

for i in range(5):
    TP = cm[i, i]
    FP = cm[:, i].sum() - TP
    FN = cm[i, :].sum() - TP
    TN = cm.sum() - (TP + FP + FN)

    recall_i = TP / (TP + FN + 1e-8)
    specificity_i = TN / (TN + FP + 1e-8)
    precision_i = TP / (TP + FP + 1e-8)
    f1_i = 2 * recall_i * precision_i / (recall_i + precision_i + 1e-8)

    metrics_per_class["recall"].append(recall_i)
    metrics_per_class["specificity"].append(specificity_i)
    metrics_per_class["precision"].append(precision_i)
    metrics_per_class["f1"].append(f1_i)

macro_avg_metrics = {k: np.mean(v) for k, v in metrics_per_class.items()}
weighted_avg_metrics = {k: np.sum(np.array(v) * weights) for k, v in metrics_per_class.items()}

print("\n===== Per-Class Metrics =====")
for k, v in metrics_per_class.items():
    print(f"{k}: {np.round(v, 4)}")
print("\n===== Macro-Average Metrics =====")
for k, v in macro_avg_metrics.items():
    print(f"{k}: {v:.4f}")
print("\n===== Weighted-Average Metrics =====")
for k, v in weighted_avg_metrics.items():
    print(f"{k}: {v:.4f}")

===== Confusion Matrix =====
[[18052    39    23     2     2]
 [  118   428     8     2     0]
 [   37     2  1388    19     2]
 [   20     0     5   137     0]
 [   14     0     1     0  1593]]

===== Per-Class Metrics =====
recall: [0.9964 0.7698 0.9586 0.8457 0.9907]
specificity: [0.9499 0.9981 0.9982 0.9989 0.9998]
precision: [0.9896 0.9126 0.974  0.8562 0.9975]
f1: [0.993  0.8351 0.9662 0.8509 0.9941]

===== Macro-Average Metrics =====
recall: 0.9122
specificity: 0.9890
precision: 0.9460
f1: 0.9279

===== Weighted-Average Metrics =====
recall: 0.9866
specificity: 0.9584
precision: 0.9862
f1: 0.9862
