In [None]:
from __future__ import annotations

from dataclasses import dataclass
from pathlib import Path
import json, math, random
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

from grid_dataset import GridSeqDataset
from collate_grid import collate_grid

# === Unify root / sim_name ===
PROJECT_ROOT = Path("/Users/Jer_ry/Desktop/script_tom")
SIM_NAME = "3_12"
N_AGENTS = 3

_MODE_ALIASES = {
    "rulemap":"rulemap", "random":"random", "logic":"logic",
    "intermediate_case1":"intermediate_case1", "intermediate_case2":"intermediate_case2",
    "intermediate_1":"intermediate_case1", "intermediate_2":"intermediate_case2",
    "intemediete_1":"intermediate_case1", "intemediete_2":"intermediate_case2",
    "intermediat_case1":"intermediate_case1", "intermediat_case2":"intermediate_case2",
    "all":"all",
}
_ALL_MODES = ["rulemap","random","intermediate_case1","intermediate_case2"]


@dataclass
class TrainConfig:
    data_root: Path = PROJECT_ROOT
    sim_name: str = SIM_NAME
    mode: str = "all"    # or single mode
    # model
    d_model: int = 256
    nhead: int = 8
    n_layers: int = 2
    stem_ch: int = 32
    n_blocks: int = 4
    e_char_dim: int = 64
    dropout: float = 0.1
    last_k: int = 3
    # train
    batch_size: int = 32
    lr: float = 3e-4
    weight_decay: float = 1e-2
    max_epochs: int = 20
    seed: int = 42
    device: str = "cuda" if torch.cuda.is_available() else "cpu"


In [2]:
# vision_tomnet_transformer.py
import torch
import torch.nn as nn
import torch.nn.functional as F

class BasicBlock(nn.Module):
    def __init__(self, ch: int):
        super().__init__()
        self.c1 = nn.Conv2d(ch, ch, 3, padding=1)
        self.bn1 = nn.BatchNorm2d(ch)
        self.c2 = nn.Conv2d(ch, ch, 3, padding=1)
        self.bn2 = nn.BatchNorm2d(ch)
        self.act = nn.ReLU(inplace=True)
    def forward(self, x):
        h = self.act(self.bn1(self.c1(x)))
        h = self.bn2(self.c2(h))
        return self.act(x + h)

class VisionCharNet(nn.Module):
    """
    input: grid_seq [B, T, H, W, C]  (C 可為 18/30 ...)
    steps:
      1) per-frame 1x1 conv 把 C -> stem_ch
      2) 3x3 conv + n 個 residual blocks
      3) global average pool => [B, T, d_vis]
      4) TransformerEncoder over time => [B, T, d_model]
      5) 取 CLS（可選）或 masked mean => e_char [B, e_char_dim]
    """
    def __init__(self, in_ch: int, stem_ch: int = 32, n_blocks: int = 4,
                 d_model: int = 256, nhead: int = 8, n_layers: int = 2,
                 e_char_dim: int = 64, use_cls_token: bool = True, last_k: int = 3):
        super().__init__()
        self.use_cls_token = use_cls_token
        self.last_k = last_k
        self.stem1x1 = nn.Conv2d(in_ch, stem_ch, kernel_size=1)
        self.c3 = nn.Conv2d(stem_ch, stem_ch, kernel_size=3, padding=1)
        self.bn3 = nn.BatchNorm2d(stem_ch)
        self.blocks = nn.Sequential(*[BasicBlock(stem_ch) for _ in range(n_blocks)])
        self.act = nn.ReLU(inplace=True)
        self.gap = nn.AdaptiveAvgPool2d(1)                 # [B*T, ch,1,1]
        # 時序 Transformer
        self.in_proj = nn.Linear(stem_ch, d_model)         # 視覺特徵 -> d_model
        encoder_layer = nn.TransformerEncoderLayer(d_model=d_model, nhead=nhead,
                                                   dim_feedforward=4*d_model, batch_first=True)
        self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=n_layers)
        self.ln = nn.LayerNorm(d_model)
        self.to_e = nn.Linear(d_model, e_char_dim)

    def forward(self, grid_seq: torch.Tensor, tmask: torch.Tensor):
        """
        grid_seq: [B,T,H,W,C]; tmask: [B,T] True=有效
        return:
          e_char [B,e_char_dim],   # 角色 embedding
          h_time [B,T,d_model]     # 每步時序特徵（給 q_vec 用）
        """
        B,T,H,W,C = grid_seq.shape
        x = grid_seq.permute(0,1,4,2,3).contiguous()     # [B,T,C,H,W]
        x = x.view(B*T, C, H, W)
        x = self.stem1x1(x)
        x = self.act(self.bn3(self.c3(x)))
        x = self.blocks(x)
        x = self.gap(x).squeeze(-1)                      # [B*T, stem_ch]
        x = x.view(B, T, -1)                             # [B,T,stem_ch]
        h = self.in_proj(x)                              # [B,T,d_model]
        key_pad = ~tmask                                 # True=PAD
        h = self.encoder(h, src_key_padding_mask=key_pad)
        h = self.ln(h)                                   # [B,T,d_model]

        if self.use_cls_token:
            # 取最後一個有效步當「CLS」
            last = tmask.sum(dim=1) - 1                  # [B]
            e_src = h.gather(1, last.view(B,1,1).expand(B,1,h.size(-1))).squeeze(1)
        else:
            m = tmask.unsqueeze(-1)                      # [B,T,1]
            e_src = (h * m).sum(dim=1) / m.sum(dim=1).clamp_min(1.0)
        e_char = self.to_e(e_src)                        # [B,e_char_dim]
        return e_char, h

class PredHead(nn.Module):
    """
    與你原本 PredNet 類似，但不再用 token 序列的 q_vec，
    而是用時序特徵 h_time 的最後 K 步平均作為 query。
    choices 仍用 cell-id 嵌入。
    """
    def __init__(self, d_model: int, e_char_dim: int, dropout: float = 0.1,
                 cell_vocab: int = 24*24, choice_emb_dim: int = 256):
        super().__init__()
        self.proj_e = nn.Linear(e_char_dim, d_model)
        self.choice_emb = nn.Embedding(cell_vocab, choice_emb_dim)
        self.proj_choice = nn.Linear(choice_emb_dim, d_model)
        self.ln = nn.LayerNorm(d_model)
        self.mlp = nn.Sequential(
            nn.Linear(4*d_model, 2*d_model),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(2*d_model, 1)
        )

    def _embed_choice_triplet(self, ids: torch.Tensor) -> torch.Tensor:
        # ids: [B,4,3]
        emb = self.choice_emb(ids)         # [B,4,3,Dc]
        return emb.mean(dim=2)             # [B,4,Dc]

    def forward(self, e_char: torch.Tensor, h_time: torch.Tensor, tmask: torch.Tensor,
                choices_ids: torch.Tensor, last_k: int = 3) -> torch.Tensor:
        """
        e_char:   [B,e]
        h_time:   [B,T,d_model]
        tmask:    [B,T]
        choices:  [B,4,3]
        return logits: [B,4]
        """
        B, T, D = h_time.shape
        L = tmask.sum(dim=1)                     # [B]，每個樣本有效長度
        k = max(int(last_k), 1)                  # ← 修正：用 Python int 取最大至少 1

        # 取最後 k 步（不足就重複最後一個有效步；如果 T==0 就用 0）
        # 這段做法向量化、避免 Python for 迴圈慢：
        # ex: L=[14,10] & k=3 → 索引 [[11,12,13],[7,8,9]]（0-based）
        base = (L - 1).clamp(min=0)              # [B]
        offs = torch.arange(k, device=h_time.device)  # [k] = [0,1,...,k-1]
        idxs = (base.unsqueeze(1) - (k - 1 - offs)).clamp(min=0)  # [B,k]
        gather = h_time.gather(1, idxs.unsqueeze(-1).expand(B, k, D))  # [B,k,D]
        q_vec = gather.mean(dim=1)               # [B,D]

        e_proj = self.proj_e(e_char)             # [B,D]
        prefix_fused = self.ln(e_proj + q_vec)   # [B,D]

        # choices：三元組 cell-id → 平均後投影到 D
        ch = self._embed_choice_triplet(choices_ids)   # [B,4,Dc]
        ch = self.proj_choice(ch)                      # [B,4,D]

        pf = prefix_fused.unsqueeze(1).expand_as(ch)   # [B,4,D]
        feat = torch.cat([pf, ch, torch.abs(pf - ch), pf * ch], dim=-1)  # [B,4,4D]
        logits = self.mlp(feat).squeeze(-1)           # [B,4]
        return logits

class ToMNetVisionTransformer(nn.Module):
    """
    整合：VisionCharNet（ToMNet 風格特徵抽取 + Transformer 時序） + PredHead
    """
    def __init__(self, in_ch: int, d_model: int=256, nhead: int=8, n_layers: int=2,
                 stem_ch: int=32, n_blocks: int=4, e_char_dim: int=64, dropout: float=0.1,
                 cell_vocab: int = 24*24, use_cls_token: bool=True, last_k: int=3):
        super().__init__()
        self.vision = VisionCharNet(in_ch=in_ch, stem_ch=stem_ch, n_blocks=n_blocks,
                                    d_model=d_model, nhead=nhead, n_layers=n_layers,
                                    e_char_dim=e_char_dim, use_cls_token=use_cls_token, last_k=last_k)
        self.head = PredHead(d_model=d_model, e_char_dim=e_char_dim, dropout=dropout,
                             cell_vocab=cell_vocab)

        self.last_k = last_k

    def forward(self, grid_seq: torch.Tensor, tmask: torch.Tensor, choices_ids: torch.Tensor) -> torch.Tensor:
        # grid_seq: [B,T,H,W,C]
        C = grid_seq.size(-1)
        assert C % 6 == 0, f"Expect channels = 6 × n_agents (no obstacle). Got C={C}"
        e_char, h_time = self.vision(grid_seq, tmask)
        logits = self.head(e_char, h_time, tmask, choices_ids, last_k=self.last_k)
        return logits


In [None]:
from pathlib import Path
from dataclasses import dataclass
from typing import Tuple
import random, math
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader

from torch.utils.data import Dataset 
from grid_dataset import GridSeqDataset
from collate_grid import collate_grid
from pathlib import Path
import json
import torch
import numpy as np


_MODE_ALIASES = {
    "rulemap":"rulemap", "random":"random", "logic":"logic",
    "intermediate_case1":"intermediate_case1", "intermediate_case2":"intermediate_case2",
    "intermediate_1":"intermediate_case1", "intermediate_2":"intermediate_case2",
    "intemediete_1":"intermediate_case1", "intemediete_2":"intermediate_case2",
    "intermediat_case1":"intermediate_case1", "intermediat_case2":"intermediate_case2",
    "all":"all",
}
_ALL_MODES = ["rulemap","random","intermediate_case1","intermediate_case2"]  # 先不含 logic

@dataclass
class TrainConfig:
    data_root: Path
    sim_name: str = "sim1"
    mode: str = "all"   # or single mode
    # model
    d_model: int = 256
    nhead: int = 8
    n_layers: int = 2
    stem_ch: int = 32
    n_blocks: int = 4
    e_char_dim: int = 64
    dropout: float = 0.1
    last_k: int = 3
    # train
    batch_size: int = 32
    lr: float = 3e-4
    weight_decay: float = 1e-2
    max_epochs: int = 20
    seed: int = 42
    device: str = "cuda" if torch.cuda.is_available() else "cpu"
    save_dir: Path = Path("../checkpoints")

    def canonical_mode(self) -> str:
        key = self.mode.strip().lower()
        if key not in _MODE_ALIASES: 
            raise ValueError(f"Unknown mode: {self.mode}")
        return _MODE_ALIASES[key]

def set_seed(seed: int):
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

def make_loader(cfg: TrainConfig, split: str, mode: str) -> Tuple[DataLoader, int]:
    ds = GridSeqDataset(cfg.data_root, cfg.sim_name, split, mode)
    one = ds[0]["grid_seq"]
    in_ch = int(one.shape[-1])
    ld = DataLoader(ds, batch_size=cfg.batch_size, shuffle=(split=="train"),
                    num_workers=0, pin_memory=True, collate_fn=collate_grid)
    return ld, in_ch

@torch.no_grad()
def evaluate(model: nn.Module, loader: DataLoader, device: str) -> Tuple[float, float]:
    model.eval()
    n, corr, tot_loss = 0, 0, 0.0
    for batch in loader:
        x = {k:(v.to(device) if isinstance(v, torch.Tensor) else v) for k,v in batch.items()}
        logits = model(x["grid_seq"], x["tmask"], x["choices_ids"])
        loss = F.cross_entropy(logits, x["labels"])
        pred = logits.argmax(dim=-1)
        corr += (pred == x["labels"]).sum().item()
        tot_loss += loss.item() * logits.size(0)
        n += logits.size(0)
    return corr/max(n,1), tot_loss/max(n,1)

def train_one_mode(cfg: TrainConfig, mode: str):
    print(f"\n=== Training mode = {mode} ===")
    train_ld, in_ch = make_loader(cfg, "train", mode)
    val_ld, _ = make_loader(cfg, "val", mode)
    test_ld, _ = make_loader(cfg, "test", mode)

    model = ToMNetVisionTransformer(
        in_ch=in_ch, d_model=cfg.d_model, nhead=cfg.nhead,
        n_layers=cfg.n_layers, stem_ch=cfg.stem_ch, n_blocks=cfg.n_blocks,
        e_char_dim=cfg.e_char_dim, dropout=cfg.dropout,
        use_cls_token=True, last_k=cfg.last_k
    ).to(cfg.device)
    
    opt = torch.optim.AdamW(model.parameters(), lr=cfg.lr, weight_decay=cfg.weight_decay)
    sch = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=cfg.max_epochs)

    best_va, best_state = 0.0, None
    for ep in range(1, cfg.max_epochs+1):
        model.train()
        n = corr = 0
        tot_loss = 0.0
        
        for batch in train_ld:
            x = {k:(v.to(cfg.device) if isinstance(v, torch.Tensor) else v) for k,v in batch.items()}
            logits = model(x["grid_seq"], x["tmask"], x["choices_ids"])
            loss = F.cross_entropy(logits, x["labels"])
            
            opt.zero_grad()
            loss.backward()
            nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            opt.step()
            
            pred = logits.argmax(dim=-1)
            corr += (pred == x["labels"]).sum().item()
            tot_loss += loss.item() * logits.size(0)
            n += logits.size(0)
            
        tr_acc, tr_loss = corr/max(n,1), tot_loss/max(n,1)
        va_acc, va_loss = evaluate(model, val_ld, cfg.device)
        sch.step()
        
        print(f"[{mode}][Ep {ep:02d}] train acc={tr_acc:.3f} loss={tr_loss:.3f} | val acc={va_acc:.3f} loss={va_loss:.3f}")
        
        if va_acc > best_va:
            best_va, best_state = va_acc, {k:v.cpu() for k,v in model.state_dict().items()}

    if best_state is not None: 
        model.load_state_dict(best_state)
    te_acc, te_loss = evaluate(model, test_ld, cfg.device)
    print(f"[{mode}][TEST] acc={te_acc:.3f} loss={te_loss:.3f}")
    return model, {"val_acc": best_va, "test_acc": te_acc}


In [None]:
from typing import Dict
import torch

_ALL_MODES = ["random", "rulemap", "logic", "intermediate_case1", "intermediate_case2"]

def _model_name(sim_name: str, mode: str, kind: str = "ToMNet") -> str:
    return f"{kind}_{sim_name}_{mode}"

def _save_model(model: nn.Module, cfg: TrainConfig, mode: str, kind: str = "ToMNet") -> Path:
    cfg.save_dir.mkdir(parents=True, exist_ok=True)
    name = _model_name(cfg.sim_name, mode, kind)
    path = cfg.save_dir / f"{name}.pt"
    torch.save(model.state_dict(), path)
    print(f"[SAVE] {name} -> {path}")
    return path

def run(cfg: TrainConfig, kind: str = "ToMNet") -> Dict[str, nn.Module]:
    """
    kind: "ToMNet" / "ToMNetSNN" / "ToMNetConv", for filename
    Returns: {model_name: model_instance}
    """
    models: Dict[str, nn.Module] = {}

    if cfg.mode == "all":
        modes = _ALL_MODES
    else:
        modes = [cfg.mode]

    for mode in modes:
        print(f"[TRAIN] sim={cfg.sim_name} mode={mode}")
        model, best_acc = train_one_mode(cfg, mode)

        name = _model_name(cfg.sim_name, mode, kind)
        _save_model(model, cfg, mode, kind=kind)
        models[name] = model
        print(f"[DONE] {name} (best_acc={best_acc})")

    return models


## Train

In [None]:
cfg = TrainConfig(
    data_root=Path(".."),
    sim_name="3_12",
    mode="random",
    seed=42,
    batch_size=64,
    lr=1e-4,
    dropout=0.2,
    max_epochs=80,
)

models = run(cfg, kind="ToMNet_Transformer")  
TomNet_Random_Model = models["ToMNet_Transformer_3_12_random"]

[TRAIN] sim=3_12 mode=random

=== Training mode = random ===


  output = torch._nested_tensor_from_mask(


[random][Ep 01] train acc=0.320 loss=1.309 | val acc=0.370 loss=1.268
[random][Ep 02] train acc=0.374 loss=1.201 | val acc=0.370 loss=1.215
[random][Ep 03] train acc=0.414 loss=1.158 | val acc=0.340 loss=1.201
[random][Ep 04] train acc=0.434 loss=1.133 | val acc=0.370 loss=1.207
[random][Ep 05] train acc=0.435 loss=1.124 | val acc=0.370 loss=1.210
[random][Ep 06] train acc=0.458 loss=1.110 | val acc=0.360 loss=1.213
[random][Ep 07] train acc=0.480 loss=1.101 | val acc=0.340 loss=1.225
[random][Ep 08] train acc=0.492 loss=1.086 | val acc=0.360 loss=1.224
[random][Ep 09] train acc=0.507 loss=1.079 | val acc=0.360 loss=1.224
[random][Ep 10] train acc=0.516 loss=1.067 | val acc=0.390 loss=1.222
[random][Ep 11] train acc=0.522 loss=1.051 | val acc=0.380 loss=1.246
[random][Ep 12] train acc=0.526 loss=1.031 | val acc=0.340 loss=1.250
[random][Ep 13] train acc=0.535 loss=1.002 | val acc=0.340 loss=1.256
[random][Ep 14] train acc=0.580 loss=0.977 | val acc=0.360 loss=1.258
[random][Ep 15] trai

In [8]:
cfg = TrainConfig(
    data_root=Path(".."),
    sim_name="3_12",
    mode="rulemap",
    seed=42,
    batch_size=64,
    lr=1e-4,
    dropout=0.2,
    max_epochs=80,
)
models = run(cfg, kind="ToMNet_Transformer")   
TomNet_Rulemap_Model = models["ToMNet_Transformer_3_12_rulemap"]

[TRAIN] sim=3_12 mode=rulemap

=== Training mode = rulemap ===
[rulemap][Ep 01] train acc=0.649 loss=1.123 | val acc=0.840 loss=0.988
[rulemap][Ep 02] train acc=0.811 loss=0.695 | val acc=0.860 loss=0.645
[rulemap][Ep 03] train acc=0.844 loss=0.462 | val acc=0.860 loss=0.448
[rulemap][Ep 04] train acc=0.874 loss=0.345 | val acc=0.900 loss=0.342
[rulemap][Ep 05] train acc=0.904 loss=0.272 | val acc=0.900 loss=0.272
[rulemap][Ep 06] train acc=0.917 loss=0.228 | val acc=0.910 loss=0.233
[rulemap][Ep 07] train acc=0.936 loss=0.186 | val acc=0.920 loss=0.203
[rulemap][Ep 08] train acc=0.948 loss=0.158 | val acc=0.920 loss=0.173
[rulemap][Ep 09] train acc=0.961 loss=0.138 | val acc=0.920 loss=0.152
[rulemap][Ep 10] train acc=0.965 loss=0.124 | val acc=0.940 loss=0.141
[rulemap][Ep 11] train acc=0.966 loss=0.109 | val acc=0.960 loss=0.131
[rulemap][Ep 12] train acc=0.968 loss=0.104 | val acc=0.970 loss=0.118
[rulemap][Ep 13] train acc=0.974 loss=0.094 | val acc=0.960 loss=0.118
[rulemap][Ep 1

In [9]:
cfg = TrainConfig(
    data_root=Path(".."),
    sim_name="3_12",
    mode="logic",
    seed=42,
    batch_size=64,
    lr=1e-4,
    dropout=0.2,
    max_epochs=80,
)
models = run(cfg, kind="ToMNet_Transformer")   
TomNet_Logic_Model = models["ToMNet_Transformer_3_12_logic"]

[TRAIN] sim=3_12 mode=logic

=== Training mode = logic ===
[logic][Ep 01] train acc=0.676 loss=1.126 | val acc=0.820 loss=0.938
[logic][Ep 02] train acc=0.797 loss=0.697 | val acc=0.840 loss=0.569
[logic][Ep 03] train acc=0.830 loss=0.454 | val acc=0.890 loss=0.383
[logic][Ep 04] train acc=0.885 loss=0.336 | val acc=0.910 loss=0.297
[logic][Ep 05] train acc=0.919 loss=0.274 | val acc=0.930 loss=0.248
[logic][Ep 06] train acc=0.922 loss=0.235 | val acc=0.940 loss=0.210
[logic][Ep 07] train acc=0.935 loss=0.205 | val acc=0.950 loss=0.181
[logic][Ep 08] train acc=0.954 loss=0.170 | val acc=0.960 loss=0.156
[logic][Ep 09] train acc=0.961 loss=0.150 | val acc=0.970 loss=0.133
[logic][Ep 10] train acc=0.968 loss=0.131 | val acc=0.970 loss=0.115
[logic][Ep 11] train acc=0.984 loss=0.111 | val acc=0.970 loss=0.101
[logic][Ep 12] train acc=0.988 loss=0.093 | val acc=0.990 loss=0.089
[logic][Ep 13] train acc=0.988 loss=0.083 | val acc=0.990 loss=0.078
[logic][Ep 14] train acc=0.991 loss=0.065 | 

In [10]:
cfg = TrainConfig(
    data_root=Path(".."),
    sim_name="3_12",
    mode="intermediate_case1",
    seed=42,
    batch_size=64,
    lr=1e-4,
    dropout=0.2,
    max_epochs=80,
)
models = run(cfg, kind="ToMNet_Transformer")   
TomNet_Intermediate1_Model = models["ToMNet_Transformer_3_12_intermediate_case1"]

[TRAIN] sim=3_12 mode=intermediate_case1

=== Training mode = intermediate_case1 ===
[intermediate_case1][Ep 01] train acc=0.328 loss=1.300 | val acc=0.330 loss=1.264
[intermediate_case1][Ep 02] train acc=0.347 loss=1.195 | val acc=0.350 loss=1.224
[intermediate_case1][Ep 03] train acc=0.393 loss=1.160 | val acc=0.360 loss=1.215
[intermediate_case1][Ep 04] train acc=0.438 loss=1.135 | val acc=0.360 loss=1.206
[intermediate_case1][Ep 05] train acc=0.430 loss=1.119 | val acc=0.430 loss=1.201
[intermediate_case1][Ep 06] train acc=0.454 loss=1.106 | val acc=0.420 loss=1.197
[intermediate_case1][Ep 07] train acc=0.477 loss=1.089 | val acc=0.370 loss=1.193
[intermediate_case1][Ep 08] train acc=0.479 loss=1.083 | val acc=0.390 loss=1.188
[intermediate_case1][Ep 09] train acc=0.477 loss=1.081 | val acc=0.360 loss=1.188
[intermediate_case1][Ep 10] train acc=0.489 loss=1.071 | val acc=0.400 loss=1.187
[intermediate_case1][Ep 11] train acc=0.482 loss=1.056 | val acc=0.370 loss=1.194
[intermediate

In [11]:
cfg = TrainConfig(
    data_root=Path(".."),
    sim_name="3_12",
    mode="intermediate_case2",
    seed=42,
    batch_size=64,
    lr=1e-4,
    dropout=0.2,
    max_epochs=80,
)
models = run(cfg, kind="ToMNet_Transformer")   
TomNet_Intermediate2_Model = models["ToMNet_Transformer_3_12_intermediate_case2"]

[TRAIN] sim=3_12 mode=intermediate_case2

=== Training mode = intermediate_case2 ===
[intermediate_case2][Ep 01] train acc=0.326 loss=1.300 | val acc=0.400 loss=1.259
[intermediate_case2][Ep 02] train acc=0.417 loss=1.196 | val acc=0.340 loss=1.191
[intermediate_case2][Ep 03] train acc=0.415 loss=1.159 | val acc=0.320 loss=1.169
[intermediate_case2][Ep 04] train acc=0.424 loss=1.138 | val acc=0.340 loss=1.171
[intermediate_case2][Ep 05] train acc=0.454 loss=1.122 | val acc=0.340 loss=1.177
[intermediate_case2][Ep 06] train acc=0.465 loss=1.108 | val acc=0.340 loss=1.181
[intermediate_case2][Ep 07] train acc=0.484 loss=1.096 | val acc=0.380 loss=1.191
[intermediate_case2][Ep 08] train acc=0.481 loss=1.076 | val acc=0.400 loss=1.189
[intermediate_case2][Ep 09] train acc=0.492 loss=1.076 | val acc=0.380 loss=1.193
[intermediate_case2][Ep 10] train acc=0.482 loss=1.074 | val acc=0.360 loss=1.201
[intermediate_case2][Ep 11] train acc=0.491 loss=1.046 | val acc=0.360 loss=1.209
[intermediate

## Evaluation

In [None]:
from model_evaluate import evaluate_model_performance

def forward_tomnet(m: nn.Module, b: Dict[str, torch.Tensor]) -> torch.Tensor:
    _ = m(b["grid_seq"], b["tmask"], b["choices_ids"])
    return _  # logits

In [None]:
# random
val_loader, in_ch = make_loader(cfg, split="val", mode="random")

ds = GridSeqDataset(cfg.data_root, cfg.sim_name, "val", "random")
one_sample = ds[0]["grid_seq"]
input_hw = one_sample.shape[2:4]

metrics = evaluate_model_performance(
    TomNet_Random_Model,
    val_loader,
    device="cpu",
    forward_fn=forward_tomnet,
    energy_forward_fn=forward_tomnet)

print("=== ToMNet_Transformer_3_12_random model evaluation ===")
print(f"accuracy: {metrics.accuracy:.4f}")
print(f"precision: {metrics.precision:.4f}")
print(f"recall rate: {metrics.recall:.4f}")
print(f"F1 score: {metrics.f1_score:.4f}")
print(f"ROC AUC: {metrics.roc_auc:.4f}")
print(f"R2 Score: {metrics.r2_score:.4f}")
print(f"Prediction Entropy: {metrics.prediction_entropy:.4f}")
print("\n=== Energy and Efficiency Evaluation ===")
print(f"MACs per sample: {metrics.energy_report.macs_per_sample:,.0f}")
print(f"Energy per sample: {metrics.energy_report.ann_energy_per_sample:.2e} J")
print(f"Energy per accuracy unit: {metrics.energy_per_accuracy:.2e} J/acc")
print(f"MACs per accuracy unit: {metrics.macs_per_accuracy:,.0f} MACs/acc")

=== ToMNet_Transformer_3_12_random model evaluation ===
accuracy: 0.4000
precision: 0.4845
recall rate: 0.4000
F1 score: 0.4137
ROC AUC: 0.7123
R2 Score: -0.4060
Prediction Entropy: 0.3449

=== Energy and Efficiency Evaluation ===
MACs per sample: 133,230,592
Energy per sample: 6.13e-04 J
Energy per accuracy unit: 1.53e-03 J/acc
MACs per accuracy unit: 333,076,480 MACs/acc




In [16]:
# rulemap
val_loader, in_ch = make_loader(cfg, split="val", mode="rulemap")
metrics = evaluate_model_performance(
    TomNet_Rulemap_Model,
    val_loader,
    device="cpu",
    forward_fn=forward_tomnet,
    energy_forward_fn=forward_tomnet)

print("=== ToMNetSNN_3_12_rulemap model evaluation ===")
print(f"accuracy: {metrics.accuracy:.4f}")
print(f"precision: {metrics.precision:.4f}")
print(f"recall rate: {metrics.recall:.4f}")
print(f"F1 score: {metrics.f1_score:.4f}")
print(f"ROC AUC: {metrics.roc_auc:.4f}")
print(f"R2 Score: {metrics.r2_score:.4f}")
print(f"Prediction Entropy: {metrics.prediction_entropy:.4f}")
print("\n=== Energy and Efficiency Evaluation ===")
print(f"MACs per sample: {metrics.energy_report.macs_per_sample:,.0f}")
print(f"Energy per sample: {metrics.energy_report.ann_energy_per_sample:.2e} J")
print(f"Energy per accuracy unit: {metrics.energy_per_accuracy:.2e} J/acc")
print(f"MACs per accuracy unit: {metrics.macs_per_accuracy:,.0f} MACs/acc")



=== ToMNetSNN_3_12_rulemap model evaluation ===
accuracy: 0.9500
precision: 0.9510
recall rate: 0.9500
F1 score: 0.9503
ROC AUC: 0.9975
R2 Score: 0.9009
Prediction Entropy: 0.0554

=== Energy and Efficiency Evaluation ===
MACs per sample: 131,706,880
Energy per sample: 6.06e-04 J
Energy per accuracy unit: 6.38e-04 J/acc
MACs per accuracy unit: 138,638,821 MACs/acc


In [18]:
# logic
val_loader, in_ch = make_loader(cfg, split="val", mode="logic")
metrics = evaluate_model_performance(
    TomNet_Logic_Model,
    val_loader,
    device="cpu",
    forward_fn=forward_tomnet,
    energy_forward_fn=forward_tomnet)

print("=== ToMNetSNN_3_12_logic model evaluation ===")
print(f"accuracy: {metrics.accuracy:.4f}")
print(f"precision: {metrics.precision:.4f}")
print(f"recall rate: {metrics.recall:.4f}")
print(f"F1 score: {metrics.f1_score:.4f}")
print(f"ROC AUC: {metrics.roc_auc:.4f}")
print(f"R2 Score: {metrics.r2_score:.4f}")
print(f"Prediction Entropy: {metrics.prediction_entropy:.4f}")
print("\n=== Energy and Efficiency Evaluation ===")
print(f"MACs per sample: {metrics.energy_report.macs_per_sample:,.0f}")
print(f"Energy per sample: {metrics.energy_report.ann_energy_per_sample:.2e} J")
print(f"Energy per accuracy unit: {metrics.energy_per_accuracy:.2e} J/acc")
print(f"MACs per accuracy unit: {metrics.macs_per_accuracy:,.0f} MACs/acc")



=== ToMNetSNN_3_12_logic model evaluation ===
accuracy: 0.9700
precision: 0.9705
recall rate: 0.9700
F1 score: 0.9701
ROC AUC: 0.9995
R2 Score: 0.9697
Prediction Entropy: 0.0774

=== Energy and Efficiency Evaluation ===
MACs per sample: 131,674,112
Energy per sample: 6.06e-04 J
Energy per accuracy unit: 6.24e-04 J/acc
MACs per accuracy unit: 135,746,507 MACs/acc


In [19]:
# intermediate_case1
val_loader, in_ch = make_loader(cfg, split="val", mode="intermediate_case1")
metrics = evaluate_model_performance(
    TomNet_Intermediate1_Model,
    val_loader,
    device="cpu",
    forward_fn=forward_tomnet,
    energy_forward_fn=forward_tomnet)

print("=== ToMNetSNN_3_12_intermediate_case1 model evaluation ===")
print(f"accuracy: {metrics.accuracy:.4f}")
print(f"precision: {metrics.precision:.4f}")
print(f"recall rate: {metrics.recall:.4f}")
print(f"F1 score: {metrics.f1_score:.4f}")
print(f"ROC AUC: {metrics.roc_auc:.4f}")
print(f"R2 Score: {metrics.r2_score:.4f}")
print(f"Prediction Entropy: {metrics.prediction_entropy:.4f}")
print("\n=== Energy and Efficiency Evaluation ===")
print(f"MACs per sample: {metrics.energy_report.macs_per_sample:,.0f}")
print(f"Energy per sample: {metrics.energy_report.ann_energy_per_sample:.2e} J")
print(f"Energy per accuracy unit: {metrics.energy_per_accuracy:.2e} J/acc")
print(f"MACs per accuracy unit: {metrics.macs_per_accuracy:,.0f} MACs/acc")

=== ToMNetSNN_3_12_intermediate_case1 model evaluation ===
accuracy: 0.2300
precision: 0.3100
recall rate: 0.2300
F1 score: 0.2553
ROC AUC: 0.5454
R2 Score: -0.7368
Prediction Entropy: 0.3855

=== Energy and Efficiency Evaluation ===
MACs per sample: 133,230,592
Energy per sample: 6.13e-04 J
Energy per accuracy unit: 2.66e-03 J/acc
MACs per accuracy unit: 579,263,443 MACs/acc




In [20]:
# intermediate_case2
val_loader, in_ch = make_loader(cfg, split="val", mode="intermediate_case2")
metrics = evaluate_model_performance(
    TomNet_Intermediate2_Model,
    val_loader,
    device="cpu",
    forward_fn=forward_tomnet,
    energy_forward_fn=forward_tomnet)

print("=== ToMNetSNN_3_12_intermediate_case2 model evaluation ===")
print(f"accuracy: {metrics.accuracy:.4f}")
print(f"precision: {metrics.precision:.4f}")
print(f"recall rate: {metrics.recall:.4f}")
print(f"F1 score: {metrics.f1_score:.4f}")
print(f"ROC AUC: {metrics.roc_auc:.4f}")
print(f"R2 Score: {metrics.r2_score:.4f}")
print(f"Prediction Entropy: {metrics.prediction_entropy:.4f}")
print("\n=== Energy and Efficiency Evaluation ===")
print(f"MACs per sample: {metrics.energy_report.macs_per_sample:,.0f}")
print(f"Energy per sample: {metrics.energy_report.ann_energy_per_sample:.2e} J")
print(f"Energy per accuracy unit: {metrics.energy_per_accuracy:.2e} J/acc")
print(f"MACs per accuracy unit: {metrics.macs_per_accuracy:,.0f} MACs/acc")



=== ToMNetSNN_3_12_intermediate_case2 model evaluation ===
accuracy: 0.3500
precision: 0.4400
recall rate: 0.3500
F1 score: 0.3696
ROC AUC: 0.6114
R2 Score: -0.5414
Prediction Entropy: 0.3304

=== Energy and Efficiency Evaluation ===
MACs per sample: 133,230,592
Energy per sample: 6.13e-04 J
Energy per accuracy unit: 1.75e-03 J/acc
MACs per accuracy unit: 380,658,834 MACs/acc
