In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, classification_report
from torch.utils.data import Dataset, DataLoader
import os

# # --- CONFIGURATION ---
# # Set your local folder path where the CSV files are stored
# local_folder = r"ENTER_YOUR_EXISTING_FOLDER_PATH_HERE"  # <<<--- CHANGE THIS, e.g., r"C:\Users\saura\Downloads\heartbeat"
# os.chdir(local_folder)  # Change working directory to make relative loads work
# print("Working in directory:", os.getcwd())
# print("Files available:", os.listdir())

# Device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

# SNN Components (unchanged)
class SpikeFunction(torch.autograd.Function):
    @staticmethod
    def forward(ctx, mem, threshold):
        spike = (mem >= threshold).float()
        ctx.save_for_backward(mem - threshold)
        return spike

    @staticmethod
    def backward(ctx, grad_output):
        mem_minus_th, = ctx.saved_tensors
        surrogate = 1 / (1 + (torch.abs(mem_minus_th) * 5) ** 2)
        return grad_output * surrogate, None

class LIF(nn.Module):
    def __init__(self, in_features, out_features, threshold=1.0, decay=0.99):
        super(LIF, self).__init__()
        self.fc = nn.Linear(in_features, out_features)
        self.threshold = threshold
        self.decay = decay

    def forward(self, x):
        batch_size = x.size(0)
        time_steps = x.size(1)
        mem = torch.zeros(batch_size, self.fc.out_features, device=x.device)
        spikes = []
        for t in range(time_steps):
            curr = self.fc(x[:, t, :])
            mem = self.decay * mem + curr
            spike = SpikeFunction.apply(mem, self.threshold)
            mem = mem - spike * self.threshold
            spikes.append(spike)
        return torch.stack(spikes, dim=1)

class SNN(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(SNN, self).__init__()
        self.lif1 = LIF(input_size, hidden_size)
        self.lif2 = LIF(hidden_size, output_size)

    def forward(self, x):
        spk1 = self.lif1(x)
        spk2 = self.lif2(spk1)
        return spk2

class ECGDataset(Dataset):
    def __init__(self, signals, labels=None):
        self.signals = signals
        self.labels = labels

    def __len__(self):
        return len(self.signals)

    def __getitem__(self, idx):
        if self.labels is not None:
            return self.signals[idx], self.labels[idx]
        return self.signals[idx]

def normalize_signals(signals):
    min_vals = signals.min(axis=1, keepdims=True)
    max_vals = signals.max(axis=1, keepdims=True)
    return (signals - min_vals) / (max_vals - min_vals + 1e-5)

# --- STAGE 1: Building and Training Beat Classifier (Model A) on MIT-BIH Data ---
print("\n--- STAGE 1: Building and Training Beat Classifier (Model A) on MIT-BIH Data ---")
try:
    mit_train_df = pd.read_csv("mitbih_train.csv", header=None)
    mit_test_df = pd.read_csv("mitbih_test.csv", header=None)
    mit_df = pd.concat([mit_train_df, mit_test_df], axis=0)
    print("MIT-BIH dataset loaded successfully.")
    print(f"Full MIT shape: {mit_df.shape}")
except FileNotFoundError:
    print("ERROR: Make sure 'mitbih_train.csv' and 'mitbih_test.csv' are in the same folder.")
    print("If named differently (e.g., mitdb_*), rename them before running.")
    exit()

# Process full MIT for training Model A (multi-class beat classification)
# Note: Last column is label, first 187 are signals
signals_a = mit_df.iloc[:, :-1].values.astype(np.float32)
labels_a = mit_df.iloc[:, -1].values.astype(int)
signals_a = normalize_signals(signals_a)

# Split for train/val (80/20 on full data)
X_train_a, X_val_a, y_train_a, y_val_a = train_test_split(signals_a, labels_a, test_size=0.2, random_state=42, stratify=labels_a)

# Tensors
X_train_a_tensor = torch.tensor(X_train_a).unsqueeze(-1).to(device)
X_val_a_tensor = torch.tensor(X_val_a).unsqueeze(-1).to(device)
y_train_a_tensor = torch.tensor(y_train_a).to(device)
y_val_a_tensor = torch.tensor(y_val_a).to(device)

train_ds_a = ECGDataset(X_train_a_tensor, y_train_a_tensor)
val_ds_a = ECGDataset(X_val_a_tensor, y_val_a_tensor)
train_loader_a = DataLoader(train_ds_a, batch_size=32, shuffle=True)
val_loader_a = DataLoader(val_ds_a, batch_size=32)

# Model A: Multi-Class Beat Classifier (5 classes)
model_a = SNN(input_size=1, hidden_size=256, output_size=5).to(device)
optimizer_a = optim.Adam(model_a.parameters(), lr=0.001)
loss_fn = nn.CrossEntropyLoss()

num_epochs = 15
for epoch in range(num_epochs):
    model_a.train()
    total_loss = 0
    for sig, lab in train_loader_a:
        spk = model_a(sig)
        rates = spk.sum(1)
        loss = loss_fn(rates, lab)
        optimizer_a.zero_grad()
        loss.backward()
        optimizer_a.step()
        total_loss += loss.item()
    print(f"Model A Epoch {epoch+1}/{num_epochs}, Train Loss: {total_loss / len(train_loader_a):.4f}")

    # Val
    model_a.eval()
    val_preds = []
    val_true = []
    with torch.no_grad():
        for sig, lab in val_loader_a:
            spk = model_a(sig)
            rates = spk.sum(1)
            pred = torch.argmax(rates, dim=1)
            val_preds.extend(pred.cpu().numpy())
            val_true.extend(lab.cpu().numpy())
    val_acc = accuracy_score(val_true, val_preds)
    print(f"Model A Val Accuracy: {val_acc:.4f}")

print("Stage 1 Complete: Beat Classifier (Model A) Trained.")
print(classification_report(val_true, val_preds, target_names=['N', 'S', 'V', 'F', 'Q']))  # MIT classes

# --- STAGE 2: Building and Training Anomaly Detector (Model B) on PTBDB Data ---
print("\n--- STAGE 2: Building and Training Anomaly Detector (Model B) on PTBDB Data ---")
try:
    normal_df = pd.read_csv("ptbdb_normal.csv", header=None)
    abnormal_df = pd.read_csv("ptbdb_abnormal.csv", header=None)
    ptbdb_df = pd.concat([normal_df, abnormal_df], axis=0)
    print("PTBDB dataset loaded successfully.")
    print(f"Full PTBDB shape: {ptbdb_df.shape}")
except FileNotFoundError:
    print("ERROR: Make sure 'ptbdb_normal.csv' and 'ptbdb_abnormal.csv' are in the folder.")
    exit()

# Process PTBDB for binary anomaly detection
signals_b = ptbdb_df.iloc[:, :-1].values.astype(np.float32)
labels_b = ptbdb_df.iloc[:, -1].values.astype(int)
signals_b = normalize_signals(signals_b)

# Split
X_train_b, X_val_b, y_train_b, y_val_b = train_test_split(signals_b, labels_b, test_size=0.2, random_state=42, stratify=labels_b)

# Tensors
X_train_b_tensor = torch.tensor(X_train_b).unsqueeze(-1).to(device)
X_val_b_tensor = torch.tensor(X_val_b).unsqueeze(-1).to(device)
y_train_b_tensor = torch.tensor(y_train_b).to(device)
y_val_b_tensor = torch.tensor(y_val_b).to(device)

train_ds_b = ECGDataset(X_train_b_tensor, y_train_b_tensor)
val_ds_b = ECGDataset(X_val_b_tensor, y_val_b_tensor)
train_loader_b = DataLoader(train_ds_b, batch_size=32, shuffle=True)
val_loader_b = DataLoader(val_ds_b, batch_size=32)

# Model B: Binary Anomaly Detector
model_b = SNN(input_size=1, hidden_size=128, output_size=2).to(device)
optimizer_b = optim.Adam(model_b.parameters(), lr=0.001)

for epoch in range(num_epochs):
    model_b.train()
    total_loss = 0
    for sig, lab in train_loader_b:
        spk = model_b(sig)
        rates = spk.sum(1)
        loss = loss_fn(rates, lab)
        optimizer_b.zero_grad()
        loss.backward()
        optimizer_b.step()
        total_loss += loss.item()
    print(f"Model B Epoch {epoch+1}/{num_epochs}, Train Loss: {total_loss / len(train_loader_b):.4f}")

    # Val
    model_b.eval()
    val_preds = []
    val_true = []
    with torch.no_grad():
        for sig, lab in val_loader_b:
            spk = model_b(sig)
            rates = spk.sum(1)
            pred = torch.argmax(rates, dim=1)
            val_preds.extend(pred.cpu().numpy())
            val_true.extend(lab.cpu().numpy())
    val_acc = accuracy_score(val_true, val_preds)
    print(f"Model B Val Accuracy: {val_acc:.4f}")

print("Stage 2 Complete: Anomaly Detector (Model B) Trained.")
print(classification_report(val_true, val_preds, target_names=['Normal', 'Abnormal']))

# --- OPTIONAL: Separate Inference on Original MIT Test (if needed for submission) ---
print("\n--- Inference on Original MIT-BIH Test Split using Model A ---")
# Reload test separately if needed
mit_test_df = pd.read_csv("mitbih_test.csv", header=None)
signals_test = mit_test_df.values.astype(np.float32)  # assuming no label in original test
signals_test = normalize_signals(signals_test)
signals_test_tensor = torch.tensor(signals_test).unsqueeze(-1).to(device)

test_ds = ECGDataset(signals_test_tensor)
test_loader = DataLoader(test_ds, batch_size=32)

model_a.eval()
test_preds = []
with torch.no_grad():
    for sig in test_loader:
        spk = model_a(sig)
        rates = spk.sum(1)
        pred = torch.argmax(rates, dim=1)
        test_preds.extend(pred.cpu().numpy())
print("Sample MIT Test Predictions (first 10):", test_preds[:10])

print("\nAll Done! Models trained on respective datasets. For cascaded use: Apply Model B first to detect anomaly, then Model A for beat details if anomalous.")

Using device: cpu

--- STAGE 1: Building and Training Beat Classifier (Model A) on MIT-BIH Data ---
MIT-BIH dataset loaded successfully.
Full MIT shape: (109446, 188)
Model A Epoch 1/15, Train Loss: 0.7784
Model A Val Accuracy: 0.8277
Model A Epoch 2/15, Train Loss: 0.7216
Model A Val Accuracy: 0.8276
Model A Epoch 3/15, Train Loss: 0.7301
Model A Val Accuracy: 0.8277
Model A Epoch 4/15, Train Loss: 0.7115
Model A Val Accuracy: 0.8277
Model A Epoch 5/15, Train Loss: 0.7164
Model A Val Accuracy: 0.8277
Model A Epoch 6/15, Train Loss: 0.6928
Model A Val Accuracy: 0.8288
Model A Epoch 7/15, Train Loss: 0.6699
Model A Val Accuracy: 0.8271
Model A Epoch 8/15, Train Loss: 0.6535
Model A Val Accuracy: 0.8254
Model A Epoch 9/15, Train Loss: 0.6252
Model A Val Accuracy: 0.8514
Model A Epoch 10/15, Train Loss: 0.5848
Model A Val Accuracy: 0.8513
Model A Epoch 11/15, Train Loss: 0.5547
Model A Val Accuracy: 0.8410
Model A Epoch 12/15, Train Loss: 0.5486
Model A Val Accuracy: 0.8488
Model A Epoch 

  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])


PTBDB dataset loaded successfully.
Full PTBDB shape: (14552, 188)
Model B Epoch 1/15, Train Loss: 0.8914
Model B Val Accuracy: 0.5373
Model B Epoch 2/15, Train Loss: 0.6149
Model B Val Accuracy: 0.4253
Model B Epoch 3/15, Train Loss: 0.6070
Model B Val Accuracy: 0.5070
Model B Epoch 4/15, Train Loss: 0.5902
Model B Val Accuracy: 0.7176
Model B Epoch 5/15, Train Loss: 0.5723
Model B Val Accuracy: 0.6788
Model B Epoch 6/15, Train Loss: 0.5718
Model B Val Accuracy: 0.7224
Model B Epoch 7/15, Train Loss: 0.5679
Model B Val Accuracy: 0.7217
Model B Epoch 8/15, Train Loss: 0.5692
Model B Val Accuracy: 0.4885
Model B Epoch 9/15, Train Loss: 0.5660
Model B Val Accuracy: 0.6853
Model B Epoch 10/15, Train Loss: 0.5616
Model B Val Accuracy: 0.6482
Model B Epoch 11/15, Train Loss: 0.5643
Model B Val Accuracy: 0.5019
Model B Epoch 12/15, Train Loss: 0.5459
Model B Val Accuracy: 0.5854
Model B Epoch 13/15, Train Loss: 0.5669
Model B Val Accuracy: 0.7272
Model B Epoch 14/15, Train Loss: 0.5401
Model 

In [9]:
import torch
import torch.nn as nn
import torch.optim as optim
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix
from torch.utils.data import Dataset, DataLoader
import os
import matplotlib.pyplot as plt
import seaborn as sns
from collections import Counter
import json

# --- CONFIGURATION ---

print("Working in directory:", os.getcwd())

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

# SNN Classes (same as previous improved version)
class SpikeFunction(torch.autograd.Function):
    @staticmethod
    def forward(ctx, mem, threshold):
        spike = (mem >= threshold).float()
        ctx.save_for_backward(mem - threshold)
        return spike

    @staticmethod
    def backward(ctx, grad_output):
        mem_minus_th, = ctx.saved_tensors
        surrogate = torch.exp(- (mem_minus_th ** 2) / 0.1)
        return grad_output * surrogate, None

class LIF(nn.Module):
    def __init__(self, in_features, out_features, threshold=1.0, decay=0.95):
        super(LIF, self).__init__()
        self.fc = nn.Linear(in_features, out_features)
        self.threshold = threshold
        self.decay = decay

    def forward(self, x):
        batch_size = x.size(0)
        time_steps = x.size(1)
        mem = torch.zeros(batch_size, self.fc.out_features, device=x.device)
        spikes = []
        for t in range(time_steps):
            curr = self.fc(x[:, t, :])
            mem = self.decay * mem + curr
            spike = SpikeFunction.apply(mem, self.threshold)
            mem = mem - spike * self.threshold
            spikes.append(spike)
        return torch.stack(spikes, dim=1)

class SNN(nn.Module):
    def __init__(self, input_size, hidden_size1, hidden_size2, output_size):
        super(SNN, self).__init__()
        self.lif1 = LIF(input_size, hidden_size1)
        self.lif2 = LIF(hidden_size1, hidden_size2)
        self.lif_out = LIF(hidden_size2, output_size)

    def forward(self, x):
        spk1 = self.lif1(x)
        spk2 = self.lif2(spk1)
        spk_out = self.lif_out(spk2)
        return spk_out

class ECGDataset(Dataset):
    def __init__(self, signals, labels=None):
        self.signals = signals
        self.labels = labels

    def __len__(self):
        return len(self.signals)

    def __getitem__(self, idx):
        if self.labels is not None:
            return self.signals[idx], self.labels[idx]
        return self.signals[idx]

def normalize_signals(signals):
    min_vals = signals.min(axis=1, keepdims=True)
    max_vals = signals.max(axis=1, keepdims=True)
    return (signals - min_vals) / (max_vals - min_vals + 1e-5)

# --- LOAD TRAINED MODELS (Assuming you have the .pth files from previous training) ---
# If you trained in a previous session, save with: torch.save(model.state_dict(), 'model_a.pth')
# Here, we'll assume you need to retrain briefly or load if exists. To dump/save, we'll train and save.

print("\n--- Loading or Retraining Models to Save and Evaluate ---")

# STAGE 1: Model A (Beat Classifier on MIT-BIH)
print("--- Handling Model A ---")
try:
    mit_train_df = pd.read_csv("mitbih_train.csv", header=None)
    print("MIT-BIH train loaded.")
except FileNotFoundError:
    print("ERROR: File not found.")
    exit()

signals_a = mit_train_df.iloc[:, :-1].values.astype(np.float32)
labels_a = mit_train_df.iloc[:, -1].values.astype(int)
signals_a = normalize_signals(signals_a)

class_counts = Counter(labels_a)
class_weights = torch.tensor([1.0 / class_counts[i] if class_counts[i] > 0 else 0 for i in range(5)], dtype=torch.float).to(device)
loss_fn_a = nn.CrossEntropyLoss(weight=class_weights)

X_train_a, X_val_a, y_train_a, y_val_a = train_test_split(signals_a, labels_a, test_size=0.2, random_state=42, stratify=labels_a)
X_train_a_tensor = torch.tensor(X_train_a).unsqueeze(-1).to(device)
X_val_a_tensor = torch.tensor(X_val_a).unsqueeze(-1).to(device)
y_train_a_tensor = torch.tensor(y_train_a).to(device)
y_val_a_tensor = torch.tensor(y_val_a).to(device)

train_ds_a = ECGDataset(X_train_a_tensor, y_train_a_tensor)
val_ds_a = ECGDataset(X_val_a_tensor, y_val_a_tensor)
train_loader_a = DataLoader(train_ds_a, batch_size=64, shuffle=True)
val_loader_a = DataLoader(val_ds_a, batch_size=64)

model_a = SNN(input_size=1, hidden_size1=256, hidden_size2=128, output_size=5).to(device)

# Load if exists, else train briefly (5 epochs to simulate, adjust if you have saved weights)
model_a_path = 'model_a.pth'
if os.path.exists(model_a_path):
    model_a.load_state_dict(torch.load(model_a_path))
    print("Loaded Model A from disk.")
else:
    print("No saved Model A found. Training for 5 epochs to create and save...")
    optimizer_a = optim.Adam(model_a.parameters(), lr=0.0005)
    for epoch in range(5):
        model_a.train()
        for sig, lab in train_loader_a:
            spk = model_a(sig)
            rates = spk.sum(1)
            loss = loss_fn_a(rates, lab)
            optimizer_a.zero_grad()
            loss.backward()
            optimizer_a.step()
    torch.save(model_a.state_dict(), model_a_path)
    print("Trained briefly and saved Model A to 'model_a.pth'")

# Evaluate Model A on val
model_a.eval()
val_preds_a = []
val_true_a = []
with torch.no_grad():
    for sig, lab in val_loader_a:
        spk = model_a(sig)
        rates = spk.sum(1)
        pred = torch.argmax(rates, dim=1)
        val_preds_a.extend(pred.cpu().numpy())
        val_true_a.extend(lab.cpu().numpy())
acc_a = accuracy_score(val_true_a, val_preds_a)
print(f"Model A Val Accuracy: {acc_a:.4f}")
print(classification_report(val_true_a, val_preds_a, target_names=['N', 'S', 'V', 'F', 'Q'], zero_division=0))

# Confusion Matrix for A
cm_a = confusion_matrix(val_true_a, val_preds_a)
plt.figure(figsize=(8, 6))
sns.heatmap(cm_a, annot=True, fmt='d', cmap='Blues', xticklabels=['N', 'S', 'V', 'F', 'Q'], yticklabels=['N', 'S', 'V', 'F', 'Q'])
plt.title('Model A Confusion Matrix')
plt.xlabel('Predicted')
plt.ylabel('True')
plt.savefig('model_a_confusion_matrix.png')
plt.show()

# Per-class accuracy bar
class_acc_a = cm_a.diagonal() / cm_a.sum(axis=1)
plt.figure(figsize=(8, 5))
sns.barplot(x=['N', 'S', 'V', 'F', 'Q'], y=class_acc_a)
plt.title('Model A Per-Class Accuracy')
plt.ylim(0, 1)
plt.savefig('model_a_per_class_acc.png')
plt.show()

# STAGE 2: Model B (Anomaly Detector on PTBDB)
print("\n--- Handling Model B ---")
try:
    normal_df = pd.read_csv("ptbdb_normal.csv", header=None)
    abnormal_df = pd.read_csv("ptbdb_abnormal.csv", header=None)
    print("PTBDB loaded.")
except FileNotFoundError:
    print("ERROR: File not found.")
    exit()

signals_b = np.vstack((normal_df.iloc[:, :-1].values, abnormal_df.iloc[:, :-1].values)).astype(np.float32)
labels_b = np.hstack((normal_df.iloc[:, -1].values, abnormal_df.iloc[:, -1].values)).astype(int)
signals_b = normalize_signals(signals_b)

class_counts_b = Counter(labels_b)
class_weights_b = torch.tensor([1.0 / class_counts_b[i] for i in range(2)], dtype=torch.float).to(device)
loss_fn_b = nn.CrossEntropyLoss(weight=class_weights_b)

X_train_b, X_val_b, y_train_b, y_val_b = train_test_split(signals_b, labels_b, test_size=0.2, random_state=42, stratify=labels_b)
X_train_b_tensor = torch.tensor(X_train_b).unsqueeze(-1).to(device)
X_val_b_tensor = torch.tensor(X_val_b).unsqueeze(-1).to(device)
y_train_b_tensor = torch.tensor(y_train_b).to(device)
y_val_b_tensor = torch.tensor(y_val_b).to(device)

train_ds_b = ECGDataset(X_train_b_tensor, y_train_b_tensor)
val_ds_b = ECGDataset(X_val_b_tensor, y_val_b_tensor)
train_loader_b = DataLoader(train_ds_b, batch_size=64, shuffle=True)
val_loader_b = DataLoader(val_ds_b, batch_size=64)

model_b = SNN(input_size=1, hidden_size1=256, hidden_size2=128, output_size=2).to(device)

model_b_path = 'model_b.pth'
if os.path.exists(model_b_path):
    model_b.load_state_dict(torch.load(model_b_path))
    print("Loaded Model B from disk.")
else:
    print("No saved Model B found. Training for 5 epochs to create and save...")
    optimizer_b = optim.Adam(model_b.parameters(), lr=0.0005)
    for epoch in range(5):
        model_b.train()
        for sig, lab in train_loader_b:
            spk = model_b(sig)
            rates = spk.sum(1)
            loss = loss_fn_b(rates, lab)
            optimizer_b.zero_grad()
            loss.backward()
            optimizer_b.step()
    torch.save(model_b.state_dict(), model_b_path)
    print("Trained briefly and saved Model B to 'model_b.pth'")

# Evaluate Model B on val
model_b.eval()
val_preds_b = []
val_true_b = []
with torch.no_grad():
    for sig, lab in val_loader_b:
        spk = model_b(sig)
        rates = spk.sum(1)
        pred = torch.argmax(rates, dim=1)
        val_preds_b.extend(pred.cpu().numpy())
        val_true_b.extend(lab.cpu().numpy())
acc_b = accuracy_score(val_true_b, val_preds_b)
print(f"Model B Val Accuracy: {acc_b:.4f}")
print(classification_report(val_true_b, val_preds_b, target_names=['Normal', 'Abnormal'], zero_division=0))

# Confusion Matrix for B
cm_b = confusion_matrix(val_true_b, val_preds_b)
plt.figure(figsize=(6, 5))
sns.heatmap(cm_b, annot=True, fmt='d', cmap='Greens', xticklabels=['Normal', 'Abnormal'], yticklabels=['Normal', 'Abnormal'])
plt.title('Model B Confusion Matrix')
plt.xlabel('Predicted')
plt.ylabel('True')
plt.savefig('model_b_confusion_matrix.png')
plt.show()

# Per-class accuracy bar for B
class_acc_b = cm_b.diagonal() / cm_b.sum(axis=1)
plt.figure(figsize=(6, 4))
sns.barplot(x=['Normal', 'Abnormal'], y=class_acc_b)
plt.title('Model B Per-Class Accuracy')
plt.ylim(0, 1)
plt.savefig('model_b_per_class_acc.png')
plt.show()

# --- Dump More Metrics to Files ---
metrics_a = {
    'accuracy': acc_a,
    'classification_report': classification_report(val_true_a, val_preds_a, output_dict=True, zero_division=0)
}
with open('model_a_metrics.json', 'w') as f:
    json.dump(metrics_a, f, indent=4)

metrics_b = {
    'accuracy': acc_b,
    'classification_report': classification_report(val_true_b, val_preds_b, output_dict=True, zero_division=0)
}
with open('model_b_metrics.json', 'w') as f:
    json.dump(metrics_b, f, indent=4)

print("\nModels dumped to 'model_a.pth' and 'model_b.pth'.")
print("Graphs saved as PNGs: confusion matrices and per-class acc.")
print("Metrics dumped to JSON files.")
print("To load later: model.load_state_dict(torch.load('path.pth'))")

Working in directory: c:\Users\saura\Documents\Major Project\Implementing-SNN
Using device: cpu

--- Loading or Retraining Models to Save and Evaluate ---
--- Handling Model A ---
MIT-BIH train loaded.
No saved Model A found. Training for 5 epochs to create and save...


KeyboardInterrupt: 