In [None]:
# scripts/gen_grid.py
from pathlib import Path
import numpy as np
from collate_grid import collate_grid


PROJECT_ROOT = Path("/Users/Jer_ry/Desktop/script_tom")
SIM_NAME     = "3_12"
SPLITS       = ["training", "validation", "testing"]
H = W = 12
N_AGENTS = 3

In [None]:
# A_fixed_model.py
import torch
import torch.nn as nn


class BasicBlock(nn.Module):
    """
    Basic Residual Block。
      in:  [B, ch, H, W]
      out: [B, ch, H, W]
    """
    def __init__(self, ch: int):
        super().__init__()
        self.c1 = nn.Conv2d(ch, ch, 3, padding=1)
        self.bn1 = nn.BatchNorm2d(ch)
        self.c2 = nn.Conv2d(ch, ch, 3, padding=1)
        self.bn2 = nn.BatchNorm2d(ch)
        self.act = nn.ReLU(inplace=True)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        h = self.act(self.bn1(self.c1(x)))
        h = self.bn2(self.c2(h))
        return self.act(x + h)


class CharacterNetwork(nn.Module):
    """
    Character Network: from τ_j compute e_char in all sequences

    Input:
      - grid_seq: [B, T, H, W, C]
      - tmask:    [B, T]  True=valiation

    Output:
      - e_char:   [B, e_char_dim]
    """
    def __init__(
        self,
        in_ch: int,
        stem_ch: int = 32,
        n_blocks: int = 4,
        d_model: int = 256,
        nhead: int = 8,
        n_layers: int = 2,
        e_char_dim: int = 64,
        use_cls_token: bool = True,
    ):
        super().__init__()
        self.use_cls_token = use_cls_token

        # per-frame CNN
        self.stem1x1 = nn.Conv2d(in_ch, stem_ch, 1)
        self.c3 = nn.Conv2d(stem_ch, stem_ch, 3, padding=1)
        self.bn3 = nn.BatchNorm2d(stem_ch)
        self.blocks = nn.Sequential(*[BasicBlock(stem_ch) for _ in range(n_blocks)])
        self.act = nn.ReLU(inplace=True)
        self.gap = nn.AdaptiveAvgPool2d(1)  # -> [B*T, stem_ch, 1, 1]

        # temporal Transformer
        self.in_proj = nn.Linear(stem_ch, d_model)
        enc_layer = nn.TransformerEncoderLayer(
            d_model=d_model,
            nhead=nhead,
            dim_feedforward=4 * d_model,
            batch_first=True,
        )
        self.encoder = nn.TransformerEncoder(enc_layer, num_layers=n_layers)
        self.ln = nn.LayerNorm(d_model)

        self.to_e = nn.Linear(d_model, e_char_dim)

    def forward(self, grid_seq: torch.Tensor, tmask: torch.Tensor) -> torch.Tensor:
        B, T, H, W, C = grid_seq.shape

        x = grid_seq.permute(0, 1, 4, 2, 3).contiguous()
        x = x.view(B * T, C, H, W)

        x = self.stem1x1(x)
        x = self.act(self.bn3(self.c3(x)))
        x = self.blocks(x)
        x = self.gap(x).squeeze(-1).squeeze(-1)  # [B*T, stem_ch]
        x = x.view(B, T, -1)  # [B,T,stem_ch]

        # 時序 Transformer
        h = self.in_proj(x)                  # [B,T,d_model]
        key_pad = ~tmask                     # True=padding
        h = self.encoder(h, src_key_padding_mask=key_pad)
        h = self.ln(h)

        # 時間維度聚合
        if self.use_cls_token:
            last_idx = tmask.sum(dim=1) - 1        # [B]
            e_src = h.gather(
                1, last_idx.view(B, 1, 1).expand(B, 1, h.size(-1))
            ).squeeze(1)                            # [B,d_model]
        else:
            m = tmask.unsqueeze(-1).float()
            e_src = (h * m).sum(dim=1) / m.sum(dim=1).clamp_min(1.0)

        e_char = self.to_e(e_src)                   # [B,e_char_dim]
        return e_char


class OptionScoringHead(nn.Module):
    """
    Four-choice head:

    - e_char + τ_k first frame q_k(0) → ctx
    - Each option is a set of cell ids (N cells) → embedding + masked mean → ch_vec
    - Score each (ctx, ch_vec) pair

    Input:
      e_char:       [B, e_char_dim]
      q0:           [B, H, W, C]
      choices_ids:  [B, 4, Nmax]
      choices_mask: [B, 4, Nmax]  True=valid cell

    Output:
      logits:       [B, 4]
    """
    def __init__(
        self,
        in_ch: int,
        stem_ch: int = 32,
        n_blocks: int = 3,
        e_char_dim: int = 64,
        cell_vocab: int = 24 * 24,
        choice_emb_dim: int = 128,
        d_ctx: int = 256,
        dropout: float = 0.1,
    ):
        super().__init__()
        # CNN for q_k(0)
        self.query_stem = nn.Conv2d(in_ch, stem_ch, 1)
        self.query_c3 = nn.Conv2d(stem_ch, stem_ch, 3, padding=1)
        self.query_bn3 = nn.BatchNorm2d(stem_ch)
        self.blocks = nn.Sequential(*[BasicBlock(stem_ch) for _ in range(n_blocks)])
        self.act = nn.ReLU(inplace=True)
        self.gap = nn.AdaptiveAvgPool2d(1)

        # ctx
        self.proj_ctx = nn.Sequential(
            nn.Linear(stem_ch + e_char_dim, d_ctx),
            nn.ReLU(inplace=True),
            nn.Dropout(dropout),
            nn.LayerNorm(d_ctx),
        )

        # cell embedding
        self.choice_emb = nn.Embedding(cell_vocab, choice_emb_dim)
        self.proj_choice = nn.Sequential(
            nn.Linear(choice_emb_dim, d_ctx),
            nn.ReLU(inplace=True),
            nn.Dropout(dropout),
            nn.LayerNorm(d_ctx),
        )

        # Dual-tower + scoring
        self.scorer = nn.Sequential(
            nn.Linear(d_ctx * 4, d_ctx),
            nn.ReLU(inplace=True),
            nn.Dropout(dropout),
            nn.Linear(d_ctx, 1),
        )

    def forward(
        self,
        e_char: torch.Tensor,
        q0: torch.Tensor,
        choices_ids: torch.Tensor,
        choices_mask: torch.Tensor,
    ) -> torch.Tensor:
        B, H, W, C = q0.shape

        # q_k 第一幀
        q = q0.permute(0, 3, 1, 2).contiguous()   # [B,C,H,W]
        q = self.query_stem(q)
        q = self.act(self.query_bn3(self.query_c3(q)))
        q = self.blocks(q)
        q_vec = self.gap(q).squeeze(-1).squeeze(-1)  # [B,stem_ch]

        # ctx
        ctx = self.proj_ctx(torch.cat([q_vec, e_char], dim=-1))  # [B,d_ctx]

        # 選項嵌入 + masked mean
        B2, K, Nmax = choices_ids.shape  # K=4
        emb = self.choice_emb(choices_ids)           # [B,4,Nmax,D_emb]
        m = choices_mask.unsqueeze(-1).float()       # [B,4,Nmax,1]
        ch_vec = (emb * m).sum(dim=2) / m.sum(dim=2).clamp_min(1.0)  # [B,4,D_emb]
        ch_vec = self.proj_choice(ch_vec)           # [B,4,d_ctx]

        # 雙塔打分
        ctx4 = ctx.unsqueeze(1).expand_as(ch_vec)   # [B,4,d_ctx]
        feat = torch.cat(
            [ctx4, ch_vec, torch.abs(ctx4 - ch_vec), ctx4 * ch_vec],
            dim=-1,
        )                                           # [B,4,4*d_ctx]
        logits = self.scorer(feat).squeeze(-1)      # [B,4]
        return logits


class ToMNet(nn.Module):
    """
    Overall ToMNet v2:
      - CharacterNetwork: from τ_j compute e_char
      - OptionScoringHead: (e_char, τ_k(0), choices_cell) → 4 logits
    """
    def __init__(
        self,
        in_ch: int,
        d_model: int = 256,
        nhead: int = 8,
        n_layers: int = 2,
        stem_ch: int = 32,
        n_blocks: int = 4,
        e_char_dim: int = 64,
        cell_vocab: int = 24 * 24,
        dropout: float = 0.1,
        use_cls_token: bool = True,
    ):
        super().__init__()
        self.char_net = CharacterNetwork(
            in_ch=in_ch,
            stem_ch=stem_ch,
            n_blocks=n_blocks,
            d_model=d_model,
            nhead=nhead,
            n_layers=n_layers,
            e_char_dim=e_char_dim,
            use_cls_token=use_cls_token,
        )
        self.head = OptionScoringHead(
            in_ch=in_ch,
            stem_ch=stem_ch,
            n_blocks=n_blocks,
            e_char_dim=e_char_dim,
            cell_vocab=cell_vocab,
            choice_emb_dim=128,
            d_ctx=256,
            dropout=dropout,
        )

    def forward(
        self,
        grid_seq_j: torch.Tensor,
        tmask_j: torch.Tensor,
        grid_seq_k: torch.Tensor,
        tmask_k: torch.Tensor,
        choices_ids: torch.Tensor,
        choices_mask: torch.Tensor,
    ) -> torch.Tensor:
        e_char = self.char_net(grid_seq_j, tmask_j)      # [B,e_char]
        q0 = grid_seq_k[:, 0, ...]                       # [B,H,W,C]
        logits = self.head(e_char, q0, choices_ids, choices_mask)  # [B,4]
        return logits


In [None]:
# B_fixed_dataset.py
import json
import random
from pathlib import Path
from typing import Dict, Any, List, Optional
import re

import numpy as np
from torch.utils.data import Dataset

_num_pat = re.compile(r"-?\d+(?:\.\d+)?")


def _infer_n_agents_from_len(K: int) -> int:
    if K % N_AGENTS != 0:
        raise ValueError(f" Length of trail {K} is not divisible by N_AGENTS={N_AGENTS} (possibly mixed with 5-agent data)")
    
    return N_AGENTS

def _read_trail_xy(step_dir: Path) -> np.ndarray:
    trail = Path(step_dir) / "trail.txt"
    if not trail.exists():
        raise FileNotFoundError(f"trail.txt not found at {trail}")

    try:
        text = trail.read_text(encoding="utf-8")
        nums = list(map(float, _num_pat.findall(text)))
        if len(nums) == 0 or len(nums) % 2 != 0:
            raise ValueError(f"Cannot parse into (x,y) sequence, data length: {len(nums)}")

        arr = np.asarray(nums, dtype=float).reshape(-1, 2)

        # Simple check: if world coordinates (e.g. 7..18), shift back to 1..12
        mn, mx = float(arr.min()), float(arr.max())
        if mn >= 6.5:
            arr = arr - 6.0

        # Modification: clip to 1..12
        arr[..., 0] = np.clip(arr[..., 0], 1.0, float(W))
        arr[..., 1] = np.clip(arr[..., 1], 1.0, float(H))
        return arr
    except Exception as e:
        raise ValueError(f"[Error reading] {trail}: {e}") from e

def _load_grid_seq(step_dir: Path) -> np.ndarray:
    step_dir = Path(step_dir)

    # 1. Prefer loading npy
    p = step_dir / "grid_seq.npy"
    if p.exists():
        arr = np.load(p)
        return arr.astype(np.float32)

    # 2. Load npz
    pz = step_dir / "grid_seq.npz"
    if pz.exists():
        arr = np.load(pz)["grid_seq"]
        return arr.astype(np.float32)

    # 3. [Modification] If neither exists, raise an error instead of trying to generate from txt
    # Because the logic for generating from txt (3 channels) is incompatible with npy (18 channels)
    raise FileNotFoundError(
        f"Missing grid_seq.npy in {step_dir}. "
        "Please run the gridizer script (F_run_gridizer.py) to generate it first."
    )

def _make_tmask(T: int) -> np.ndarray:
    return np.ones(T, dtype=bool)

class ToMNet2JsonlDataset(Dataset):
    def __init__(self, split_root: Path, sim_name: str = "sim5_24", use_modes: Optional[List[str]] = None, seed: int = 42):
        super().__init__()
        base_root = Path(split_root).resolve()
        cand1 = base_root / "tomnet2_format.jsonl"
        
        if cand1.exists():
            self.split_root = base_root
            self.jsonl = cand1
        else:
            raise FileNotFoundError(f"jsonl not found at {cand1}")

        self.rng = random.Random(seed)
        self.recs: List[Dict[str, Any]] = []

        for line in self.jsonl.read_text(encoding="utf-8").splitlines():
            rec = json.loads(line)
            if use_modes and rec.get("mode") not in use_modes:
                continue
            self.recs.append(rec)

        if not self.recs:
            raise RuntimeError(f"No records in {self.jsonl} for modes={use_modes}")

        self.in_ch: Optional[int] = None 

    def __len__(self) -> int:
        return len(self.recs)

    def __getitem__(self, idx: int) -> Dict[str, Any]:
        r = self.recs[idx]
        hdir = self.split_root / r["history_dir"]
        qdir = self.split_root / r["query_dir"]

        # If an error occurs during loading, it will be caught by the try-except in _load_grid_seq and the path will be displayed
        gs_j = _load_grid_seq(hdir)     
        gs_k = _load_grid_seq(qdir)     
        
        Tj, Tk = gs_j.shape[0], gs_k.shape[0]
        tm_j = _make_tmask(Tj)          
        tm_k = _make_tmask(Tk)          

        choices = r["choices_cell"]    
        N_max = max(len(x) for x in choices)

        pad_choices: List[List[int]] = []
        pad_masks:   List[List[bool]] = []

        for arr in choices:
            m = [True] * len(arr) + [False] * (N_max - len(arr))   
            a = arr + [0] * (N_max - len(arr))                     
            pad_choices.append(a)
            pad_masks.append(m)

        item = {
            "grid_seq_j": gs_j,
            "tmask_j": tm_j,
            "grid_seq_k": gs_k,
            "tmask_k": tm_k,
            "choices_ids": np.asarray(pad_choices, dtype=np.int64),  
            "choices_mask": np.asarray(pad_masks, dtype=bool),       
            "label_idx": int(r["label_idx"]),
        }

        if self.in_ch is None:
            self.in_ch = int(gs_j.shape[-1])

        return item

In [None]:
# C_fixed_collate.py
from typing import List, Dict, Any, Tuple

import numpy as np
import torch


def _pad_time(batch_np: List[np.ndarray]) -> torch.Tensor:
    """
    batch_np: list of (T, H, W, C_i)，C_i can be different
    Pads to (B, T_max, H, W, C_max), with smaller C
    """
    B = len(batch_np)
    T_max = max(x.shape[0] for x in batch_np)
    H, W = batch_np[0].shape[1], batch_np[0].shape[2]
    C_max = max(x.shape[3] for x in batch_np)

    out = np.zeros((B, T_max, H, W, C_max), dtype=np.float32)
    for i, x in enumerate(batch_np):
        t, h, w, c = x.shape
        out[i, :t, :h, :w, :c] = x  # fewer channels naturally only overwrite the first c channels

    return torch.from_numpy(out)

def _pad_mask(batch_mask: List[np.ndarray]) -> torch.Tensor:
    """
    [T_i] -> [B,T_max] bool
    """
    T_max = max(x.shape[0] for x in batch_mask)
    B = len(batch_mask)
    out = np.zeros((B, T_max), dtype=bool)
    for i, m in enumerate(batch_mask):
        t = m.shape[0]
        out[i, :t] = m
    return torch.from_numpy(out)


def collate_tomnet2(batch: List[Dict[str, Any]]) -> Dict[str, torch.Tensor]:
    # τ_j
    gs_j = _pad_time([b["grid_seq_j"] for b in batch])
    tm_j = _pad_mask([b["tmask_j"] for b in batch])

    # τ_k
    gs_k = _pad_time([b["grid_seq_k"] for b in batch])
    tm_k = _pad_mask([b["tmask_k"] for b in batch])

    choices_ids = torch.stack([torch.from_numpy(b["choices_ids"]) for b in batch], dim=0)
    choices_mask = torch.stack([torch.from_numpy(b["choices_mask"]) for b in batch], dim=0)
    
    labels = torch.tensor([b["label_idx"] for b in batch], dtype=torch.long)

    return dict(
        grid_seq_j=gs_j,
        tmask_j=tm_j,
        grid_seq_k=gs_k,
        tmask_k=tm_k,
        choices_ids=choices_ids,
        choices_mask=choices_mask,
        labels=labels,
    )

In [None]:
# D_build_tomnet2_pairs.py
import json
import math
import random
import re
from pathlib import Path
from typing import List, Optional, Tuple, Set

import numpy as np


# split： training / validation / testing
SPLITS = ["training", "validation", "testing"]

MODE_DIRS = {
    "logic":              "logic",
    "rulemap":            "rulemap",
    "random":             "random",
    "intermediate_case1": "intermediate_case1",
    "intermediate_case2": "intermediate_case2",
}


RNG_SEED = 1234

NUM_DISTRACTORS = 3
MIN_CENTROID_L2_FROM_GOLD = 5.0
MAX_JACCARD_OVERLAP = 0.0
MAX_RETRY_PER_DISTRACTOR = 500

_num_pat = re.compile(r"-?\d+(?:\.\d+)?")

# ---------- tools---------

def _step_id(p: Path) -> Optional[int]:
    m = re.search(r"step_plot_\w+_(\d+)$", p.name)
    return int(m.group(1)) if m else None


def _read_trail(trail_txt: Path) -> np.ndarray:
    text = trail_txt.read_text(encoding="utf-8")
    nums = list(map(float, _num_pat.findall(text)))
    if len(nums) == 0 or len(nums) % 2 != 0:
        raise ValueError(f"{trail_txt} cannot be parsed into (x,y) sequence")
    arr = np.asarray(nums, dtype=float).reshape(-1, 2)

    # Simple check: if world coordinates (e.g. 7..18), shift back to 1..12
    mn, mx = float(arr.min()), float(arr.max())
    if mn >= 6.5: 
        arr = arr - 6.0
        
    arr = np.clip(arr, 1.0, float(W))
    return arr


def _infer_n_agents_from_len(K: int) -> int:
    """
    For the 3_12 dataset, only 3 agents are allowed.
    """
    if K % N_AGENTS != 0:
        raise ValueError(f"trail length {K} is not divisible by N_AGENTS={N_AGENTS}")
    return N_AGENTS


def _coord_to_cid(x: float, y: float, W: int = 12) -> int: # 修改 W 預設值
    r = int(round(y)) - 1
    c = int(round(x)) - 1
    r = min(max(r, 0), W - 1)
    c = min(max(c, 0), W - 1)
    return r * W + c


def _xy_to_cells(arr: np.ndarray) -> List[int]:
    return [_coord_to_cid(float(x), float(y), W=W) for x, y in arr]


def _centroid(arr: np.ndarray) -> Tuple[float, float]:
    return float(arr[:, 0].mean()), float(arr[:, 1].mean())


def _l2(p: Tuple[float, float], q: Tuple[float, float]) -> float:
    return math.sqrt((p[0] - q[0]) ** 2 + (p[1] - q[1]) ** 2)


def _jaccard(a: Set[int], b: Set[int]) -> float:
    if not a and not b:
        return 1.0
    return len(a & b) / max(1, len(a | b))


def _sample_cluster_far_from_gold(
    N: int,
    gold_cells: Set[int],
    gold_centroid: Tuple[float, float],
    min_centroid_l2: float,
    max_jaccard_overlap: float,
    max_retry: int,
) -> np.ndarray:
    for _ in range(max_retry):
        cells = random.sample(range(W * H), k=N)
        if len(set(cells)) != N:
            continue
        if _jaccard(set(cells), gold_cells) > max_jaccard_overlap:
            continue

        xy = []
        for cid in cells:
            r, c = divmod(cid, W)
            xy.append([c + 1, r + 1])   # (x,y)
        arr = np.asarray(xy, dtype=float)

        if _l2(_centroid(arr), gold_centroid) < min_centroid_l2:
            continue

        return arr

    raise RuntimeError(
        "Cant satisfy the distance / overlap constraints after MAX_RETRY attempts"
    )


def _mode_from_dir(step_dir: Path) -> str:
    parts = [p.name.lower() for p in step_dir.parents] + [step_dir.name.lower()]
    if "logic" in parts: return "logic"
    if "rulemap" in parts: return "rulemap"
    if "random" in parts: return "random"
    if "intermediate_case1" in parts: return "intermediate_case1"
    if "intermediate_case2" in parts: return "intermediate_case2"
    return "unknown"


def _list_odd_even(split_root: Path):
    odd, even = [], []
    for sub in MODE_DIRS.values():
        base = split_root / sub
        if not base.exists():
            continue
        for p in base.glob("step_plot_*"):
            if not p.is_dir():
                continue
            sid = _step_id(p)
            if sid is None:
                continue
            (odd if sid % 2 == 1 else even).append(p)
    odd.sort(key=lambda x: _step_id(x))
    even.sort(key=lambda x: _step_id(x))
    return odd, even

# ---------- build one split ----------

def build_one_split(split_root: Path) -> int:
    split_root = split_root.resolve()
    print(f"[INFO] build split at {split_root}")

    random.seed(RNG_SEED)
    np.random.seed(RNG_SEED)

    out_jsonl = split_root / "tomnet2_format.jsonl"
    if out_jsonl.exists():
        out_jsonl.unlink() 

    out_jsonl.parent.mkdir(parents=True, exist_ok=True)

    odd_dirs, even_dirs = _list_odd_even(split_root)

    if not even_dirs:
        print(f"[WARN] {split_root}: no even steps, skipping")
        return 0
    if not odd_dirs:
        print(f"[WARN] {split_root}: no odd steps (history), skipping")
        return 0

    total = 0
    with out_jsonl.open("w", encoding="utf-8") as fout:
        for qdir in even_dirs:
            trail = qdir / "trail.txt"
            if not trail.exists():
                continue

            try:
                xy = _read_trail(trail)
                N = _infer_n_agents_from_len(xy.shape[0]) 
                T = xy.shape[0] // N
                gold = xy[-N:, :].copy()
                gold_cells = set(_xy_to_cells(gold))
                gold_cent = _centroid(gold)
            except Exception as e:
                print(f"[skip] {qdir.name}: {e}")
                continue

            distractors: List[np.ndarray] = []
            seen_keys = {tuple(map(int, gold.flatten()))}

            for _ in range(NUM_DISTRACTORS):
                try:
                    arr = _sample_cluster_far_from_gold(
                        N=N,
                        gold_cells=gold_cells,
                        gold_centroid=gold_cent,
                        min_centroid_l2=MIN_CENTROID_L2_FROM_GOLD,
                        max_jaccard_overlap=MAX_JACCARD_OVERLAP,
                        max_retry=MAX_RETRY_PER_DISTRACTOR,
                    )
                except RuntimeError:
                    break
                k = tuple(map(int, arr.flatten()))
                if k in seen_keys:
                    continue
                seen_keys.add(k)
                distractors.append(arr)

            if len(distractors) < NUM_DISTRACTORS:
                need = NUM_DISTRACTORS - len(distractors)
                for _ in range(need):
                    try:
                        arr = _sample_cluster_far_from_gold(
                            N=N,
                            gold_cells=gold_cells,
                            gold_centroid=gold_cent,
                            min_centroid_l2=max(1.0, MIN_CENTROID_L2_FROM_GOLD * 0.6),
                            max_jaccard_overlap=MAX_JACCARD_OVERLAP,
                            max_retry=MAX_RETRY_PER_DISTRACTOR * 2,
                        )
                    except RuntimeError:
                        break
                    k = tuple(map(int, arr.flatten()))
                    if k in seen_keys:
                        continue
                    seen_keys.add(k)
                    distractors.append(arr)

            if len(distractors) != NUM_DISTRACTORS:
                print(f"[skip] {qdir.name}: insufficient distractors, skipping")
                continue

            all_choices = [gold] + distractors
            cell_choices = [[int(c) for c in _xy_to_cells(a)] for a in all_choices]

            order = list(range(4))
            random.shuffle(order)
            shuffled = [cell_choices[i] for i in order]
            label_idx = order.index(0)

            hdir = random.choice(odd_dirs)

            rec = {
                "history_dir": str(hdir.relative_to(split_root)),
                "query_dir":   str(qdir.relative_to(split_root)),
                "mode":        _mode_from_dir(qdir),
                "n_agents":    N,
                "T":           T,
                "choices_cell": shuffled,
                "label_idx":   label_idx,
                "label_letter": ["A","B","C","D"][label_idx],
            }
            fout.write(json.dumps(rec, ensure_ascii=False) + "\n")
            total += 1

    print(f"✓ {split_root}: wrote {total} pairs -> {out_jsonl.name}")
    return total

In [None]:

BASE_3_12 = Path("/Users/Jer_ry/Desktop/script_tom/data/3_12")

def main() -> None:
    base = BASE_3_12.resolve()
    print(f"[INFO] use base = {base}")

    for name in SPLITS:
        p = base / name
        print(f"[DEBUG] check split dir: {p} -> exists={p.is_dir()}")
        if not p.is_dir():
            print(f"[WARN] split {name} does not exist: {p}")

    tot = 0
    for sp in SPLITS:
        split_root = base / sp      # e.g. /.../3_12/training
        if not split_root.exists():
            print(f"[INFO] skip {split_root} (not found)")
            continue
        tot += build_one_split(split_root)

    print(f"Done. total={tot}")


if __name__ == "__main__":
    main()


[INFO] use base = /Users/Jer_ry/Desktop/script_tom/data/3_12
[DEBUG] check split dir: /Users/Jer_ry/Desktop/script_tom/data/3_12/training -> exists=True
[DEBUG] check split dir: /Users/Jer_ry/Desktop/script_tom/data/3_12/validation -> exists=True
[DEBUG] check split dir: /Users/Jer_ry/Desktop/script_tom/data/3_12/testing -> exists=True
[INFO] build split at /Users/Jer_ry/Desktop/script_tom/data/3_12/training
✓ /Users/Jer_ry/Desktop/script_tom/data/3_12/training: wrote 4000 pairs -> tomnet2_format.jsonl
[INFO] build split at /Users/Jer_ry/Desktop/script_tom/data/3_12/validation
✓ /Users/Jer_ry/Desktop/script_tom/data/3_12/validation: wrote 500 pairs -> tomnet2_format.jsonl
[INFO] build split at /Users/Jer_ry/Desktop/script_tom/data/3_12/testing
✓ /Users/Jer_ry/Desktop/script_tom/data/3_12/testing: wrote 500 pairs -> tomnet2_format.jsonl
Done. total=5000


In [None]:
# E_fixed_train.py
import random
from dataclasses import dataclass
from pathlib import Path
from typing import Tuple, List

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader

_MODE_ALIASES = {
    "rulemap": "rulemap",
    "random": "random",
    "logic": "logic",
    "intermediate_case1": "intermediate_case1",
    "intermediate_case2": "intermediate_case2",
    "all": "all",
}

_ALL_MODES: List[str] = [
    "rulemap",
    "random",
    "intermediate_case1",
    "intermediate_case2",
    "logic",
]


@dataclass
class TrainConfig:
    data_root: Path = Path(".")    
    sim_name: str = "sim5_24"
    mode: str = "all"              

    d_model: int = 256
    nhead: int = 8
    n_layers: int = 2
    stem_ch: int = 32
    n_blocks: int = 4
    e_char_dim: int = 64
    dropout: float = 0.1

    batch_size: int = 32
    lr: float = 3e-4
    weight_decay: float = 1e-2
    max_epochs: int = 30

    seed: int = 42
    device: str = "cuda" if torch.cuda.is_available() else "cpu"

    save_dir: Path = Path("../checkpoints")

    def canonical_mode(self) -> str:
        key = self.mode.strip().lower()
        if key not in _MODE_ALIASES:
            raise ValueError(f"Unknown mode: {self.mode}")
        return _MODE_ALIASES[key]


def set_seed(seed: int) -> None:
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)


def _probe_in_ch(ds: ToMNet2JsonlDataset) -> int:
    one = ds[0]
    return int(one["grid_seq_j"].shape[-1])


def make_loader(cfg: TrainConfig, split: str, mode: str) -> Tuple[DataLoader, int]:
    split_map = {
        "train": "training",
        "val":   "validation",
        "test":  "testing",
    }
    split_dir = cfg.data_root / split_map[split]
    use_modes = None if mode == "all" else [mode]

    ds = ToMNet2JsonlDataset(
        split_root=split_dir,
        sim_name=cfg.sim_name,
        use_modes=use_modes,
        seed=cfg.seed,
    )
    in_ch = ds.in_ch if ds.in_ch is not None else _probe_in_ch(ds)

    loader = DataLoader(
        ds,
        batch_size=cfg.batch_size,
        shuffle=(split == "train"),
        num_workers=0,
        pin_memory=True,
        collate_fn=collate_tomnet2,
    )
    return loader, in_ch


@torch.no_grad()
def evaluate(
    model: nn.Module, loader: DataLoader, device: str
) -> Tuple[float, float]:
    model.eval()
    n = 0
    corr = 0
    tot_loss = 0.0

    for batch in loader:
        x = {
            k: (v.to(device) if isinstance(v, torch.Tensor) else v)
            for k, v in batch.items()
        }
        logits = model(
            x["grid_seq_j"],
            x["tmask_j"],
            x["grid_seq_k"],
            x["tmask_k"],
            x["choices_ids"],
            x["choices_mask"],
        )
        loss = F.cross_entropy(logits, x["labels"])

        pred = logits.argmax(dim=-1)
        corr += (pred == x["labels"]).sum().item()
        tot_loss += loss.item() * logits.size(0)
        n += logits.size(0)

    acc = corr / max(n, 1)
    avg_loss = tot_loss / max(n, 1)
    return acc, avg_loss


def train_one_mode(cfg: TrainConfig, mode: str):
    print(f"\n=== Training mode = {mode} ===")

    
    train_ld, in_ch = make_loader(cfg, "train", mode)
    val_ld, _ = make_loader(cfg, "val", mode)
    test_ld, _ = make_loader(cfg, "test", mode)

    model = ToMNet(
        in_ch=in_ch,
        d_model=cfg.d_model,
        nhead=cfg.nhead,
        n_layers=cfg.n_layers,
        stem_ch=cfg.stem_ch,
        n_blocks=cfg.n_blocks,
        e_char_dim=cfg.e_char_dim,
        dropout=cfg.dropout,
    ).to(cfg.device)

    opt = torch.optim.AdamW(
        model.parameters(), lr=cfg.lr, weight_decay=cfg.weight_decay
    )
    sch = torch.optim.lr_scheduler.CosineAnnealingLR(
        opt, T_max=cfg.max_epochs
    )

    best_va = 0.0
    best_state = None

    for ep in range(1, cfg.max_epochs + 1):
        model.train()
        n = 0
        corr = 0
        tot_loss = 0.0

        for batch in train_ld:
            x = {
                k: (v.to(cfg.device) if isinstance(v, torch.Tensor) else v)
                for k, v in batch.items()
            }
            logits = model(
                x["grid_seq_j"],
                x["tmask_j"],
                x["grid_seq_k"],
                x["tmask_k"],
                x["choices_ids"],
                x["choices_mask"],
            )
            loss = F.cross_entropy(logits, x["labels"])

            opt.zero_grad(set_to_none=True)
            loss.backward()
            nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            opt.step()

            pred = logits.argmax(dim=-1)
            corr += (pred == x["labels"]).sum().item()
            tot_loss += loss.item() * logits.size(0)
            n += logits.size(0)

        tr_acc = corr / max(n, 1)
        tr_loss = tot_loss / max(n, 1)
        va_acc, va_loss = evaluate(model, val_ld, cfg.device)
        sch.step()

        print(
            f"[{mode}][Ep {ep:02d}] "
            f"train acc={tr_acc:.3f} loss={tr_loss:.3f} | "
            f"val acc={va_acc:.3f} loss={va_loss:.3f}"
        )

        if va_acc > best_va:
            best_va = va_acc
            best_state = {k: v.cpu() for k, v in model.state_dict().items()}

    if best_state is not None:
        model.load_state_dict(best_state)

    te_acc, te_loss = evaluate(model, test_ld, cfg.device)
    print(f"[{mode}][TEST] acc={te_acc:.3f} loss={te_loss:.3f}")

    return model, {"val_acc": best_va, "test_acc": te_acc}


_ALL_MODES = ["random", "rulemap", "logic", "intermediate_case1", "intermediate_case2"]

def _model_name(sim_name: str, mode: str, kind: str = "ToMNet") -> str:
    return f"{kind}_{sim_name}_{mode}"

def _save_model(model: nn.Module, cfg: TrainConfig, mode: str, kind: str = "ToMNet") -> Path:
    cfg.save_dir.mkdir(parents=True, exist_ok=True)
    name = _model_name(cfg.sim_name, mode, kind)
    path = cfg.save_dir / f"{name}.pt"
    torch.save(model.state_dict(), path)
    print(f"[SAVE] {name} -> {path}")
    return path

def run(cfg: TrainConfig, kind: str = "ToMNet") -> Dict[str, nn.Module]:
    models: Dict[str, nn.Module] = {}

    if cfg.mode == "all":
        modes = _ALL_MODES
    else:
        modes = [cfg.mode]

    for mode in modes:
        print(f"[TRAIN] sim={cfg.sim_name} mode={mode}")
        model, best_acc = train_one_mode(cfg, mode)

        name = _model_name(cfg.sim_name, mode, kind)
        _save_model(model, cfg, mode, kind=kind)
        models[name] = model
        print(f"[DONE] {name} (best_acc={best_acc})")

    return models


In [None]:
# F_run_gridizer.py
from pathlib import Path
import numpy as np
import sys

sys.path.append("..") 

try:
    from scripts.gridizer_multitrack import coords_to_grid_seq_multitrack
except ImportError as e:
    raise ImportError(f"can not find gridizer_multitrack.py, please check {e}")


DATA_ROOT = Path("../data/3_12") 
N_AGENTS = 3
MAZE_SIZE = 12

def process_missing_npy():
    print(f"Scanning {DATA_ROOT} for missing .npy files...")
    count = 0
    for trail_txt in DATA_ROOT.rglob("trail.txt"):
        step_dir = trail_txt.parent
        npy_path = step_dir / "grid_seq.npy"
        
        # If .npy does not exist, generate it
        if not npy_path.exists():
            try:
                # Call the standard gridizer
                # Note: This assumes the coordinate format in trail.txt is compatible with the gridizer 
                grid = coords_to_grid_seq_multitrack(
                    trail_txt_path=trail_txt,
                    n_agents=N_AGENTS,
                    H=MAZE_SIZE,
                    W=MAZE_SIZE,
                    keep_init_every_frame=True
                )
                
                np.save(npy_path, grid)
                count += 1
                if count % 100 == 0:
                    print(f"Generated {count} npy files...")
                    
            except Exception as e:
                print(f"[Error] Failed to gridize {step_dir}: {e}")

    if count == 0:
        print("All looks good! No missing .npy files.")
    else:
        print(f"Done. Generated {count} missing .npy files.")

if __name__ == "__main__":
    process_missing_npy()

Scanning ../data/3_12 for missing .npy files...
All looks good! No missing .npy files.


## Train

In [29]:
cfg = TrainConfig(
    data_root=Path("../data/3_12"),
    sim_name="3_12",
    mode="random",
    seed=42,
    batch_size=64,
    lr=1e-4,
    dropout=0.2,
    max_epochs=80,
)

models = run(cfg, kind="ToMNet2")
TomNet_Random_Model = models["ToMNet2_3_12_random"]

[TRAIN] sim=3_12 mode=random

=== Training mode = random ===




[random][Ep 01] train acc=0.280 loss=1.386 | val acc=0.470 loss=1.320
[random][Ep 02] train acc=0.398 loss=1.327 | val acc=0.460 loss=1.281
[random][Ep 03] train acc=0.444 loss=1.278 | val acc=0.450 loss=1.236
[random][Ep 04] train acc=0.455 loss=1.243 | val acc=0.480 loss=1.191
[random][Ep 05] train acc=0.521 loss=1.190 | val acc=0.520 loss=1.161
[random][Ep 06] train acc=0.535 loss=1.152 | val acc=0.540 loss=1.141
[random][Ep 07] train acc=0.554 loss=1.118 | val acc=0.560 loss=1.114
[random][Ep 08] train acc=0.562 loss=1.087 | val acc=0.540 loss=1.106
[random][Ep 09] train acc=0.568 loss=1.067 | val acc=0.590 loss=1.099
[random][Ep 10] train acc=0.588 loss=1.036 | val acc=0.600 loss=1.083
[random][Ep 11] train acc=0.611 loss=1.009 | val acc=0.550 loss=1.100
[random][Ep 12] train acc=0.620 loss=0.967 | val acc=0.560 loss=1.091
[random][Ep 13] train acc=0.645 loss=0.932 | val acc=0.580 loss=1.095
[random][Ep 14] train acc=0.647 loss=0.922 | val acc=0.600 loss=1.066
[random][Ep 15] trai

In [34]:
cfg = TrainConfig(
    data_root=Path("../data/3_12"),
    sim_name="3_12",
    mode="rulemap",
    seed=42,
    batch_size=64,
    lr=1e-4,
    dropout=0.2,
    max_epochs=80,
)

models = run(cfg, kind="ToMNet2")
TomNet_Rulemap_Model = models["ToMNet2_3_12_rulemap"]

[TRAIN] sim=3_12 mode=rulemap

=== Training mode = rulemap ===
[rulemap][Ep 01] train acc=0.676 loss=1.060 | val acc=0.860 loss=0.759
[rulemap][Ep 02] train acc=0.839 loss=0.620 | val acc=0.900 loss=0.403
[rulemap][Ep 03] train acc=0.882 loss=0.396 | val acc=0.920 loss=0.268
[rulemap][Ep 04] train acc=0.894 loss=0.311 | val acc=0.940 loss=0.199
[rulemap][Ep 05] train acc=0.917 loss=0.246 | val acc=0.960 loss=0.147
[rulemap][Ep 06] train acc=0.938 loss=0.196 | val acc=0.970 loss=0.118
[rulemap][Ep 07] train acc=0.951 loss=0.153 | val acc=0.970 loss=0.114
[rulemap][Ep 08] train acc=0.955 loss=0.136 | val acc=0.970 loss=0.108
[rulemap][Ep 09] train acc=0.963 loss=0.117 | val acc=0.960 loss=0.109
[rulemap][Ep 10] train acc=0.979 loss=0.093 | val acc=0.970 loss=0.082
[rulemap][Ep 11] train acc=0.983 loss=0.077 | val acc=0.970 loss=0.077
[rulemap][Ep 12] train acc=0.984 loss=0.065 | val acc=0.970 loss=0.056
[rulemap][Ep 13] train acc=0.986 loss=0.056 | val acc=0.970 loss=0.065
[rulemap][Ep 1

In [35]:
cfg = TrainConfig(
    data_root=Path("../data/3_12"),
    sim_name="3_12",
    mode="logic",
    seed=42,
    batch_size=64,
    lr=1e-4,
    dropout=0.2,
    max_epochs=80,
)

models = run(cfg, kind="ToMNet2")
TomNet_Logic_Model = models["ToMNet2_3_12_logic"]

[TRAIN] sim=3_12 mode=logic

=== Training mode = logic ===




[logic][Ep 01] train acc=0.644 loss=1.129 | val acc=0.850 loss=0.849
[logic][Ep 02] train acc=0.877 loss=0.677 | val acc=0.890 loss=0.433
[logic][Ep 03] train acc=0.922 loss=0.380 | val acc=0.970 loss=0.199
[logic][Ep 04] train acc=0.950 loss=0.206 | val acc=0.980 loss=0.091
[logic][Ep 05] train acc=0.961 loss=0.146 | val acc=0.990 loss=0.058
[logic][Ep 06] train acc=0.965 loss=0.111 | val acc=0.980 loss=0.049
[logic][Ep 07] train acc=0.976 loss=0.087 | val acc=0.970 loss=0.043
[logic][Ep 08] train acc=0.975 loss=0.079 | val acc=0.990 loss=0.036
[logic][Ep 09] train acc=0.980 loss=0.063 | val acc=0.990 loss=0.033
[logic][Ep 10] train acc=0.984 loss=0.052 | val acc=0.990 loss=0.024
[logic][Ep 11] train acc=0.986 loss=0.046 | val acc=1.000 loss=0.023
[logic][Ep 12] train acc=0.993 loss=0.037 | val acc=0.990 loss=0.024
[logic][Ep 13] train acc=0.994 loss=0.034 | val acc=0.990 loss=0.021
[logic][Ep 14] train acc=0.998 loss=0.024 | val acc=1.000 loss=0.015
[logic][Ep 15] train acc=0.993 los

In [36]:
cfg = TrainConfig(
    data_root=Path("../data/3_12"),
    sim_name="3_12",
    mode="intermediate_case1",
    seed=42,
    batch_size=64,
    lr=1e-4,
    dropout=0.2,
    max_epochs=80,
)   
models = run(cfg, kind="ToMNet2")
TomNet_Intermediate1_Model = models["ToMNet2_3_12_intermediate_case1"]

[TRAIN] sim=3_12 mode=intermediate_case1

=== Training mode = intermediate_case1 ===
[intermediate_case1][Ep 01] train acc=0.251 loss=1.392 | val acc=0.310 loss=1.353
[intermediate_case1][Ep 02] train acc=0.351 loss=1.342 | val acc=0.410 loss=1.323
[intermediate_case1][Ep 03] train acc=0.409 loss=1.309 | val acc=0.430 loss=1.313
[intermediate_case1][Ep 04] train acc=0.409 loss=1.284 | val acc=0.420 loss=1.314
[intermediate_case1][Ep 05] train acc=0.425 loss=1.259 | val acc=0.430 loss=1.311
[intermediate_case1][Ep 06] train acc=0.451 loss=1.240 | val acc=0.420 loss=1.308
[intermediate_case1][Ep 07] train acc=0.481 loss=1.218 | val acc=0.440 loss=1.309
[intermediate_case1][Ep 08] train acc=0.499 loss=1.186 | val acc=0.460 loss=1.301
[intermediate_case1][Ep 09] train acc=0.515 loss=1.167 | val acc=0.430 loss=1.298
[intermediate_case1][Ep 10] train acc=0.517 loss=1.145 | val acc=0.470 loss=1.298
[intermediate_case1][Ep 11] train acc=0.532 loss=1.129 | val acc=0.430 loss=1.302
[intermediate

In [37]:
cfg = TrainConfig(
    data_root=Path("../data/3_12"),
    sim_name="3_12",
    mode="intermediate_case2",
    seed=42,
    batch_size=64,
    lr=1e-4,
    dropout=0.2,
    max_epochs=80,
)   
models = run(cfg, kind="ToMNet2")
TomNet_Intermediate2_Model = models["ToMNet2_3_12_intermediate_case2"]

[TRAIN] sim=3_12 mode=intermediate_case2

=== Training mode = intermediate_case2 ===
[intermediate_case2][Ep 01] train acc=0.279 loss=1.379 | val acc=0.510 loss=1.317
[intermediate_case2][Ep 02] train acc=0.438 loss=1.301 | val acc=0.540 loss=1.254
[intermediate_case2][Ep 03] train acc=0.484 loss=1.248 | val acc=0.530 loss=1.213
[intermediate_case2][Ep 04] train acc=0.501 loss=1.191 | val acc=0.510 loss=1.185
[intermediate_case2][Ep 05] train acc=0.524 loss=1.139 | val acc=0.530 loss=1.150
[intermediate_case2][Ep 06] train acc=0.542 loss=1.112 | val acc=0.520 loss=1.135
[intermediate_case2][Ep 07] train acc=0.569 loss=1.081 | val acc=0.530 loss=1.123
[intermediate_case2][Ep 08] train acc=0.575 loss=1.057 | val acc=0.540 loss=1.109
[intermediate_case2][Ep 09] train acc=0.595 loss=1.042 | val acc=0.540 loss=1.111
[intermediate_case2][Ep 10] train acc=0.603 loss=0.984 | val acc=0.540 loss=1.099
[intermediate_case2][Ep 11] train acc=0.608 loss=0.965 | val acc=0.530 loss=1.075
[intermediate

## Evaluation

In [None]:
from model_evaluate import evaluate_model_performance

def forward_tomnet2(m: nn.Module, b: Dict[str, torch.Tensor]) -> torch.Tensor:
    """ToMNet2"""
    return m(
        b["grid_seq_j"],
        b["tmask_j"], 
        b["grid_seq_k"],
        b["tmask_k"],
        b["choices_ids"],
        b["choices_mask"]
    )


In [None]:

val_loader, in_ch = make_loader(cfg, split="val", mode="random")


dataset = ToMNet2JsonlDataset(
    split_root=cfg.data_root / "validation", 
    sim_name=cfg.sim_name,
    use_modes=["random"],
    seed=cfg.seed,
)
one_sample = dataset[0]["grid_seq_j"] 
input_hw = one_sample.shape[1:3]  

metrics = evaluate_model_performance(
    TomNet_Random_Model,
    val_loader,
    device="cpu",
    forward_fn=forward_tomnet2,
    energy_forward_fn=forward_tomnet2
)

print("=== ToMNet2_3_12_random model evaluation ===")
print(f"accuracy: {metrics.accuracy:.4f}")
print(f"precision: {metrics.precision:.4f}")
print(f"recall rate: {metrics.recall:.4f}")
print(f"F1 score: {metrics.f1_score:.4f}")
print(f"ROC AUC: {metrics.roc_auc:.4f}")
print(f"R2 Score: {metrics.r2_score:.4f}")
print(f"Prediction Entropy: {metrics.prediction_entropy:.4f}")
print("\n=== Energy and Efficiency Evaluation ===")
print(f"MACs per sample: {metrics.energy_report.macs_per_sample:,.0f}")
print(f"Energy per sample: {metrics.energy_report.ann_energy_per_sample:.2e} J")
print(f"Energy per accuracy unit: {metrics.energy_per_accuracy:.2e} J/acc")
print(f"MACs per accuracy unit: {metrics.macs_per_accuracy:,.0f} MACs/acc")



=== ToMNet2_3_12_random model evaluation ===
accuracy: 0.5300
precision: 0.5299
recall rate: 0.5300
F1 score: 0.5293
ROC AUC: 0.7805
R2 Score: -0.5215
Prediction Entropy: 0.4337

=== Energy and Efficiency Evaluation ===
MACs per sample: 155,808,768
Energy per sample: 7.17e-04 J
Energy per accuracy unit: 1.35e-03 J/acc
MACs per accuracy unit: 293,978,808 MACs/acc


In [38]:


val_loader, in_ch = make_loader(cfg, split="val", mode="rulemap")

dataset = ToMNet2JsonlDataset(
    split_root=cfg.data_root / "validation", 
    sim_name=cfg.sim_name,
    use_modes=["rulemap"],
    seed=cfg.seed,
)
one_sample = dataset[0]["grid_seq_j"]  # 使用 grid_seq_j 而不是 grid_seq
input_hw = one_sample.shape[1:3]  # 获取 H, W

metrics = evaluate_model_performance(
    TomNet_Random_Model,
    val_loader,
    device="cpu",
    forward_fn=forward_tomnet2,
    energy_forward_fn=forward_tomnet2
)

print("=== ToMNet2_3_12_rulemap model evaluation ===")
print(f"accuracy: {metrics.accuracy:.4f}")
print(f"precision: {metrics.precision:.4f}")
print(f"recall rate: {metrics.recall:.4f}")
print(f"F1 score: {metrics.f1_score:.4f}")
print(f"ROC AUC: {metrics.roc_auc:.4f}")
print(f"R2 Score: {metrics.r2_score:.4f}")
print(f"Prediction Entropy: {metrics.prediction_entropy:.4f}")
print("\n=== Energy and Efficiency Evaluation ===")
print(f"MACs per sample: {metrics.energy_report.macs_per_sample:,.0f}")
print(f"Energy per sample: {metrics.energy_report.ann_energy_per_sample:.2e} J")
print(f"Energy per accuracy unit: {metrics.energy_per_accuracy:.2e} J/acc")
print(f"MACs per accuracy unit: {metrics.macs_per_accuracy:,.0f} MACs/acc")

=== ToMNet2_3_12_rulemap model evaluation ===
accuracy: 0.1900
precision: 0.1922
recall rate: 0.1900
F1 score: 0.1866
ROC AUC: 0.3764
R2 Score: -1.2249
Prediction Entropy: 0.5308

=== Energy and Efficiency Evaluation ===
MACs per sample: 155,628,544
Energy per sample: 7.16e-04 J
Energy per accuracy unit: 3.77e-03 J/acc
MACs per accuracy unit: 819,097,600 MACs/acc




In [39]:

val_loader, in_ch = make_loader(cfg, split="val", mode="logic")

dataset = ToMNet2JsonlDataset(
    split_root=cfg.data_root / "validation", 
    sim_name=cfg.sim_name,
    use_modes=["logic"],
    seed=cfg.seed,
)
one_sample = dataset[0]["grid_seq_j"] 
input_hw = one_sample.shape[1:3]  

metrics = evaluate_model_performance(
    TomNet_Random_Model,
    val_loader,
    device="cpu",
    forward_fn=forward_tomnet2,
    energy_forward_fn=forward_tomnet2
)

print("=== ToMNet2_3_12_logic model evaluation ===")
print(f"accuracy: {metrics.accuracy:.4f}")
print(f"precision: {metrics.precision:.4f}")
print(f"recall rate: {metrics.recall:.4f}")
print(f"F1 score: {metrics.f1_score:.4f}")
print(f"ROC AUC: {metrics.roc_auc:.4f}")
print(f"R2 Score: {metrics.r2_score:.4f}")
print(f"Prediction Entropy: {metrics.prediction_entropy:.4f}")
print("\n=== Energy and Efficiency Evaluation ===")
print(f"MACs per sample: {metrics.energy_report.macs_per_sample:,.0f}")
print(f"Energy per sample: {metrics.energy_report.ann_energy_per_sample:.2e} J")
print(f"Energy per accuracy unit: {metrics.energy_per_accuracy:.2e} J/acc")
print(f"MACs per accuracy unit: {metrics.macs_per_accuracy:,.0f} MACs/acc")

=== ToMNet2_3_12_logic model evaluation ===
accuracy: 0.1700
precision: 0.1629
recall rate: 0.1700
F1 score: 0.1650
ROC AUC: 0.4153
R2 Score: -1.5187
Prediction Entropy: 0.5656

=== Energy and Efficiency Evaluation ===
MACs per sample: 155,776,000
Energy per sample: 7.17e-04 J
Energy per accuracy unit: 4.22e-03 J/acc
MACs per accuracy unit: 916,329,412 MACs/acc




In [40]:

val_loader, in_ch = make_loader(cfg, split="val", mode="intermediate_case1")

dataset = ToMNet2JsonlDataset(
    split_root=cfg.data_root / "validation", 
    sim_name=cfg.sim_name,
    use_modes=["intermediate_case1"],
    seed=cfg.seed,
)
one_sample = dataset[0]["grid_seq_j"]  
input_hw = one_sample.shape[1:3]

metrics = evaluate_model_performance(
    TomNet_Random_Model,
    val_loader,
    device="cpu",
    forward_fn=forward_tomnet2,
    energy_forward_fn=forward_tomnet2
)

print("=== ToMNet2_3_12_intermediate_case1 model evaluation ===")
print(f"accuracy: {metrics.accuracy:.4f}")
print(f"precision: {metrics.precision:.4f}")
print(f"recall rate: {metrics.recall:.4f}")
print(f"F1 score: {metrics.f1_score:.4f}")
print(f"ROC AUC: {metrics.roc_auc:.4f}")
print(f"R2 Score: {metrics.r2_score:.4f}")
print(f"Prediction Entropy: {metrics.prediction_entropy:.4f}")
print("\n=== Energy and Efficiency Evaluation ===")
print(f"MACs per sample: {metrics.energy_report.macs_per_sample:,.0f}")
print(f"Energy per sample: {metrics.energy_report.ann_energy_per_sample:.2e} J")
print(f"Energy per accuracy unit: {metrics.energy_per_accuracy:.2e} J/acc")
print(f"MACs per accuracy unit: {metrics.macs_per_accuracy:,.0f} MACs/acc")

=== ToMNet2_3_12_intermediate_case1 model evaluation ===
accuracy: 0.5300
precision: 0.5513
recall rate: 0.5300
F1 score: 0.5341
ROC AUC: 0.7775
R2 Score: -0.2288
Prediction Entropy: 0.4121

=== Energy and Efficiency Evaluation ===
MACs per sample: 155,972,608
Energy per sample: 7.17e-04 J
Energy per accuracy unit: 1.35e-03 J/acc
MACs per accuracy unit: 294,287,940 MACs/acc




In [41]:

# 使用 ToMNet2JsonlDataset 而不是 GridSeqDataset
val_loader, in_ch = make_loader(cfg, split="val", mode="intermediate_case2")

# 获取输入尺寸 - 从数据集中获取一个样本
dataset = ToMNet2JsonlDataset(
    split_root=cfg.data_root / "validation",  # 验证集路径
    sim_name=cfg.sim_name,
    use_modes=["intermediate_case2"],
    seed=cfg.seed,
)
one_sample = dataset[0]["grid_seq_j"]  # 使用 grid_seq_j 而不是 grid_seq
input_hw = one_sample.shape[1:3]  # 获取 H, W

metrics = evaluate_model_performance(
    TomNet_Random_Model,
    val_loader,
    device="cpu",
    forward_fn=forward_tomnet2,
    energy_forward_fn=forward_tomnet2
)

print("=== ToMNet2_3_12_intermediate_case2 model evaluation ===")
print(f"accuracy: {metrics.accuracy:.4f}")
print(f"precision: {metrics.precision:.4f}")
print(f"recall rate: {metrics.recall:.4f}")
print(f"F1 score: {metrics.f1_score:.4f}")
print(f"ROC AUC: {metrics.roc_auc:.4f}")
print(f"R2 Score: {metrics.r2_score:.4f}")
print(f"Prediction Entropy: {metrics.prediction_entropy:.4f}")
print("\n=== Energy and Efficiency Evaluation ===")
print(f"MACs per sample: {metrics.energy_report.macs_per_sample:,.0f}")
print(f"Energy per sample: {metrics.energy_report.ann_energy_per_sample:.2e} J")
print(f"Energy per accuracy unit: {metrics.energy_per_accuracy:.2e} J/acc")
print(f"MACs per accuracy unit: {metrics.macs_per_accuracy:,.0f} MACs/acc")

=== ToMNet2_3_12_intermediate_case2 model evaluation ===
accuracy: 0.6100
precision: 0.6146
recall rate: 0.6100
F1 score: 0.6112
ROC AUC: 0.8442
R2 Score: -0.0007
Prediction Entropy: 0.4194

=== Energy and Efficiency Evaluation ===
MACs per sample: 155,939,840
Energy per sample: 7.17e-04 J
Energy per accuracy unit: 1.18e-03 J/acc
MACs per accuracy unit: 255,639,082 MACs/acc


