In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from torch.utils.data import Dataset, DataLoader
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score
import wfdb
from scipy.signal import butter, filtfilt
from sklearn.metrics import classification_report

In [17]:
# ===============================
# 1. ECG Dataset Loader
# ===============================
class ECGDataset(Dataset):
    def __init__(self, X, y):
        self.X = torch.tensor(X, dtype=torch.float32)
        self.y = torch.tensor(y, dtype=torch.long)
    def __len__(self):
        return len(self.X)
    def __getitem__(self, idx):
        return self.X[idx], self.y[idx]

In [19]:
# ===============================
# 2. Preprocessing
# ===============================
def bandpass_filter(signal, lowcut=0.5, highcut=40, fs=360, order=5):
    nyquist = 0.5 * fs
    b, a = butter(order, [lowcut/nyquist, highcut/nyquist], btype="band")
    return filtfilt(b, a, signal)

def load_mitbih(path="./mit-bih-arrhythmia-database-1.0.0/", record_list=None, segment_length=1000):
    """
    Loads MIT-BIH dataset and prepares it for testing.
    - path: dataset folder
    - record_list: which records to use (e.g. ["100", "101"])
    - segment_length: fixed length for CNN input
    """
    if record_list is None:
        record_list = ["100", "101", "103", "105"]  # you can expand

    X, y = [], []
    for rec in record_list:
        record = wfdb.rdrecord(path + rec)
        ann = wfdb.rdann(path + rec, "atr")

        signal = record.p_signal[:,0]  # use lead 0
        signal = bandpass_filter(signal, fs=record.fs)

        # Slice into segments aligned with annotations
        for i, sample in enumerate(ann.sample):
            start = max(0, sample - segment_length//2)
            end = start + segment_length
            if end > len(signal):
                continue
            segment = signal[start:end]
            if len(segment) == segment_length:
                X.append(segment)
                # Convert annotation symbol to binary label (Normal vs Abnormal)
                label = 0 if ann.symbol[i] == "N" else 1
                y.append(label)

    X = np.expand_dims(np.array(X), 1)  # (N, 1, segment_length)
    y = np.array(y)
    return X, y

In [21]:
# ===============================
# 3. ResNet Model (same as training)
# ===============================
class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1):
        super(ResidualBlock, self).__init__()
        self.conv1 = nn.Conv1d(in_channels, out_channels, kernel_size=7, stride=stride, padding=3, bias=False)
        self.bn1 = nn.BatchNorm1d(out_channels)
        self.conv2 = nn.Conv1d(out_channels, out_channels, kernel_size=7, stride=1, padding=3, bias=False)
        self.bn2 = nn.BatchNorm1d(out_channels)
        self.shortcut = nn.Sequential()
        if stride != 1 or in_channels != out_channels:
            self.shortcut = nn.Sequential(
                nn.Conv1d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm1d(out_channels)
            )
    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        return F.relu(out)

class ResNet1D(nn.Module):
    def __init__(self, block, layers, num_classes=2, n_leads=1):
        super(ResNet1D, self).__init__()
        self.in_channels = 64
        self.conv1 = nn.Conv1d(n_leads, 64, kernel_size=7, stride=2, padding=3, bias=False)
        self.bn1 = nn.BatchNorm1d(64)
        self.layer1 = self._make_layer(block, 64, layers[0], stride=1)
        self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
        self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
        self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
        self.avg_pool = nn.AdaptiveAvgPool1d(1)
        self.fc = nn.Linear(512, num_classes)
    def _make_layer(self, block, out_channels, blocks, stride):
        strides = [stride] + [1]*(blocks-1)
        layers = []
        for s in strides:
            layers.append(block(self.in_channels, out_channels, s))
            self.in_channels = out_channels
        return nn.Sequential(*layers)
    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = self.avg_pool(out)
        out = out.view(out.size(0), -1)
        return self.fc(out)

def ResNet18_1D(num_classes=2, n_leads=1):
    return ResNet1D(ResidualBlock, [2,2,2,2], num_classes=num_classes, n_leads=n_leads)


In [35]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import numpy as np
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score
import wfdb
from scipy.signal import butter, filtfilt
import os

# ===============================
# 1. ECG Dataset Loader
# ===============================
class ECGDataset(Dataset):
    def __init__(self, X, y):
        self.X = torch.tensor(X, dtype=torch.float32)
        self.y = torch.tensor(y, dtype=torch.long)
    def __len__(self):
        return len(self.X)
    def __getitem__(self, idx):
        return self.X[idx], self.y[idx]

# ===============================
# 2. Preprocessing
# ===============================
def bandpass_filter(signal, lowcut=0.5, highcut=40, fs=360, order=5):
    nyquist = 0.5 * fs
    b, a = butter(order, [lowcut/nyquist, highcut/nyquist], btype="band")
    return filtfilt(b, a, signal)

def load_mitbih(path, segment_length=1800):
    """
    Load MIT-BIH arrhythmia dataset.
    - path: path to dataset directory
    - segment_length: number of samples per segment (default 5s at 360 Hz)
    """
    X, y = [], []
    records = [f[:-4] for f in os.listdir(path) if f.endswith(".dat")]

    for rec in records:
        sig, fields = wfdb.rdsamp(os.path.join(path, rec))
        ann = wfdb.rdann(os.path.join(path, rec), "atr")

        # use lead 0
        lead = sig[:,0]
        lead = bandpass_filter(lead, fs=fields["fs"])

        # segment signal
        n_segments = len(lead) // segment_length
        for i in range(n_segments):
            segment = lead[i*segment_length:(i+1)*segment_length]
            if len(segment) == segment_length:
                # duplicate into 12 channels
                segment = np.tile(segment, (12,1))  # shape = (12, L)
                X.append(segment)
                y.append(1 if any(l in ann.symbol for l in ["V","A","L","R"]) else 0)

    X = np.array(X)
    y = np.array(y)
    return X, y

# ===============================
# 3. ResNet Blocks and Model
# ===============================
class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1):
        super(ResidualBlock, self).__init__()
        self.conv1 = nn.Conv1d(in_channels, out_channels, kernel_size=7, stride=stride, padding=3, bias=False)
        self.bn1 = nn.BatchNorm1d(out_channels)
        self.conv2 = nn.Conv1d(out_channels, out_channels, kernel_size=7, stride=1, padding=3, bias=False)
        self.bn2 = nn.BatchNorm1d(out_channels)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_channels != out_channels:
            self.shortcut = nn.Sequential(
                nn.Conv1d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm1d(out_channels)
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        return F.relu(out)

class ResNet1D(nn.Module):
    def __init__(self, block, layers, num_classes=2, n_leads=12):
        super(ResNet1D, self).__init__()
        self.in_channels = 64
        self.conv1 = nn.Conv1d(n_leads, 64, kernel_size=7, stride=2, padding=3, bias=False)
        self.bn1 = nn.BatchNorm1d(64)
        self.layer1 = self._make_layer(block, 64, layers[0], stride=1)
        self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
        self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
        self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
        self.avg_pool = nn.AdaptiveAvgPool1d(1)
        self.fc = nn.Linear(512, num_classes)

    def _make_layer(self, block, out_channels, blocks, stride):
        strides = [stride] + [1]*(blocks-1)
        layers = []
        for s in strides:
            layers.append(block(self.in_channels, out_channels, s))
            self.in_channels = out_channels
        return nn.Sequential(*layers)

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = self.avg_pool(out)
        out = out.view(out.size(0), -1)
        return self.fc(out)

def ResNet18_1D(num_classes=2, n_leads=12):
    return ResNet1D(ResidualBlock, [2,2,2,2], num_classes=num_classes, n_leads=n_leads)

# ===============================
# 4. Main Testing
# ===============================
if __name__ == "__main__":
    path = "mit-bih-arrhythmia-database-1.0.0/" 
    X, y = load_mitbih(path)

    # Dataset
    test_ds = ECGDataset(X, y)
    test_loader = DataLoader(test_ds, batch_size=32)

    # Load model
    device = "cuda" if torch.cuda.is_available() else "cpu"
    model = ResNet18_1D(num_classes=2, n_leads=12)
    model.load_state_dict(torch.load("ptb-xl/models/second.pth", map_location=device))
    model.to(device)
    model.eval()

    # Evaluation
    preds, probs, targets = [], [], []
    with torch.no_grad():
        for Xb, yb in test_loader:
            Xb, yb = Xb.to(device), yb.to(device)
            out = model(Xb)
            prob = F.softmax(out, dim=1)[:,1].cpu().numpy()
            pred = torch.argmax(out, dim=1).cpu().numpy()
            preds.extend(pred)
            probs.extend(prob)
            targets.extend(yb.cpu().numpy())

print(classification_report(targets, preds, target_names=["Normal", "Arrhythmia"]))

              precision    recall  f1-score   support

      Normal       0.31      0.33      0.32      1083
  Arrhythmia       0.96      0.95      0.95     16245

    accuracy                           0.91     17328
   macro avg       0.63      0.64      0.64     17328
weighted avg       0.91      0.91      0.91     17328



# **1dcnn**

In [26]:
import torch
import torch.nn as nn
import numpy as np
import wfdb
import os
from torch.utils.data import Dataset, DataLoader
from sklearn.metrics import classification_report

In [31]:
# ======================
# 1. Load Your Model
# ======================
class CNN1D(nn.Module):
    def __init__(self, num_classes=2):
        super(CNN1D, self).__init__()
        self.features = nn.Sequential(
            nn.Conv1d(12, 32, kernel_size=7, padding=3),  # 12-lead input
            nn.BatchNorm1d(32),
            nn.ReLU(),
            nn.MaxPool1d(2),

            nn.Conv1d(32, 64, kernel_size=5, padding=2),
            nn.BatchNorm1d(64),
            nn.ReLU(),
            nn.MaxPool1d(2),

            nn.Conv1d(64, 128, kernel_size=3, padding=1),
            nn.BatchNorm1d(128),
            nn.ReLU(),
            nn.AdaptiveAvgPool1d(1)  # makes it [batch, 128, 1]
        )
        
        self.classifier = nn.Linear(128, num_classes)  # matches checkpoint

    def forward(self, x):
        x = self.features(x)
        x = x.view(x.size(0), -1)  # flatten [batch, 128, 1] -> [batch, 128]
        x = self.classifier(x)
        return x

# -------------------
# Load trained model
# -------------------
model = CNN1D(num_classes=2)
state_dict = torch.load("ptb-xl/models/first.pth", map_location="cpu")
model.load_state_dict(state_dict)
model.eval()
print("✅ Model loaded successfully!")


✅ Model loaded successfully!


In [32]:
# ======================
# 2. MIT-BIH Dataset Loader
# ======================
class MITBIHDataset(Dataset):
    def __init__(self, data_dir, record_list, window_size=500):
        self.samples = []
        self.labels = []
        self.window_size = window_size

        for record in record_list:
            record_path = os.path.join(data_dir, record)
            sig, fields = wfdb.rdsamp(record_path)
            ann = wfdb.rdann(record_path, 'atr')

            sig = sig[:, 0]  # use lead MLII
            sig = (sig - np.mean(sig)) / np.std(sig)

            for idx, sym in zip(ann.sample, ann.symbol):
                if sym in ['N', 'L', 'R']:  # Normal beats
                    label = 0
                elif sym in ['V', 'A']:     # Arrhythmia
                    label = 1
                else:
                    continue

                start = max(0, idx - window_size // 2)
                end = start + window_size
                if end <= len(sig):
                    beat = sig[start:end]
                    self.samples.append(beat)
                    self.labels.append(label)

        self.samples = np.array(self.samples)
        self.labels = np.array(self.labels)

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        x = self.samples[idx]
        y = self.labels[idx]
        x = torch.tensor(x, dtype=torch.float32).unsqueeze(0)  # (1, L)
        return x, y

In [34]:
# ======================
# 3. Run Testing
# ======================
data_dir = "mit-bih-arrhythmia-database-1.0.0"
records = ["100", "101", "102"]  # pick a few records
dataset = MITBIHDataset(data_dir, records)
dataloader = DataLoader(dataset, batch_size=32, shuffle=False)

y_true, y_pred = [], []

with torch.no_grad():
    for x, y in dataloader:
        # x shape: [batch, 1, length]
        x = x.repeat(1, 12, 1)   # expand single-lead -> 12 leads
        outputs = model(x)
        preds = torch.argmax(outputs, dim=1)
        y_true.extend(y.numpy())
        y_pred.extend(preds.numpy())


print(classification_report(y_true, y_pred, target_names=["Normal", "Arrhythmia"]))

              precision    recall  f1-score   support

      Normal       0.98      0.11      0.20      4196
  Arrhythmia       0.01      0.78      0.02        41

    accuracy                           0.11      4237
   macro avg       0.49      0.44      0.11      4237
weighted avg       0.97      0.11      0.19      4237

