In [None]:
from __future__ import annotations

# ============================================================
# Imports & Constants
# ============================================================

from dataclasses import dataclass
from pathlib import Path
from typing import Optional, Tuple, Dict

import random
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader

from grid_dataset import GridSeqDataset
from collate_grid import collate_grid

try:
    import snntorch as snn
    from snntorch import surrogate
except ImportError as e:
    raise ImportError(
        "    pip install snntorch\n"
    ) from e

# === (CHANGE IT) sim_name ===
PROJECT_ROOT = Path("project_root/scripts")
SIM_NAME = "5_24"

_ALL_MODES = ["logic", "rulemap", "random", "intermediate_case1", "intermediate_case2"]


# ============================================================
# Config
# ============================================================

@dataclass
class TrainConfigSNN:
    """
    SNN training configuration
    data_root: project root directory (contains data/)
    sim_name:  corresponds to data/<sim_name>/ directory
    mode:      'all' or a single mode (e.g., 'random')
    """
    data_root: Path = PROJECT_ROOT
    sim_name: str = SIM_NAME
    mode: str = "all"    # or a single mode (e.g., 'random')
    # model
    hidden_dim: int = 256
    e_char_dim: int = 64
    d_model: int = 256
    dropout: float = 0.1

    # train
    batch_size: int = 32
    lr: float = 3e-4
    weight_decay: float = 1e-2
    max_epochs: int = 20
    seed: int = 42
    device: str = "cuda" if torch.cuda.is_available() else "cpu"

    save_dir: Path = Path("../checkpoints")

    # transfer
    # e.g., pretrain_mode='random' means pretrain on random mode, then fine-tune on other modes
    pretrain_mode: Optional[str] = None
    freeze_encoder_on_ft: bool = True    # Freeze front-end conv+LIF during fine-tuning, only train head

    # checkpoint save location
    save_dir: Path = PROJECT_ROOT / "checkpoints_snn"



In [None]:
# ============================================================
# Spiking Version:  Character Net + Prediction Head
# ============================================================

class SpikingCharNet(nn.Module):
    """
    input:
      grid_seq: [B,T,H,W,C]
      tmask:    [B,T]  True=valid time step

    per-frame:
      Conv2d( stride=2 ) x2:
        H0×W0 → H1×W1 → H2×W2
        Here, use a formula to automatically infer H2=W2 from input_hw, supporting different sizes like 12x12, 24x24, etc.

    temporal:
      One layer of Leaky-Integrate-and-Fire (LIF) accumulates spikes
      Use time-avg spike rate as representation -> map to e_char_dim
    output:
      e_char: [B, e_char_dim]
    """
    def __init__(self, in_ch: int, hidden_dim: int = 256,
                 e_char_dim: int = 64, input_hw: int = 24):
        super().__init__()
        self.hidden_dim = hidden_dim

        # Spatial encoding: two stride=2 conv layers
        self.conv1 = nn.Conv2d(in_ch, 32, kernel_size=3, stride=2, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1)

        # Infer conv2 output H2=W2 from input_hw
        def _out_size(size, k=3, s=2, p=1):
            return (size + 2 * p - k) // s + 1

        h1 = _out_size(input_hw)   # after first conv
        h2 = _out_size(h1)         # after second conv
        self.h2 = h2
        self.flatten_dim = 64 * h2 * h2

        # flatten → hidden_dim
        self.fc = nn.Linear(self.flatten_dim, hidden_dim)

        # Temporal SNN: one layer of LIF hidden
        self.lif_hidden = snn.Leaky(beta=0.9, spike_grad=surrogate.atan())

        # time-avg spike-rate → e_char
        self.to_e = nn.Linear(hidden_dim, e_char_dim)
        self.ln = nn.LayerNorm(e_char_dim)

    def forward(self, grid_seq: torch.Tensor, tmask: torch.Tensor) -> torch.Tensor:
        """
        grid_seq: [B,T,H,W,C]
        tmask:    [B,T] (bool) True=valid time step
        return:   e_char [B, e_char_dim]
        """
        B, T, H, W, C = grid_seq.shape

        # Convert to conv2d format [B,T,C,H,W]
        x = grid_seq.permute(0, 1, 4, 2, 3).contiguous()

        # Initialize LIF state and spike accumulation
        mem_h = torch.zeros(B, self.fc.out_features, device=grid_seq.device)
        spike_sum = torch.zeros(B, self.fc.out_features, device=grid_seq.device)

        # Valid length for time-average
        valid_len = (
            tmask.sum(dim=1)            # [B]
            .clamp_min(1)
            .to(grid_seq.device)
            .unsqueeze(1)               # [B,1]
            .float()
        )

        for t in range(T):
            # Take the frame at time step t [B,C,H,W]
            xt = x[:, t]
            m = tmask[:, t].float().view(B, 1, 1, 1)  # [B,1,1,1]
            # Zero out PAD steps (tmask=False) to avoid contaminating membrane potential
            xt = xt * m

            # Spatial conv
            h = F.relu(self.conv1(xt))
            h = F.relu(self.conv2(h))
            h = h.view(B, -1)          # [B, flatten_dim]
            # Here flatten_dim equals self.flatten_dim
            h = F.relu(self.fc(h))     # [B, hidden_dim]

            # Temporal LIF (hidden spikes)
            spk_h, mem_h = self.lif_hidden(h, mem_h)  # [B, hidden_dim], [B, hidden_dim]

            # Accumulate spikes only at valid steps
            spike_sum = spike_sum + spk_h * m.view(B, 1)

        # time-average spike rate
        rate = spike_sum / valid_len                 # [B, hidden_dim]
        e_char = self.to_e(rate)                     # [B, e_char_dim]
        e_char = self.ln(e_char)
        return e_char

class PredHeadSNN(nn.Module):
    """
    Prediction head：
      e_char + choice embeddings → MLP → 4 logits

    - choices_ids: [B,4,3] (each choice consists of 3 cell ids)
      First do nn.Embedding, then average over the 3 positions
    - e_char is linearly projected to d_model, followed by LayerNorm
    - For each choice:
        feat = concat(pf, ch, |pf-ch|, pf*ch)  → MLP → scalar logit
    """
    def __init__(
        self,
        e_char_dim: int,
        d_model: int = 256,
        cell_vocab: int = 24 * 24,
        choice_emb_dim: int = 256,
        dropout: float = 0.1,
    ):
        super().__init__()
        self.proj_e = nn.Linear(e_char_dim, d_model)
        self.choice_emb = nn.Embedding(cell_vocab, choice_emb_dim)
        self.proj_choice = nn.Linear(choice_emb_dim, d_model)
        self.ln = nn.LayerNorm(d_model)
        self.mlp = nn.Sequential(
            nn.Linear(4 * d_model, 2 * d_model),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(2 * d_model, 1),
        )

    def _embed_choice_triplet(self, ids: torch.Tensor) -> torch.Tensor:
        """
        ids: [B,4,3]
        return: [B,4,Dc] averaged choice embeddings
        """
        emb = self.choice_emb(ids)  # [B,4,3,Dc]
        return emb.mean(dim=2)      # [B,4,Dc]

    def forward(self, e_char: torch.Tensor, choices_ids: torch.Tensor) -> torch.Tensor:
        """
        e_char:      [B, e_char_dim]
        choices_ids: [B, 4, 3]
        return:      logits [B, 4]
        """
        B = e_char.size(0)

        # prefix fused
        e = self.proj_e(e_char)   # [B,D]
        pf = self.ln(e)           # [B,D]

        # choices
        ch = self._embed_choice_triplet(choices_ids)  # [B,4,Dc]
        ch = self.proj_choice(ch)                     # [B,4,D]

        pf_expand = pf.unsqueeze(1).expand_as(ch)     # [B,4,D]
        feat = torch.cat(
            [pf_expand, ch, torch.abs(pf_expand - ch), pf_expand * ch],
            dim=-1,
        )  # [B,4,4D]

        logits = self.mlp(feat).squeeze(-1)          # [B,4]
        return logits

class ToMNetSNN(nn.Module):
    """
    Overall SNN version of ToMNet:
      - vision: SpikingCharNet → e_char
      - head:   PredHeadSNN    → 4 logits
    """
    def __init__(
        self,
        in_ch: int,
        hidden_dim: int = 256,
        e_char_dim: int = 64,
        d_model: int = 256,
        dropout: float = 0.1,
        cell_vocab: int = 24 * 24,
        input_hw: int = 24, 
    ):
        super().__init__()
        self.vision = SpikingCharNet(
            in_ch=in_ch,
            hidden_dim=hidden_dim,
            e_char_dim=e_char_dim,
            input_hw=input_hw,
        )
        self.head = PredHeadSNN(
            e_char_dim=e_char_dim,
            d_model=d_model,
            cell_vocab=cell_vocab,
            dropout=dropout,
        )

    @torch.no_grad()
    def num_params(self) -> int:
        return sum(p.numel() for p in self.parameters())

    def forward(
        self,
        grid_seq: torch.Tensor,
        tmask: torch.Tensor,
        choices_ids: torch.Tensor,
    ) -> torch.Tensor:
        e_char = self.vision(grid_seq, tmask)   # [B,e_char_dim]
        logits = self.head(e_char, choices_ids) # [B,4]
        return logits


In [None]:
# ============================================================
# Training Utilities
# ============================================================

def set_seed(seed: int) -> None:
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

def make_loader(cfg: TrainConfigSNN, split: str, mode: str) -> Tuple[DataLoader, int, int]:
    """
    create DataLoader, and also return
      - in_ch(last dimension C)
      - input_hw (here use H, assuming H=W)
    """
    ds = GridSeqDataset(cfg.data_root, cfg.sim_name, split, mode)
    one = ds[0]["grid_seq"]  # [T,H,W,C]
    _, H, W, C = one.shape
    in_ch = int(C)

    ld = DataLoader(
        ds,
        batch_size=cfg.batch_size,
        shuffle=(split == "train"),
        num_workers=0,
        pin_memory=True,
        collate_fn=collate_grid,
    )
    assert H == W, f"Currently assume inner grid is square, but got H={H}, W={W}"
    return ld, in_ch, H


@torch.no_grad()
def evaluate(model: nn.Module, loader: DataLoader, device: str) -> Tuple[float, float]:
    """
    Evaluate on a given DataLoader:
      return: (accuracy, avg_loss)
    """
    model.eval()
    n, corr, tot_loss = 0, 0, 0.0

    for batch in loader:
        x = {
            k: (v.to(device) if isinstance(v, torch.Tensor) else v)
            for k, v in batch.items()
        }
        logits = model(x["grid_seq"], x["tmask"], x["choices_ids"])
        loss = F.cross_entropy(logits, x["labels"])

        pred = logits.argmax(dim=-1)
        corr += (pred == x["labels"]).sum().item()
        tot_loss += loss.item() * logits.size(0)
        n += logits.size(0)

    if n == 0:
        return 0.0, 0.0

    return corr / n, tot_loss / n

def build_model(in_ch: int, cfg: TrainConfigSNN, input_hw: int) -> ToMNetSNN:
    """
    Build a ToMNetSNN based on in_ch and input_hw.
    """
    model = ToMNetSNN(
        in_ch=in_ch,
        hidden_dim=cfg.hidden_dim,
        e_char_dim=cfg.e_char_dim,
        d_model=cfg.d_model,
        dropout=cfg.dropout,
        cell_vocab=24 * 24,   # vocab 開大一點沒關係，只要 id < vocab
        input_hw=input_hw,    # ★ 傳進去
    )
    return model


def freeze_encoder(model: ToMNetSNN, flag: bool = True) -> None:
    """
    freeze/unfreeze vision encoder (SpikingCharNet).
    """
    for p in model.vision.parameters():
        p.requires_grad = not flag

def train_one_mode(
    cfg: TrainConfigSNN,
    mode: str,
    model: Optional[nn.Module] = None,
    *,
    load_ckpt: Optional[Path] = None,
    freeze_encoder_flag: bool = False,
):
    print(f"\n=== [SNN] Training mode = {mode} ===")
    cfg.save_dir.mkdir(parents=True, exist_ok=True)

    # Note: now also return input_hw
    train_ld, in_ch, input_hw = make_loader(cfg, "train", mode)
    val_ld, _, _ = make_loader(cfg, "val", mode)
    test_ld, _, _ = make_loader(cfg, "test", mode)

    if model is None:
        model = build_model(in_ch, cfg, input_hw=input_hw).to(cfg.device)

    if load_ckpt is not None:
        print(f"[SNN] Load checkpoint: {load_ckpt}")
        state = torch.load(load_ckpt, map_location="cpu")
        model.load_state_dict(state)

    if freeze_encoder_flag:
        print("[SNN] Freeze encoder (vision) for fine-tuning.")
        freeze_encoder(model, True)

    opt = torch.optim.AdamW(
        filter(lambda p: p.requires_grad, model.parameters()),
        lr=cfg.lr,
        weight_decay=cfg.weight_decay,
    )
    sch = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=cfg.max_epochs)

    best_va, best_state = 0.0, None

    for ep in range(1, cfg.max_epochs + 1):
        model.train()
        n = corr = 0
        tot_loss = 0.0

        for batch in train_ld:
            x = {
                k: (v.to(cfg.device) if isinstance(v, torch.Tensor) else v)
                for k, v in batch.items()
            }
            logits = model(x["grid_seq"], x["tmask"], x["choices_ids"])
            loss = F.cross_entropy(logits, x["labels"])

            opt.zero_grad(set_to_none=True)
            loss.backward()
            nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            opt.step()

            pred = logits.argmax(dim=-1)
            corr += (pred == x["labels"]).sum().item()
            tot_loss += loss.item() * logits.size(0)
            n += logits.size(0)

        tr_acc = corr / max(n, 1)
        tr_loss = tot_loss / max(n, 1)
        va_acc, va_loss = evaluate(model, val_ld, cfg.device)
        sch.step()

        print(
            f"[{mode}][Ep {ep:02d}] "
            f"train acc={tr_acc:.3f} loss={tr_loss:.3f} | "
            f"val acc={va_acc:.3f} loss={va_loss:.3f}"
        )

        if va_acc > best_va:
            best_va = va_acc
            best_state = {k: v.cpu() for k, v in model.state_dict().items()}

    if best_state is not None:
        model.load_state_dict(best_state)

    ckpt_path = cfg.save_dir / f"snn_{cfg.sim_name}_{mode}.pt"
    torch.save(model.state_dict(), ckpt_path)

    te_acc, te_loss = evaluate(model, test_ld, cfg.device)
    print(
        f"[{mode}][TEST] acc={te_acc:.3f} loss={te_loss:.3f} | "
        f"saved: {ckpt_path}"
    )
    return model, {"val_acc": best_va, "test_acc": te_acc, "ckpt": ckpt_path}



In [None]:

_ALL_MODES = ["random", "rulemap", "logic", "intermediate_case1", "intermediate_case2"]

def _model_name(sim_name: str, mode: str, kind: str = "ToMNet") -> str:
    # kind 可以填 "ToMNet", "ToMNetSNN", "ToMNetConv"
    return f"{kind}_{sim_name}_{mode}"

def _save_model(model: nn.Module, cfg: TrainConfigSNN, mode: str, kind: str = "ToMNet") -> Path:
    cfg.save_dir.mkdir(parents=True, exist_ok=True)
    name = _model_name(cfg.sim_name, mode, kind)
    path = cfg.save_dir / f"{name}.pt"
    torch.save(model.state_dict(), path)
    print(f"[SAVE] {name} -> {path}")
    return path

def run(cfg: TrainConfigSNN, kind: str = "ToMNet") -> Dict[str, nn.Module]:
    """
    kind: "ToMNet" / "ToMNetSNN" / "ToMNetConv", for naming purpose
    Return {model_name: model_instance}
    """
    models: Dict[str, nn.Module] = {}

    if cfg.mode == "all":
        modes = _ALL_MODES
    else:
        modes = [cfg.mode]

    for mode in modes:
        print(f"[TRAIN] sim={cfg.sim_name} mode={mode}")
        model, best_acc = train_one_mode(cfg, mode)

        name = _model_name(cfg.sim_name, mode, kind)
        _save_model(model, cfg, mode, kind=kind)
        models[name] = model
        print(f"[DONE] {name} (best_acc={best_acc})")

    return models


## Train

In [8]:
cfg = TrainConfigSNN(
    data_root=Path(".."),
    sim_name="3_12",
    mode="random",
    seed=42,
    batch_size=64,
    lr=1e-4,
    dropout=0.2,
    max_epochs=80,
)

models = run(cfg, kind="ToMNetSNN")
TomNet_Random_Model = models["ToMNetSNN_3_12_random"]

[TRAIN] sim=3_12 mode=random

=== [SNN] Training mode = random ===




[random][Ep 01] train acc=0.307 loss=1.326 | val acc=0.320 loss=1.259
[random][Ep 02] train acc=0.385 loss=1.219 | val acc=0.370 loss=1.214
[random][Ep 03] train acc=0.403 loss=1.169 | val acc=0.380 loss=1.197
[random][Ep 04] train acc=0.438 loss=1.147 | val acc=0.420 loss=1.190
[random][Ep 05] train acc=0.439 loss=1.133 | val acc=0.420 loss=1.186
[random][Ep 06] train acc=0.455 loss=1.116 | val acc=0.390 loss=1.180
[random][Ep 07] train acc=0.459 loss=1.112 | val acc=0.410 loss=1.185
[random][Ep 08] train acc=0.470 loss=1.096 | val acc=0.400 loss=1.187
[random][Ep 09] train acc=0.494 loss=1.082 | val acc=0.410 loss=1.190
[random][Ep 10] train acc=0.515 loss=1.069 | val acc=0.390 loss=1.199
[random][Ep 11] train acc=0.509 loss=1.063 | val acc=0.360 loss=1.203
[random][Ep 12] train acc=0.507 loss=1.049 | val acc=0.380 loss=1.212
[random][Ep 13] train acc=0.529 loss=1.043 | val acc=0.380 loss=1.217
[random][Ep 14] train acc=0.526 loss=1.034 | val acc=0.370 loss=1.223
[random][Ep 15] trai

In [22]:
cfg = TrainConfigSNN(
    data_root=Path(".."),
    sim_name="3_12",
    mode="rulemap",
    seed=42,
    batch_size=64,
    lr=1e-4,
    dropout=0.2,
    max_epochs=80,
)
models = run(cfg, kind="ToMNetSNN") 
TomNet_Rulemap_Model = models["ToMNetSNN_3_12_rulemap"]

[TRAIN] sim=3_12 mode=rulemap

=== [SNN] Training mode = rulemap ===
[rulemap][Ep 01] train acc=0.639 loss=1.151 | val acc=0.780 loss=0.918
[rulemap][Ep 02] train acc=0.800 loss=0.762 | val acc=0.820 loss=0.609
[rulemap][Ep 03] train acc=0.820 loss=0.501 | val acc=0.840 loss=0.443
[rulemap][Ep 04] train acc=0.850 loss=0.369 | val acc=0.880 loss=0.359
[rulemap][Ep 05] train acc=0.891 loss=0.297 | val acc=0.920 loss=0.306
[rulemap][Ep 06] train acc=0.916 loss=0.239 | val acc=0.940 loss=0.257
[rulemap][Ep 07] train acc=0.935 loss=0.199 | val acc=0.940 loss=0.223
[rulemap][Ep 08] train acc=0.949 loss=0.168 | val acc=0.950 loss=0.188
[rulemap][Ep 09] train acc=0.956 loss=0.146 | val acc=0.950 loss=0.168
[rulemap][Ep 10] train acc=0.956 loss=0.127 | val acc=0.950 loss=0.152
[rulemap][Ep 11] train acc=0.965 loss=0.113 | val acc=0.950 loss=0.142
[rulemap][Ep 12] train acc=0.970 loss=0.097 | val acc=0.950 loss=0.134
[rulemap][Ep 13] train acc=0.973 loss=0.091 | val acc=0.950 loss=0.132
[rulemap

In [23]:
cfg = TrainConfigSNN(
    data_root=Path(".."),
    sim_name="3_12",
    mode="logic",
    seed=42,
    batch_size=64,
    lr=1e-4,
    dropout=0.2,
    max_epochs=80,
)
models = run(cfg, kind="ToMNetSNN") 
TomNet_Logic_Model = models["ToMNetSNN_3_12_logic"]

[TRAIN] sim=3_12 mode=logic

=== [SNN] Training mode = logic ===
[logic][Ep 01] train acc=0.601 loss=1.192 | val acc=0.770 loss=0.954
[logic][Ep 02] train acc=0.756 loss=0.806 | val acc=0.780 loss=0.631
[logic][Ep 03] train acc=0.806 loss=0.541 | val acc=0.790 loss=0.447
[logic][Ep 04] train acc=0.835 loss=0.402 | val acc=0.830 loss=0.369
[logic][Ep 05] train acc=0.864 loss=0.333 | val acc=0.840 loss=0.327
[logic][Ep 06] train acc=0.882 loss=0.285 | val acc=0.880 loss=0.295
[logic][Ep 07] train acc=0.897 loss=0.248 | val acc=0.900 loss=0.267
[logic][Ep 08] train acc=0.914 loss=0.220 | val acc=0.920 loss=0.239
[logic][Ep 09] train acc=0.938 loss=0.192 | val acc=0.930 loss=0.212
[logic][Ep 10] train acc=0.960 loss=0.167 | val acc=0.960 loss=0.185
[logic][Ep 11] train acc=0.974 loss=0.145 | val acc=0.960 loss=0.166
[logic][Ep 12] train acc=0.975 loss=0.128 | val acc=0.970 loss=0.150
[logic][Ep 13] train acc=0.979 loss=0.109 | val acc=0.970 loss=0.130
[logic][Ep 14] train acc=0.985 loss=0.

In [24]:
cfg = TrainConfigSNN(
    data_root=Path(".."),
    sim_name="3_12",
    mode="intermediate_case1",
    seed=42,
    batch_size=64,
    lr=1e-4,
    dropout=0.2,
    max_epochs=80,
)
models = run(cfg, kind="ToMNetSNN") 
TomNet_Intermediate1_Model = models["ToMNetSNN_3_12_intermediate_case1"]

[TRAIN] sim=3_12 mode=intermediate_case1

=== [SNN] Training mode = intermediate_case1 ===
[intermediate_case1][Ep 01] train acc=0.319 loss=1.334 | val acc=0.410 loss=1.268
[intermediate_case1][Ep 02] train acc=0.364 loss=1.221 | val acc=0.430 loss=1.217
[intermediate_case1][Ep 03] train acc=0.360 loss=1.180 | val acc=0.410 loss=1.199
[intermediate_case1][Ep 04] train acc=0.414 loss=1.151 | val acc=0.400 loss=1.191
[intermediate_case1][Ep 05] train acc=0.439 loss=1.135 | val acc=0.380 loss=1.186
[intermediate_case1][Ep 06] train acc=0.456 loss=1.117 | val acc=0.410 loss=1.183
[intermediate_case1][Ep 07] train acc=0.466 loss=1.104 | val acc=0.420 loss=1.178
[intermediate_case1][Ep 08] train acc=0.469 loss=1.096 | val acc=0.430 loss=1.174
[intermediate_case1][Ep 09] train acc=0.497 loss=1.078 | val acc=0.490 loss=1.173
[intermediate_case1][Ep 10] train acc=0.511 loss=1.071 | val acc=0.450 loss=1.169
[intermediate_case1][Ep 11] train acc=0.492 loss=1.060 | val acc=0.490 loss=1.165
[interm

In [25]:
cfg = TrainConfigSNN(
    data_root=Path(".."),
    sim_name="3_12",
    mode="intermediate_case2",
    seed=42,
    batch_size=64,
    lr=1e-4,
    dropout=0.2,
    max_epochs=80,
)
models = run(cfg, kind="ToMNetSNN") 
TomNet_Intermediate2_Model = models["ToMNetSNN_3_12_intermediate_case2"]

[TRAIN] sim=3_12 mode=intermediate_case2

=== [SNN] Training mode = intermediate_case2 ===
[intermediate_case2][Ep 01] train acc=0.330 loss=1.303 | val acc=0.380 loss=1.205
[intermediate_case2][Ep 02] train acc=0.389 loss=1.196 | val acc=0.390 loss=1.150
[intermediate_case2][Ep 03] train acc=0.415 loss=1.160 | val acc=0.400 loss=1.143
[intermediate_case2][Ep 04] train acc=0.439 loss=1.150 | val acc=0.390 loss=1.145
[intermediate_case2][Ep 05] train acc=0.451 loss=1.120 | val acc=0.420 loss=1.143
[intermediate_case2][Ep 06] train acc=0.479 loss=1.111 | val acc=0.410 loss=1.148
[intermediate_case2][Ep 07] train acc=0.476 loss=1.100 | val acc=0.430 loss=1.153
[intermediate_case2][Ep 08] train acc=0.491 loss=1.088 | val acc=0.440 loss=1.160
[intermediate_case2][Ep 09] train acc=0.486 loss=1.077 | val acc=0.410 loss=1.158
[intermediate_case2][Ep 10] train acc=0.506 loss=1.064 | val acc=0.400 loss=1.168
[intermediate_case2][Ep 11] train acc=0.497 loss=1.057 | val acc=0.420 loss=1.173
[interm

## Evaluation

In [20]:
from model_evaluate import evaluate_model_performance

# 根據實際的批次鍵名定義forward函數
def forward_tomnet_SNN(m, b):
    return m(b["grid_seq"], b["tmask"], b["choices_ids"])

In [None]:
# random
val_loader, in_ch, input_hw = make_loader(cfg, split="val", mode="random")
metrics = evaluate_model_performance(
    TomNet_Random_Model,
    val_loader,
    device="mps",
    forward_fn=forward_tomnet_SNN,
    energy_forward_fn=forward_tomnet_SNN
)

print("=== ToMNetSNN_3_12_random model evaluation ===")
print(f"accuracy: {metrics.accuracy:.4f}")
print(f"precision: {metrics.precision:.4f}")
print(f"recall rate: {metrics.recall:.4f}")
print(f"F1 score: {metrics.f1_score:.4f}")
print(f"ROC AUC: {metrics.roc_auc:.4f}")
print(f"R2 Score: {metrics.r2_score:.4f}")
print(f"Prediction Entropy: {metrics.prediction_entropy:.4f}")
print("\n=== Energy and Efficiency Evaluation ===")
print(f"MACs per sample: {metrics.energy_report.macs_per_sample:,.0f}")
print(f"Energy per sample: {metrics.energy_report.ann_energy_per_sample:.2e} J")
print(f"Energy per accuracy unit: {metrics.energy_per_accuracy:.2e} J/acc")
print(f"MACs per accuracy unit: {metrics.macs_per_accuracy:,.0f} MACs/acc")

=== ToMNetSNN_3_12_random model evaluation ===
accuracy: 0.3600
precision: 0.5030
recall rate: 0.3600
F1 score: 0.3683
ROC AUC: 0.6452
R2 Score: -0.7594
Prediction Entropy: 0.7121

=== Energy and Efficiency Evaluation ===
MACs per sample: 7,393,792
Energy per sample: 3.40e-05 J
Energy per accuracy unit: 9.45e-05 J/acc
MACs per accuracy unit: 20,538,311 MACs/acc




In [28]:
# rulemap
val_loader, in_ch, input_hw = make_loader(cfg, split="val", mode="rulemap")
metrics = evaluate_model_performance(
    TomNet_Rulemap_Model,
    val_loader,
    device="mps",
    forward_fn=forward_tomnet_SNN,
    energy_forward_fn=forward_tomnet_SNN
)

print("=== ToMNetSNN_3_12_rulemap model evaluation ===")
print(f"accuracy: {metrics.accuracy:.4f}")
print(f"precision: {metrics.precision:.4f}")
print(f"recall rate: {metrics.recall:.4f}")
print(f"F1 score: {metrics.f1_score:.4f}")
print(f"ROC AUC: {metrics.roc_auc:.4f}")
print(f"R2 Score: {metrics.r2_score:.4f}")
print(f"Prediction Entropy: {metrics.prediction_entropy:.4f}")
print("\n=== Energy and Efficiency Evaluation ===")
print(f"MACs per sample: {metrics.energy_report.macs_per_sample:,.0f}")
print(f"Energy per sample: {metrics.energy_report.ann_energy_per_sample:.2e} J")
print(f"Energy per accuracy unit: {metrics.energy_per_accuracy:.2e} J/acc")
print(f"MACs per accuracy unit: {metrics.macs_per_accuracy:,.0f} MACs/acc")

=== ToMNetSNN_3_12_rulemap model evaluation ===
accuracy: 0.9500
precision: 0.9506
recall rate: 0.9500
F1 score: 0.9501
ROC AUC: 0.9974
R2 Score: 0.8558
Prediction Entropy: 0.0717

=== Energy and Efficiency Evaluation ===
MACs per sample: 7,393,792
Energy per sample: 3.40e-05 J
Energy per accuracy unit: 3.58e-05 J/acc
MACs per accuracy unit: 7,782,939 MACs/acc




In [26]:
# logic
val_loader, in_ch, input_hw = make_loader(cfg, split="val", mode="logic")
metrics = evaluate_model_performance(
    TomNet_Logic_Model,
    val_loader,
    device="mps",
    forward_fn=forward_tomnet_SNN,
    energy_forward_fn=forward_tomnet_SNN
)   

print("=== ToMNetSNN_3_12_logic model evaluation ===")
print(f"accuracy: {metrics.accuracy:.4f}")
print(f"precision: {metrics.precision:.4f}")
print(f"recall rate: {metrics.recall:.4f}")
print(f"F1 score: {metrics.f1_score:.4f}")
print(f"ROC AUC: {metrics.roc_auc:.4f}")
print(f"R2 Score: {metrics.r2_score:.4f}")
print(f"Prediction Entropy: {metrics.prediction_entropy:.4f}")
print("\n=== Energy and Efficiency Evaluation ===")
print(f"MACs per sample: {metrics.energy_report.macs_per_sample:,.0f}")
print(f"Energy per sample: {metrics.energy_report.ann_energy_per_sample:.2e} J")
print(f"Energy per accuracy unit: {metrics.energy_per_accuracy:.2e} J/acc")
print(f"MACs per accuracy unit: {metrics.macs_per_accuracy:,.0f} MACs/acc") 

=== ToMNetSNN_3_12_logic model evaluation ===
accuracy: 0.9800
precision: 0.9803
recall rate: 0.9800
F1 score: 0.9799
ROC AUC: 0.9992
R2 Score: 0.9798
Prediction Entropy: 0.0453

=== Energy and Efficiency Evaluation ===
MACs per sample: 7,393,792
Energy per sample: 3.40e-05 J
Energy per accuracy unit: 3.47e-05 J/acc
MACs per accuracy unit: 7,544,686 MACs/acc




In [29]:
# intermediate_case1
val_loader, in_ch, input_hw = make_loader(cfg, split="val", mode="intermediate_case1")
metrics = evaluate_model_performance(
    TomNet_Intermediate1_Model,
    val_loader,
    device="mps",
    forward_fn=forward_tomnet_SNN,
    energy_forward_fn=forward_tomnet_SNN
)   

print("=== ToMNetSNN_3_12_intermediate_case1 model evaluation ===")
print(f"accuracy: {metrics.accuracy:.4f}")
print(f"precision: {metrics.precision:.4f}")
print(f"recall rate: {metrics.recall:.4f}")
print(f"F1 score: {metrics.f1_score:.4f}")
print(f"ROC AUC: {metrics.roc_auc:.4f}")
print(f"R2 Score: {metrics.r2_score:.4f}")
print(f"Prediction Entropy: {metrics.prediction_entropy:.4f}")
print("\n=== Energy and Efficiency Evaluation ===")
print(f"MACs per sample: {metrics.energy_report.macs_per_sample:,.0f}")
print(f"Energy per sample: {metrics.energy_report.ann_energy_per_sample:.2e} J")
print(f"Energy per accuracy unit: {metrics.energy_per_accuracy:.2e} J/acc")
print(f"MACs per accuracy unit: {metrics.macs_per_accuracy:,.0f} MACs/acc") 

=== ToMNetSNN_3_12_intermediate_case1 model evaluation ===
accuracy: 0.4300
precision: 0.5161
recall rate: 0.4300
F1 score: 0.4507
ROC AUC: 0.7312
R2 Score: -0.3609
Prediction Entropy: 0.7282

=== Energy and Efficiency Evaluation ===
MACs per sample: 7,393,792
Energy per sample: 3.40e-05 J
Energy per accuracy unit: 7.91e-05 J/acc
MACs per accuracy unit: 17,194,865 MACs/acc




In [27]:
# intermediate_case2
val_loader, in_ch, input_hw = make_loader(cfg, split="val", mode="intermediate_case2")
metrics = evaluate_model_performance(
    TomNet_Intermediate2_Model,
    val_loader,
    device="mps",
    forward_fn=forward_tomnet_SNN,
    energy_forward_fn=forward_tomnet_SNN
)   

print("=== ToMNetSNN_3_12_intermediate_case2 model evaluation ===")
print(f"accuracy: {metrics.accuracy:.4f}")
print(f"precision: {metrics.precision:.4f}")
print(f"recall rate: {metrics.recall:.4f}")
print(f"F1 score: {metrics.f1_score:.4f}")
print(f"ROC AUC: {metrics.roc_auc:.4f}")
print(f"R2 Score: {metrics.r2_score:.4f}")
print(f"Prediction Entropy: {metrics.prediction_entropy:.4f}")
print("\n=== Energy and Efficiency Evaluation ===")
print(f"MACs per sample: {metrics.energy_report.macs_per_sample:,.0f}")
print(f"Energy per sample: {metrics.energy_report.ann_energy_per_sample:.2e} J")
print(f"Energy per accuracy unit: {metrics.energy_per_accuracy:.2e} J/acc")
print(f"MACs per accuracy unit: {metrics.macs_per_accuracy:,.0f} MACs/acc") 

=== ToMNetSNN_3_12_intermediate_case2 model evaluation ===
accuracy: 0.4200
precision: 0.5368
recall rate: 0.4200
F1 score: 0.4356
ROC AUC: 0.7078
R2 Score: -0.6316
Prediction Entropy: 0.7243

=== Energy and Efficiency Evaluation ===
MACs per sample: 7,393,792
Energy per sample: 3.40e-05 J
Energy per accuracy unit: 8.10e-05 J/acc
MACs per accuracy unit: 17,604,267 MACs/acc


