In [None]:
"""
Training SNN untuk ECG dengan snnTorch (3-way split: Train / Val / Test)
------------------------------------------------------------------------

Pipeline:
1. Baca dataset spike encoding dari folder:
       dataset_ecg_encoded/<class>/beat_xxxx.npz

2. Dataset berisi channel:
   - dm_up        : [T]
   - dm_down      : [T]
   - lc_up        : [T]   (level crossing LC+ 2-channel)
   - lc_down      : [T]   (LC-)
   - energy_spikes: [T]

   Disusun menjadi spike tensor [T, C]:
   - channel 0 : dm_up
   - channel 1 : dm_down
   - channel 2 : lc_up
   - channel 3 : lc_down
   - channel 4 : energy_spikes

3. Model SNN (snnTorch):
   Input C  → FC(30) + LIF (trainable leak & threshold)
            → FC(30) + LIF
            → FC(5)  + LIF

   - Reset               : "subtract"
   - Surrogate gradient  : atan
   - Output              : spike count per kelas
   - Klasifikasi         : argmax(spike_count)

4. Dataset split (stratified per kelas):
   - Train: 70%
   - Val  : 15%
   - Test : 15%

5. Training:
   - Loss        : CrossEntropy(spike_counts, label)
   - Optimizer   : Adam
   - Early stopping berdasarkan validation loss

6. Output file:
   - snn_ecg_snntorch_30_30_best.pt           (model terbaik)
   - training_curves_snntorch.png             (loss + accuracy)
   - confusion_matrix_val_snntorch.png        (confusion matrix validation)
   - confusion_matrix_test_snntorch.png       (confusion matrix test)
"""

import os
import glob
import random
from typing import List, Tuple, Dict

import numpy as np
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

import snntorch as snn
from snntorch import surrogate

from sklearn.metrics import confusion_matrix, classification_report


# ============================================================
# 1. Label mapping AAMI
# ============================================================

AAMI_LABEL_TO_IDX: Dict[str, int] = {
    "N": 0,
    "S": 1,
    "V": 2,
    "F": 3,
    "Q": 4,
}
IDX_TO_AAMI_LABEL: Dict[int, str] = {v: k for k, v in AAMI_LABEL_TO_IDX.items()}


# ============================================================
# 2. Dataset: membaca spike encoding dari .npz
# ============================================================

class SpikeECGDataset(Dataset):
    """
    Dataset untuk membaca spike encoding ECG dari folder.

    Struktur folder:
      root_dir/
          N/*.npz
          S/*.npz
          V/*.npz
          F/*.npz
          Q/*.npz

    Setiap file .npz minimal punya:
      - dm_up        : [T]
      - dm_down      : [T]
      - lc_up        : [T]
      - lc_down      : [T]
      - energy_spikes: [T]

    Output __getitem__:
      - spikes: Tensor [T, C]  (T = time steps, C = channels)
      - label : int (0..4) sesuai AAMI_LABEL_TO_IDX

    Catatan:
      - Properti self.samples berisi list (filepath, cls_str) → dipakai
        untuk stratified split di fungsi split_dataset_3way.
    """

    def __init__(self, root_dir: str, classes: List[str] = None):
        super().__init__()
        self.root_dir = root_dir
        if classes is None:
            classes = ["N", "S", "V", "F", "Q"]
        # pastikan uppercase
        self.classes = [c.upper() for c in classes]

        # List (filepath, label_str)
        self.samples: List[Tuple[str, str]] = []

        for cls in self.classes:
            cls_dir = os.path.join(root_dir, cls)
            if not os.path.isdir(cls_dir):
                continue
            files = sorted(glob.glob(os.path.join(cls_dir, "*.npz")))
            for fp in files:
                self.samples.append((fp, cls))

        if not self.samples:
            raise RuntimeError(f"Tidak ada file .npz di {root_dir}")

    def __len__(self) -> int:
        return len(self.samples)

    def __getitem__(self, idx: int):
        fp, cls = self.samples[idx]
        data = np.load(fp)

        # Ambil channel spike encoding (semua [T])
        dm_up = data["dm_up"].astype(np.float32)          # [T]
        dm_down = data["dm_down"].astype(np.float32)      # [T]
        lc_up = data["lc_up"].astype(np.float32)          # [T]
        lc_down = data["lc_down"].astype(np.float32)      # [T]
        energy_spikes = data["energy_spikes"].astype(np.float32)  # [T]

        # Bentuk [C, T] dengan urutan:
        #   0 : dm_up
        #   1 : dm_down
        #   2 : lc_up
        #   3 : lc_down
        #   4 : energy_spikes
        dm_up_ch = dm_up[None, :]          # [1, T]
        dm_down_ch = dm_down[None, :]      # [1, T]
        lc_up_ch = lc_up[None, :]          # [1, T]
        lc_down_ch = lc_down[None, :]      # [1, T]
        energy_ch = energy_spikes[None, :] # [1, T]

        spikes_np = np.concatenate(
            [dm_up_ch, dm_down_ch, lc_up_ch, lc_down_ch, energy_ch],
            axis=0
        )  # [C, T]

        # Convert ke tensor [T, C] (time-major untuk SNN)
        spikes = torch.from_numpy(spikes_np.T)  # [T, C]

        # Label ke index
        y = AAMI_LABEL_TO_IDX[cls]
        return spikes, y


# ============================================================
# 3. Stratified 3-way split: Train / Val / Test
# ============================================================

def split_dataset_3way(
    dataset: SpikeECGDataset,
    val_ratio: float = 0.15,
    test_ratio: float = 0.15,
    seed: int = 42,
):
    """
    Split dataset menjadi train / val / test dengan stratifikasi per kelas.

    - val_ratio dan test_ratio digunakan per kelas
    - Train ratio = 1 - val_ratio - test_ratio
    """
    rng = random.Random(seed)

    # Kumpulkan index berdasarkan kelas string (N, S, V, F, Q)
    class_to_indices: Dict[str, List[int]] = {c: [] for c in AAMI_LABEL_TO_IDX.keys()}

    for idx in range(len(dataset)):
        _, cls = dataset.samples[idx]  # (filepath, cls_str)
        class_to_indices[cls].append(idx)

    train_idx: List[int] = []
    val_idx: List[int] = []
    test_idx: List[int] = []

    for cls, indices in class_to_indices.items():
        if not indices:
            continue

        rng.shuffle(indices)

        n = len(indices)
        n_val = int(n * val_ratio)
        n_test = int(n * test_ratio)

        cls_val = indices[:n_val]
        cls_test = indices[n_val:n_val + n_test]
        cls_train = indices[n_val + n_test:]

        train_idx.extend(cls_train)
        val_idx.extend(cls_val)
        test_idx.extend(cls_test)

        print(
            f"Class {cls}: total={n}, "
            f"train={len(cls_train)}, val={len(cls_val)}, test={len(cls_test)}"
        )

    # Buat subset
    train_set = torch.utils.data.Subset(dataset, train_idx)
    val_set = torch.utils.data.Subset(dataset, val_idx)
    test_set = torch.utils.data.Subset(dataset, test_idx)

    return train_set, val_set, test_set


# ============================================================
# 4. SNN dengan snnTorch (LIF + atan surrogate)
# ============================================================

class SNN_Snntorch_ECG(nn.Module):
    """
    Arsitektur SNN:

    Input C  → FC(30) + Leaky LIF
             → FC(30) + Leaky LIF
             → FC(5)  + Leaky LIF

    - Neuron: snn.Leaky dengan:
        * beta           : leak
        * threshold      : ambang
        * learn_beta     : True
        * learn_threshold: True
        * reset_mechanism: "subtract"
        * spike_grad     : surrogate.atan()
    - Output: spike_count per neuron output (kelas)
    """

    def __init__(self, input_dim: int, hidden1: int = 30, hidden2: int = 30, n_classes: int = 5):
        super().__init__()

        # Layer 1: FC -> LIF
        self.fc1 = nn.Linear(input_dim, hidden1, bias=True)
        self.lif1 = snn.Leaky(
            beta=0.9,
            threshold=1.0,
            spike_grad=surrogate.atan(),   # surrogate gradient arctan
            learn_beta=True,
            learn_threshold=True,
            reset_mechanism="subtract",
        )

        # Layer 2: FC -> LIF
        self.fc2 = nn.Linear(hidden1, hidden2, bias=True)
        self.lif2 = snn.Leaky(
            beta=0.9,
            threshold=1.0,
            spike_grad=surrogate.atan(),
            learn_beta=True,
            learn_threshold=True,
            reset_mechanism="subtract",
        )

        # Layer output: FC -> LIF
        self.fc3 = nn.Linear(hidden2, n_classes, bias=True)
        self.lif3 = snn.Leaky(
            beta=0.9,
            threshold=1.0,
            spike_grad=surrogate.atan(),
            learn_beta=True,
            learn_threshold=True,
            reset_mechanism="subtract",
        )

    def forward(self, spikes: torch.Tensor):
        """
        spikes: [B, T, C]  (batch, time, channel)
        return:
          spike_counts: [B, n_classes]
        """
        B, T, C = spikes.shape
        device = spikes.device

        # snnTorch biasanya pakai [T, B, ...], jadi kita permute
        spikes_t = spikes.permute(1, 0, 2)  # [T, B, C]

        # Inisialisasi membrane potential ke nol di awal sequence
        mem1 = torch.zeros(B, self.fc1.out_features, device=device)
        mem2 = torch.zeros(B, self.fc2.out_features, device=device)
        mem3 = torch.zeros(B, self.fc3.out_features, device=device)

        # Akumulasi spike output untuk klasifikasi
        spk_sum = torch.zeros(B, self.fc3.out_features, device=device)

        for t in range(T):
            cur = spikes_t[t]          # [B, C]

            # Layer 1
            cur = self.fc1(cur)
            spk1, mem1 = self.lif1(cur, mem1)

            # Layer 2
            cur = self.fc2(spk1)
            spk2, mem2 = self.lif2(cur, mem2)

            # Output layer
            cur = self.fc3(spk2)
            spk3, mem3 = self.lif3(cur, mem3)

            # Akumulasi spike di output neuron
            spk_sum += spk3

        # spike_sum sebagai "logits" untuk CrossEntropy
        return spk_sum


# ============================================================
# 5. Training & Evaluation helpers
# ============================================================

def train_one_epoch(
    model: nn.Module,
    loader: DataLoader,
    optimizer: torch.optim.Optimizer,
    device: torch.device,
):
    model.train()
    total_loss = 0.0
    total_correct = 0
    total_samples = 0

    criterion = nn.CrossEntropyLoss()

    for spikes, labels in loader:
        spikes = spikes.to(device)  # [B, T, C]
        labels = labels.to(device)

        optimizer.zero_grad()

        spike_counts = model(spikes)   # [B, 5]
        logits = spike_counts          # gunakan spike count sebagai logits

        loss = criterion(logits, labels)
        loss.backward()
        optimizer.step()

        total_loss += loss.item() * labels.size(0)
        preds = logits.argmax(dim=1)
        total_correct += (preds == labels).sum().item()
        total_samples += labels.size(0)

    avg_loss = total_loss / total_samples
    avg_acc = total_correct / total_samples
    return avg_loss, avg_acc


def eval_one_epoch(
    model: nn.Module,
    loader: DataLoader,
    device: torch.device,
):
    model.eval()
    total_loss = 0.0
    total_correct = 0
    total_samples = 0

    criterion = nn.CrossEntropyLoss()

    with torch.no_grad():
        for spikes, labels in loader:
            spikes = spikes.to(device)
            labels = labels.to(device)

            spike_counts = model(spikes)
            logits = spike_counts
            loss = criterion(logits, labels)

            total_loss += loss.item() * labels.size(0)
            preds = logits.argmax(dim=1)
            total_correct += (preds == labels).sum().item()
            total_samples += labels.size(0)

    avg_loss = total_loss / total_samples
    avg_acc = total_correct / total_samples
    return avg_loss, avg_acc


def get_all_predictions(
    model: nn.Module,
    loader: DataLoader,
    device: torch.device,
):
    model.eval()
    all_preds = []
    all_labels = []

    with torch.no_grad():
        for spikes, labels in loader:
            spikes = spikes.to(device)
            labels = labels.to(device)

            spike_counts = model(spikes)
            logits = spike_counts
            preds = logits.argmax(dim=1)

            all_preds.append(preds.cpu().numpy())
            all_labels.append(labels.cpu().numpy())

    all_preds = np.concatenate(all_preds, axis=0)
    all_labels = np.concatenate(all_labels, axis=0)
    return all_labels, all_preds


def plot_confusion_matrix(
    y_true: np.ndarray,
    y_pred: np.ndarray,
    filename: str,
    title: str,
):
    cm = confusion_matrix(y_true, y_pred, labels=[0, 1, 2, 3, 4])

    plt.figure(figsize=(6, 5))
    im = plt.imshow(cm, interpolation="nearest", cmap="Blues")
    plt.title(title)
    plt.colorbar(im, fraction=0.046, pad=0.04)

    tick_marks = np.arange(5)
    class_names = [IDX_TO_AAMI_LABEL[i] for i in range(5)]
    plt.xticks(tick_marks, class_names)
    plt.yticks(tick_marks, class_names)

    # Anotasi angka di tiap cell
    thresh = cm.max() / 2.0 if cm.max() > 0 else 0.5
    for i in range(cm.shape[0]):
        for j in range(cm.shape[1]):
            plt.text(
                j,
                i,
                str(cm[i, j]),
                horizontalalignment="center",
                color="white" if cm[i, j] > thresh else "black",
            )

    plt.ylabel("True label")
    plt.xlabel("Predicted label")
    plt.tight_layout()
    plt.savefig(filename, dpi=150, bbox_inches="tight")
    print(f"Confusion matrix saved as '{filename}'")
    plt.close()


# ============================================================
# 6. Main: Training loop + plot + val/test confusion matrix
# ============================================================

def main():
    # ---------------- Konfigurasi dasar ----------------
    dataset_dir = "dataset_ecg_encoded"   # folder hasil encoding spike
    batch_size = 32
    max_epochs = 100
    lr = 1e-3
    val_ratio = 0.15
    test_ratio = 0.15
    patience = 10               # early stopping patience
    num_workers = 0

    # Device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device = {device}")

    # 1) Bangun dataset dari hasil generate + encoding tadi
    full_dataset = SpikeECGDataset(dataset_dir, classes=["N", "S", "V", "F", "Q"])

    # Stratified 3-way split
    train_set, val_set, test_set = split_dataset_3way(
        full_dataset,
        val_ratio=val_ratio,
        test_ratio=test_ratio,
        seed=42,
    )

    print(f"\nTotal samples: {len(full_dataset)}")
    print(f"Train: {len(train_set)}, Val: {len(val_set)}, Test: {len(test_set)}\n")

    # Cek shape contoh untuk menentukan input_dim
    spikes0, _ = full_dataset[0]   # [T, C]
    T0, C0 = spikes0.shape
    input_dim = C0
    print(f"Sample shape: T={T0}, C={C0} -> input_dim={input_dim}")

    # DataLoaders
    train_loader = DataLoader(
        train_set,
        batch_size=batch_size,
        shuffle=True,
        num_workers=num_workers,
        drop_last=False,
    )
    val_loader = DataLoader(
        val_set,
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_workers,
        drop_last=False,
    )
    test_loader = DataLoader(
        test_set,
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_workers,
        drop_last=False,
    )

    # 2) Bangun SNN (snnTorch)
    model = SNN_Snntorch_ECG(input_dim=input_dim, hidden1=30, hidden2=30, n_classes=5)
    model.to(device)

    optimizer = torch.optim.Adam(model.parameters(), lr=lr)

    # History untuk plot
    train_loss_hist: List[float] = []
    val_loss_hist: List[float] = []
    train_acc_hist: List[float] = []
    val_acc_hist: List[float] = []

    best_val_loss = float("inf")
    best_state = None
    epochs_no_improve = 0

    # 3) Training loop dengan early stopping (berdasarkan val_loss)
    for epoch in range(1, max_epochs + 1):
        train_loss, train_acc = train_one_epoch(model, train_loader, optimizer, device)
        val_loss, val_acc = eval_one_epoch(model, val_loader, device)

        train_loss_hist.append(train_loss)
        val_loss_hist.append(val_loss)
        train_acc_hist.append(train_acc)
        val_acc_hist.append(val_acc)

        print(
            f"Epoch {epoch:03d}: "
            f"train_loss={train_loss:.4f}, train_acc={train_acc:.3f}, "
            f"val_loss={val_loss:.4f}, val_acc={val_acc:.3f}"
        )

        # Early stopping monitor val_loss
        if val_loss < best_val_loss - 1e-4:
            best_val_loss = val_loss
            best_state = model.state_dict()
            epochs_no_improve = 0
        else:
            epochs_no_improve += 1
            if epochs_no_improve >= patience:
                print(f"Early stopping at epoch {epoch} (no improvement for {patience} epochs).")
                break

    # 4) Restore best model & save
    if best_state is not None:
        model.load_state_dict(best_state)
        torch.save(best_state, "snn_ecg_snntorch_30_30_best.pt")
        print("Best model saved to 'snn_ecg_snntorch_30_30_best.pt'")

    # --------------------------------------------------------
    # 5) Plot training curves (loss & accuracy)
    # --------------------------------------------------------
    epochs_range = range(1, len(train_loss_hist) + 1)

    plt.figure(figsize=(10, 8))

    # Loss curve
    plt.subplot(2, 1, 1)
    plt.plot(epochs_range, train_loss_hist, label="Train loss")
    plt.plot(epochs_range, val_loss_hist, label="Val loss")
    plt.ylabel("Loss")
    plt.title("Training & Validation Loss (snnTorch)")
    plt.grid(True, alpha=0.3)
    plt.legend()

    # Accuracy curve
    plt.subplot(2, 1, 2)
    plt.plot(epochs_range, train_acc_hist, label="Train acc")
    plt.plot(epochs_range, val_acc_hist, label="Val acc")
    plt.xlabel("Epoch")
    plt.ylabel("Accuracy")
    plt.title("Training & Validation Accuracy (snnTorch)")
    plt.grid(True, alpha=0.3)
    plt.legend()

    plt.tight_layout()
    plt.savefig("training_curves_snntorch.png", dpi=150, bbox_inches="tight")
    print("Training curves saved as 'training_curves_snntorch.png'")
    plt.close()

    # --------------------------------------------------------
    # 6) Confusion matrix (validation set)
    # --------------------------------------------------------
    y_val_true, y_val_pred = get_all_predictions(model, val_loader, device)

    print("\nClassification report (VALIDATION set):")
    print(
        classification_report(
            y_val_true,
            y_val_pred,
            target_names=[IDX_TO_AAMI_LABEL[i] for i in range(5)],
            digits=3,
        )
    )

    plot_confusion_matrix(
        y_val_true,
        y_val_pred,
        filename="confusion_matrix_val_snntorch.png",
        title="Confusion Matrix (Validation, snnTorch)",
    )

    # --------------------------------------------------------
    # 7) Evaluate on TEST set
    # --------------------------------------------------------
    test_loss, test_acc = eval_one_epoch(model, test_loader, device)
    print(f"\n[TEST] loss={test_loss:.4f}, acc={test_acc:.3f}")

    y_test_true, y_test_pred = get_all_predictions(model, test_loader, device)

    print("\nClassification report (TEST set):")
    print(
        classification_report(
            y_test_true,
            y_test_pred,
            target_names=[IDX_TO_AAMI_LABEL[i] for i in range(5)],
            digits=3,
        )
    )

    plot_confusion_matrix(
        y_test_true,
        y_test_pred,
        filename="confusion_matrix_test_snntorch.png",
        title="Confusion Matrix (Test, snnTorch)",
    )


if __name__ == "__main__":
    main()


Using device = cuda
Class N: total=300, train=210, val=45, test=45
Class S: total=300, train=210, val=45, test=45
Class V: total=300, train=210, val=45, test=45
Class F: total=300, train=210, val=45, test=45
Class Q: total=300, train=210, val=45, test=45

Total samples: 1500
Train: 1050, Val: 225, Test: 225

Sample shape: T=234, C=5 -> input_dim=5
Epoch 001: train_loss=3.0226, train_acc=0.294, val_loss=1.4625, val_acc=0.360
Epoch 002: train_loss=1.4716, train_acc=0.352, val_loss=1.4428, val_acc=0.369
Epoch 003: train_loss=1.4418, train_acc=0.386, val_loss=1.4484, val_acc=0.396
Epoch 004: train_loss=1.3305, train_acc=0.451, val_loss=1.0939, val_acc=0.547
Epoch 005: train_loss=1.0300, train_acc=0.573, val_loss=0.8814, val_acc=0.587
Epoch 006: train_loss=0.6530, train_acc=0.700, val_loss=0.6014, val_acc=0.756
Epoch 007: train_loss=0.4913, train_acc=0.796, val_loss=0.5680, val_acc=0.747
Epoch 008: train_loss=0.3809, train_acc=0.830, val_loss=0.3225, val_acc=0.938
Epoch 009: train_loss=0.21