In [None]:
"""
Training SNN untuk ECG dengan snnTorch (3-way split: Train / Val / Test)
------------------------------------------------------------------------
Encoder terbaru: sampler_centered_var

Dataset structure:
    spikes_sampler_centered_var/
        N/*_sampler_centered_var.npz
        S/*_sampler_centered_var.npz
        V/*_sampler_centered_var.npz
        F/*_sampler_centered_var.npz
        Q/*_sampler_centered_var.npz

Setiap file .npz minimal punya:
    - spikes         : [N_anchor, SLOTS] (uint8 0/1)
    - (optional) label, fs, t_rel, anchor_times, anchor_indices, ...

Input ke SNN:
    spikes_np shape [N_anchor, SLOTS]
    -> spikes_tensor [T, C] = [SLOTS, N_anchor]

Model:
    Input C -> FC(30)+LIF -> FC(30)+LIF -> FC(5)+LIF
"""

import os
import glob
import random
from typing import List, Tuple, Dict, Optional

import numpy as np
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

import snntorch as snn
from snntorch import surrogate

from sklearn.metrics import confusion_matrix, classification_report


# ============================================================
# 1. Label mapping AAMI
# ============================================================

AAMI_LABEL_TO_IDX: Dict[str, int] = {"N": 0, "S": 1, "V": 2, "F": 3, "Q": 4}
IDX_TO_AAMI_LABEL: Dict[int, str] = {v: k for k, v in AAMI_LABEL_TO_IDX.items()}


# ============================================================
# 2. Dataset: membaca sampler spike encoding dari .npz
# ============================================================

class SamplerSpikeECGDataset(Dataset):
    """
    Dataset untuk membaca spike encoding "sampler_centered_var".

    Struktur folder:
      root_dir/
          N/*.npz
          S/*.npz
          V/*.npz
          F/*.npz
          Q/*.npz

    Setiap file .npz minimal punya:
      - spikes: [N_anchor, SLOTS]  (uint8 0/1)

    Output __getitem__:
      - spikes: Tensor [T, C] = [SLOTS, N_anchor] (float32)
      - label : int (0..4)
    """

    def __init__(self, root_dir: str, classes: Optional[List[str]] = None):
        super().__init__()
        self.root_dir = root_dir
        if classes is None:
            classes = ["N", "S", "V", "F", "Q"]
        self.classes = [c.upper() for c in classes]

        self.samples: List[Tuple[str, str]] = []
        for cls in self.classes:
            cls_dir = os.path.join(root_dir, cls)
            if not os.path.isdir(cls_dir):
                continue
            files = sorted(glob.glob(os.path.join(cls_dir, "*.npz")))
            for fp in files:
                self.samples.append((fp, cls))

        if not self.samples:
            raise RuntimeError(f"Tidak ada file .npz di {root_dir}")

    def __len__(self) -> int:
        return len(self.samples)

    def __getitem__(self, idx: int):
        fp, cls = self.samples[idx]
        d = np.load(fp, allow_pickle=True)

        if "spikes" not in d.files:
            d.close()
            raise KeyError(f"{fp} tidak memiliki key 'spikes'.")

        spikes_np = d["spikes"]  # [N_anchor, SLOTS]
        d.close()

        # pastikan binary float32 (0/1)
        spikes_np = (spikes_np > 0).astype(np.float32)

        # Convert ke [T, C] = [SLOTS, N_anchor]
        # spikes_np: [C, T] -> transpose
        spikes = torch.from_numpy(spikes_np.T)  # [T, C]

        y = AAMI_LABEL_TO_IDX[cls]
        return spikes, y


# ============================================================
# 3. Stratified 3-way split: Train / Val / Test
# ============================================================

def split_dataset_3way(
    dataset: SamplerSpikeECGDataset,
    val_ratio: float = 0.15,
    test_ratio: float = 0.15,
    seed: int = 42,
):
    rng = random.Random(seed)

    class_to_indices: Dict[str, List[int]] = {c: [] for c in AAMI_LABEL_TO_IDX.keys()}
    for i in range(len(dataset)):
        _, cls = dataset.samples[i]
        class_to_indices[cls].append(i)

    train_idx: List[int] = []
    val_idx: List[int] = []
    test_idx: List[int] = []

    for cls, indices in class_to_indices.items():
        if not indices:
            continue
        rng.shuffle(indices)
        n = len(indices)
        n_val = int(n * val_ratio)
        n_test = int(n * test_ratio)

        cls_val = indices[:n_val]
        cls_test = indices[n_val:n_val + n_test]
        cls_train = indices[n_val + n_test:]

        train_idx.extend(cls_train)
        val_idx.extend(cls_val)
        test_idx.extend(cls_test)

        print(f"Class {cls}: total={n}, train={len(cls_train)}, val={len(cls_val)}, test={len(cls_test)}")

    return (
        torch.utils.data.Subset(dataset, train_idx),
        torch.utils.data.Subset(dataset, val_idx),
        torch.utils.data.Subset(dataset, test_idx),
    )


# ============================================================
# 4. SNN model (snnTorch)
# ============================================================

class SNN_Snntorch_ECG(nn.Module):
    def __init__(self, input_dim: int, hidden1: int = 30, hidden2: int = 30, n_classes: int = 5):
        super().__init__()

        self.fc1 = nn.Linear(input_dim, hidden1, bias=True)
        self.lif1 = snn.Leaky(
            beta=0.9, threshold=1.0,
            spike_grad=surrogate.atan(),
            learn_beta=True, learn_threshold=True,
            reset_mechanism="subtract",
        )

        self.fc2 = nn.Linear(hidden1, hidden2, bias=True)
        self.lif2 = snn.Leaky(
            beta=0.9, threshold=1.0,
            spike_grad=surrogate.atan(),
            learn_beta=True, learn_threshold=True,
            reset_mechanism="subtract",
        )

        self.fc3 = nn.Linear(hidden2, n_classes, bias=True)
        self.lif3 = snn.Leaky(
            beta=0.9, threshold=1.0,
            spike_grad=surrogate.atan(),
            learn_beta=True, learn_threshold=True,
            reset_mechanism="subtract",
        )

    def forward(self, spikes: torch.Tensor):
        """
        spikes: [B, T, C]
        return: [B, n_classes] spike counts
        """
        B, T, C = spikes.shape
        device = spikes.device

        spikes_t = spikes.permute(1, 0, 2)  # [T, B, C]

        mem1 = torch.zeros(B, self.fc1.out_features, device=device)
        mem2 = torch.zeros(B, self.fc2.out_features, device=device)
        mem3 = torch.zeros(B, self.fc3.out_features, device=device)

        spk_sum = torch.zeros(B, self.fc3.out_features, device=device)

        for t in range(T):
            cur = spikes_t[t]          # [B, C]

            cur = self.fc1(cur)
            spk1, mem1 = self.lif1(cur, mem1)

            cur = self.fc2(spk1)
            spk2, mem2 = self.lif2(cur, mem2)

            cur = self.fc3(spk2)
            spk3, mem3 = self.lif3(cur, mem3)

            spk_sum += spk3

        return spk_sum


# ============================================================
# 5. Train / Eval helpers
# ============================================================

def train_one_epoch(model, loader, optimizer, device):
    model.train()
    criterion = nn.CrossEntropyLoss()

    total_loss = 0.0
    total_correct = 0
    total_samples = 0

    for spikes, labels in loader:
        spikes = spikes.to(device)  # [B, T, C]
        labels = labels.to(device)

        optimizer.zero_grad()
        logits = model(spikes)
        loss = criterion(logits, labels)
        loss.backward()
        optimizer.step()

        total_loss += loss.item() * labels.size(0)
        total_correct += (logits.argmax(dim=1) == labels).sum().item()
        total_samples += labels.size(0)

    return total_loss / total_samples, total_correct / total_samples


@torch.no_grad()
def eval_one_epoch(model, loader, device):
    model.eval()
    criterion = nn.CrossEntropyLoss()

    total_loss = 0.0
    total_correct = 0
    total_samples = 0

    for spikes, labels in loader:
        spikes = spikes.to(device)
        labels = labels.to(device)

        logits = model(spikes)
        loss = criterion(logits, labels)

        total_loss += loss.item() * labels.size(0)
        total_correct += (logits.argmax(dim=1) == labels).sum().item()
        total_samples += labels.size(0)

    return total_loss / total_samples, total_correct / total_samples


@torch.no_grad()
def get_all_predictions(model, loader, device):
    model.eval()
    ys = []
    ps = []
    for spikes, labels in loader:
        spikes = spikes.to(device)
        logits = model(spikes)
        preds = logits.argmax(dim=1).cpu().numpy()
        ys.append(labels.numpy())
        ps.append(preds)
    return np.concatenate(ys), np.concatenate(ps)


def plot_confusion_matrix(y_true, y_pred, filename, title):
    cm = confusion_matrix(y_true, y_pred, labels=[0, 1, 2, 3, 4])

    plt.figure(figsize=(6, 5))
    im = plt.imshow(cm, interpolation="nearest", cmap="Blues")
    plt.title(title)
    plt.colorbar(im, fraction=0.046, pad=0.04)

    tick_marks = np.arange(5)
    class_names = [IDX_TO_AAMI_LABEL[i] for i in range(5)]
    plt.xticks(tick_marks, class_names)
    plt.yticks(tick_marks, class_names)

    thresh = cm.max() / 2.0 if cm.max() > 0 else 0.5
    for i in range(cm.shape[0]):
        for j in range(cm.shape[1]):
            plt.text(j, i, str(cm[i, j]),
                     ha="center",
                     color="white" if cm[i, j] > thresh else "black")

    plt.ylabel("True label")
    plt.xlabel("Predicted label")
    plt.tight_layout()
    plt.savefig(filename, dpi=150, bbox_inches="tight")
    plt.close()
    print(f"Confusion matrix saved as '{filename}'")


# ============================================================
# 6. Main
# ============================================================

def main():
    dataset_dir = "spikes_sampler_centered_var"  # <-- GANTI ke folder encoder terbaru kamu
    batch_size = 32
    max_epochs = 100
    lr = 1e-3
    val_ratio = 0.15
    test_ratio = 0.15
    patience = 10
    num_workers = 0

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device = {device}")

    full_dataset = SamplerSpikeECGDataset(dataset_dir, classes=["N", "S", "V", "F", "Q"])
    train_set, val_set, test_set = split_dataset_3way(full_dataset, val_ratio=val_ratio, test_ratio=test_ratio, seed=42)

    print(f"\nTotal samples: {len(full_dataset)}")
    print(f"Train: {len(train_set)}, Val: {len(val_set)}, Test: {len(test_set)}\n")

    # infer dim: sample [T, C]
    spikes0, _ = full_dataset[0]
    T0, C0 = spikes0.shape
    input_dim = C0
    print(f"Sample shape: T={T0}, C={C0} -> input_dim={input_dim}")

    train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=num_workers, drop_last=False)
    val_loader   = DataLoader(val_set,   batch_size=batch_size, shuffle=False, num_workers=num_workers, drop_last=False)
    test_loader  = DataLoader(test_set,  batch_size=batch_size, shuffle=False, num_workers=num_workers, drop_last=False)

    model = SNN_Snntorch_ECG(input_dim=input_dim, hidden1=30, hidden2=30, n_classes=5).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)

    train_loss_hist, val_loss_hist = [], []
    train_acc_hist,  val_acc_hist  = [], []

    best_val_loss = float("inf")
    best_state = None
    no_improve = 0

    for epoch in range(1, max_epochs + 1):
        tr_loss, tr_acc = train_one_epoch(model, train_loader, optimizer, device)
        va_loss, va_acc = eval_one_epoch(model, val_loader, device)

        train_loss_hist.append(tr_loss)
        val_loss_hist.append(va_loss)
        train_acc_hist.append(tr_acc)
        val_acc_hist.append(va_acc)

        print(f"Epoch {epoch:03d}: train_loss={tr_loss:.4f}, train_acc={tr_acc:.3f}, val_loss={va_loss:.4f}, val_acc={va_acc:.3f}")

        if va_loss < best_val_loss - 1e-4:
            best_val_loss = va_loss
            best_state = model.state_dict()
            no_improve = 0
        else:
            no_improve += 1
            if no_improve >= patience:
                print(f"Early stopping at epoch {epoch} (no improvement for {patience} epochs).")
                break

    if best_state is not None:
        model.load_state_dict(best_state)
        torch.save(best_state, "snn_ecg_snntorch_sampler_30_30_best.pt")
        print("Best model saved to 'snn_ecg_snntorch_sampler_30_30_best.pt'")

    # curves
    epochs_range = range(1, len(train_loss_hist) + 1)
    plt.figure(figsize=(10, 8))

    plt.subplot(2, 1, 1)
    plt.plot(epochs_range, train_loss_hist, label="Train loss")
    plt.plot(epochs_range, val_loss_hist, label="Val loss")
    plt.ylabel("Loss")
    plt.title("Training & Validation Loss (Sampler Encoder)")
    plt.grid(True, alpha=0.3)
    plt.legend()

    plt.subplot(2, 1, 2)
    plt.plot(epochs_range, train_acc_hist, label="Train acc")
    plt.plot(epochs_range, val_acc_hist, label="Val acc")
    plt.xlabel("Epoch")
    plt.ylabel("Accuracy")
    plt.title("Training & Validation Accuracy (Sampler Encoder)")
    plt.grid(True, alpha=0.3)
    plt.legend()

    plt.tight_layout()
    plt.savefig("training_curves_snntorch_sampler.png", dpi=150, bbox_inches="tight")
    plt.close()
    print("Training curves saved as 'training_curves_snntorch_sampler.png'")

    # validation report + CM
    y_val_true, y_val_pred = get_all_predictions(model, val_loader, device)
    print("\nClassification report (VALIDATION set):")
    print(classification_report(
        y_val_true, y_val_pred,
        target_names=[IDX_TO_AAMI_LABEL[i] for i in range(5)],
        digits=3,
    ))
    plot_confusion_matrix(y_val_true, y_val_pred,
                          filename="confusion_matrix_val_snntorch_sampler.png",
                          title="Confusion Matrix (Validation, Sampler)")

    # test report + CM
    te_loss, te_acc = eval_one_epoch(model, test_loader, device)
    print(f"\n[TEST] loss={te_loss:.4f}, acc={te_acc:.3f}")

    y_test_true, y_test_pred = get_all_predictions(model, test_loader, device)
    print("\nClassification report (TEST set):")
    print(classification_report(
        y_test_true, y_test_pred,
        target_names=[IDX_TO_AAMI_LABEL[i] for i in range(5)],
        digits=3,
    ))
    plot_confusion_matrix(y_test_true, y_test_pred,
                          filename="confusion_matrix_test_snntorch_sampler.png",
                          title="Confusion Matrix (Test, Sampler)")


if __name__ == "__main__":
    main()


  from .autonotebook import tqdm as notebook_tqdm


Using device = cuda
Class N: total=300, train=210, val=45, test=45
Class S: total=300, train=210, val=45, test=45
Class V: total=300, train=210, val=45, test=45
Class F: total=300, train=210, val=45, test=45
Class Q: total=300, train=210, val=45, test=45

Total samples: 1500
Train: 1050, Val: 225, Test: 225

Sample shape: T=64, C=30 -> input_dim=30
Epoch 001: train_loss=1.9027, train_acc=0.160, val_loss=1.6094, val_acc=0.200
Epoch 002: train_loss=1.5357, train_acc=0.200, val_loss=1.4644, val_acc=0.200
Epoch 003: train_loss=1.1177, train_acc=0.368, val_loss=1.0256, val_acc=0.400
Epoch 004: train_loss=0.7591, train_acc=0.608, val_loss=0.2058, val_acc=0.898
Epoch 005: train_loss=0.0816, train_acc=0.973, val_loss=0.0352, val_acc=0.987
Epoch 006: train_loss=0.0288, train_acc=0.993, val_loss=0.0175, val_acc=0.996
