In [None]:
import torch
import torch.nn.functional as F
from __future__ import annotations
from typing import Literal
import torch, torch.nn as nn
import snntorch as snn
from snntorch import surrogate, utils as snn_utils

In [None]:
def id_to_rc(idx, H=24, W=24):
    r = (idx // W).clamp(min=0, max=H-1)
    c = (idx %  W).clamp(min=0, max=W-1)
    return r, c

def triple_to_map3(ids_3, H=24, W=24):
    B = ids_3.shape[0]
    maps = torch.zeros(B, 3, H, W, device=ids_3.device)
    for a in range(3):
        r, c = id_to_rc(ids_3[:, a], H, W)
        maps[torch.arange(B), a, r, c] = 1.0
    return maps
def prefix_to_spike_seq_batched(prefix_ids, attn_mask, pad_id=576, cls_id=577, H=24, W=24):
    """
    returns:
      spike_seq: [T_max, B, 3, H, W]  
      T_lens:    [B]                  
    """
    device = prefix_ids.device
    B, L = prefix_ids.shape
    seq_per_b = []
    T_lens = []

    for b in range(B):
        keep_b = attn_mask[b].clone()
        if prefix_ids[b, 0] == cls_id:
            keep_b[0] = False  # 忽略 CLS

        ids_b = prefix_ids[b, keep_b]           # [Li]
        Li = ids_b.numel()
        if Li % 3 != 0:
            raise ValueError(f"sample {b}: prefix_ids length {Li} is not a multiple of 3")

        T_i = Li // 3
        T_lens.append(T_i)

        if T_i == 0:
            seq_per_b.append(torch.zeros(1, 3, H, W, device=device))
            T_lens[-1] = 1
            continue

        triples = ids_b.view(T_i, 3)            # [T_i, 3]
        maps_t = []
        for t in range(T_i):
            maps_t.append(triple_to_map3(triples[t].unsqueeze(0), H, W)[0])  # [3,H,W]
        seq_per_b.append(torch.stack(maps_t, dim=0))  # [T_i,3,H,W]

    T_lens = torch.tensor(T_lens, device=device, dtype=torch.long)
    T_max = int(max(T_lens).item())

    # pad 到 [T_max,B,3,H,W]
    out = torch.zeros(T_max, B, 3, H, W, device=device)
    for b in range(B):
        Ti = int(T_lens[b].item())
        out[:Ti, b] = seq_per_b[b]

    return out, T_lens  # [T_max,B,3,H,W], [B]


def last_step_query(prefix_ids, attn_mask, H=24, W=24):
    # 取最後 3 個有效 token -> [B,3,H,W]
    B, L = prefix_ids.shape
    lengths = attn_mask.sum(dim=1)           
    idx_end = lengths - 1
    idxs = torch.stack([idx_end-2, idx_end-1, idx_end], dim=1).clamp(min=0)
    gather = prefix_ids.gather(1, idxs)                 # [B,3]
    return triple_to_map3(gather, H, W)

def choices_to_maps(choices_ids, H=24, W=24):
    B = choices_ids.shape[0]
    out = []
    for k in range(4):
        out.append(triple_to_map3(choices_ids[:, k, :], H, W))
    return torch.stack(out, dim=1)                      # [B,4,3,H,W]


In [None]:
def grid_adj(H, W, device=None):
    n = H*W
    A = torch.zeros(n, n, device=device)
    for i in range(n):
        if i-W >= 0:     A[i, i-W] = 1
        if i+W < n:      A[i, i+W] = 1
        if i%W != 0:     A[i, i-1] = 1
        if (i+1)%W != 0: A[i, i+1] = 1
    A.fill_diagonal_(1)
    return A

class SGA(nn.Module):
    """Spiking Graph Attention on grid nodes: Q,K,V → spk → masked attn → spk V"""
    def __init__(self, d_in, d_out, beta=0.95):
        super().__init__()
        self.Wq = nn.Linear(d_in, d_out, bias=False)
        self.Wk = nn.Linear(d_in, d_out, bias=False)
        self.Wv = nn.Linear(d_in, d_out, bias=False)
        sgrad = surrogate.fast_sigmoid(slope=25)
        self.lif_q = snn.Leaky(beta=beta, spike_grad=sgrad, init_hidden=True)
        self.lif_k = snn.Leaky(beta=beta, spike_grad=sgrad, init_hidden=True)
        self.lif_v = snn.Leaky(beta=beta, spike_grad=sgrad, init_hidden=True)

    def forward(self, X, adj):  # X: [B, N, d_in]
        q = self.lif_q(self.Wq(X))       # [B,N,d_out]
        k = self.lif_k(self.Wk(X))       # [B,N,d_out]
        v = self.lif_v(self.Wv(X))       # [B,N,d_out]
        co = torch.bmm(q, k.transpose(1,2)) * adj   # [B,N,N]
        attn = co / (co.sum(-1, keepdim=True) + 1e-6)
        out = torch.bmm(attn, v)         # [B,N,d_out]
        return out

class SNN_CharNet(nn.Module):
    def __init__(self, p: dict, phase: Literal["stdp","finetune"]="stdp"):
        super().__init__()
        self.p = p; self.phase = phase
        beta = p["BETA"]; th = p["THRESHOLD"]; sgrad = surrogate.atan()

        C_in = p["TRAJECTORY_CHANNELS"]; C = p["CONV_CHANNELS"]
        H, W = p["MAZE_HEIGHT"], p["MAZE_WIDTH"]
        self.conv1 = nn.Conv2d(C_in, C, 3, padding=1, bias=False)
        self.lif1  = snn.Leaky(beta=beta, threshold=th, spike_grad=sgrad, learn_beta=True, learn_threshold=True)
        self.conv2 = nn.Conv2d(C, C, 3, padding=1, bias=False)
        self.lif2  = snn.Leaky(beta=beta, threshold=th, spike_grad=sgrad, learn_beta=True, learn_threshold=True)
        self.pool  = nn.AvgPool2d(kernel_size=(H, W))      # → [B,C,1,1]

        self.slstm = snn.SLSTM(input_size=C, hidden_size=p["LSTM_HIDDEN_SIZE"], spike_grad=sgrad)
        self.fc    = nn.Linear(p["LSTM_HIDDEN_SIZE"], p["LENGTH_E_CHAR"], bias=False)
        self.lif_o = snn.Leaky(beta=beta, threshold=th, spike_grad=sgrad, learn_beta=True, learn_threshold=True)

        if phase == "stdp":
            for p_ in list(self.slstm.parameters()) + list(self.fc.parameters()) + list(self.lif_o.parameters()):
                p_.requires_grad_(False)
        else:
            for p_ in list(self.conv1.parameters()) + list(self.lif1.parameters()) + \
                      list(self.conv2.parameters()) + list(self.lif2.parameters()):
                p_.requires_grad_(False)

    def forward(self, spike_seq, t_lens=None):  # spike_seq: [T_max,B,3,H,W]
        T, B, _, _, _ = spike_seq.shape
        mem1 = mem2 = None
        syn = mem = None
        spk_hist = [] 

        for t in range(T):
            x = spike_seq[t]
            spk1, mem1 = self.lif1(self.conv1(x), mem1)
            spk2, mem2 = self.lif2(self.conv2(spk1), mem2)
            pooled = self.pool(spk2).flatten(1)             # [B,C]
            spk_l, syn, mem = self.slstm(pooled, syn, mem)  # [B,H]
            spk_hist.append(spk_l)

        spk_hist = torch.stack(spk_hist, dim=0)             # [T,B,H]
        if t_lens is not None:
            # 取各自最後有效步（t_lens-1）
            idx = (t_lens - 1).clamp_min(0)                 # [B]
            # 轉成 [B,T,H] 再 gather
            spk_bt = spk_hist.permute(1, 0, 2)              # [B,T,H]
            idx_exp = idx.view(B, 1, 1).expand(B, 1, spk_bt.size(2))
            spk_last = spk_bt.gather(1, idx_exp).squeeze(1) # [B,H]
        else:
            spk_last = spk_hist[-1]                         # 沒提供就取最後一帧

        spk_o, _ = self.lif_o(self.fc(spk_last))            # [B,E]
        return spk_o

class SNN_PredNetChoiceScorer(nn.Module):
    """
    對「單一 choice」打分：輸入 query(3) + choice(3) + e_char_tile(E)
    → Conv-LIF×2 → SGA×2 → Readout-Leaky → RLeaky（NUM_CLASSES=1）→ scalar logit
    """
    def __init__(self, p: dict):
        super().__init__()
        self.p = p
        beta = p["BETA"]; sgrad = surrogate.fast_sigmoid(slope=25)
        H, W = p["MAZE_HEIGHT"], p["MAZE_WIDTH"]
        C_in = p["QUERY_STATE_CHANNELS"] + p["CHOICE_CHANNELS"] + p["LENGTH_E_CHAR"]
        C = p["CONV_CHANNELS"]

        self.conv1 = nn.Conv2d(C_in, C, 3, padding=1, bias=False)
        self.lif1  = snn.Leaky(beta=beta, spike_grad=sgrad, init_hidden=True)
        self.conv2 = nn.Conv2d(C, C, 3, padding=1, bias=False)
        self.lif2  = snn.Leaky(beta=beta, spike_grad=sgrad, init_hidden=True)

        self.sga1  = SGA(C, 2*C, beta=beta)
        self.sga2  = SGA(2*C, C, beta=beta)

        self.read  = snn.Leaky(beta=beta, spike_grad=sgrad, init_hidden=True)
        self.proj  = nn.Linear(C, 1, bias=False)      # → 單一類別
        self.rlif  = snn.RLeaky(beta=beta, linear_features=1, spike_grad=sgrad, init_hidden=True)

        self.register_buffer("ADJ", grid_adj(H, W))

    def reset(self): snn_utils.reset(self)

    def forward(self, e_char_spk, query_spk, choice_map):
        """
        e_char_spk: [B,E]，query_spk: [B,3,H,W]，choice_map: [B,3,H,W]
        回傳 scalar logits: [B]
        """
        B, _, H, W = query_spk.shape
        e_tile = e_char_spk.unsqueeze(-1).unsqueeze(-1).expand(B, -1, H, W)  # [B,E,H,W]
        x = torch.cat([query_spk, choice_map, e_tile], dim=1)                # [B,3+3+E,H,W]

        spk1 = self.lif1(self.conv1(x))
        spk2 = self.lif2(self.conv2(spk1))                                   # [B,C,H,W]

        nodes = spk2.flatten(2).permute(0,2,1)                               # [B,N,C], N=H*W
        out1 = self.sga1(nodes, self.ADJ)
        out2 = self.sga2(out1,  self.ADJ)

        g = out2.mean(dim=1)                                                 # [B,C]
        g_spk = self.read(g)                                                 # [B,C]
        I_bias = self.proj(g_spk)                                            # [B,1]
        # 吸引子迭代幾步（小常數就夠）
        for _ in range(self.p.get("ATTR_STEPS", 6)):
            _ = self.rlif(I_bias)
        return self.rlif.mem.squeeze(-1)                                     # [B]


In [None]:

import torch, torch.nn as nn, torch.nn.functional as F
# from utils_encoding import prefix_to_spike_seq, last_step_query, choices_to_maps
# from snn_blocks import SNN_CharNet, SNN_PredNetChoiceScorer
from snntorch import utils as snn_utils

class ToMNet2SNN(nn.Module):
    def __init__(self, cfg, snn_params: dict):
        super().__init__()
        self.cfg = cfg
        defaults = dict(
            MAZE_HEIGHT=24, MAZE_WIDTH=24,
            TRAJECTORY_CHANNELS=3, QUERY_STATE_CHANNELS=3, CHOICE_CHANNELS=3,
            LENGTH_E_CHAR=cfg.e_char_dim,
            CONV_CHANNELS=64,       
            LSTM_HIDDEN_SIZE=128,   
            BETA=0.95, THRESHOLD=1.0,
            ATTR_STEPS=6,
        )
        self.p = {**defaults, **(snn_params or {})}

        self.charnet = SNN_CharNet(self.p, phase="finetune")
        self.scorer  = SNN_PredNetChoiceScorer(self.p)

    def _encode_query_choice(self, prefix_ids, attn_mask, choices_ids):
        q = last_step_query(prefix_ids, attn_mask, self.p["MAZE_HEIGHT"], self.p["MAZE_WIDTH"])     # [B,3,H,W]
        cm = choices_to_maps(choices_ids, self.p["MAZE_HEIGHT"], self.p["MAZE_WIDTH"])              # [B,4,3,H,W]
        return q, cm

    def forward(self, prefix_ids, attn_mask, choices_ids):
        device = prefix_ids.device
        # 1) prefix → spike seq（CharNet）
        spike_seq, t_lens = prefix_to_spike_seq_batched(
            prefix_ids, attn_mask,
            pad_id=self.cfg.pad_id, cls_id=self.cfg.cls_id,
            H=self.p["MAZE_HEIGHT"], W=self.p["MAZE_WIDTH"]
        )  # [T_max,B,3,H,W], [B]

        e_char_spk = self.charnet(spike_seq, t_lens=t_lens)                                                # [B,E]
        # 2) query/choices → logits
        q, cm = self._encode_query_choice(prefix_ids, attn_mask, choices_ids)                       # q:[B,3,H,W], cm:[B,4,3,H,W]
        B = prefix_ids.size(0)
        # 將四個 choice 在 batch 維度展開，平行算分數
        q_rep   = q.repeat_interleave(4, dim=0)                                                      # [4B,3,H,W]
        e_rep   = e_char_spk.repeat_interleave(4, dim=0)                                             # [4B,E]
        c_flat  = cm.view(B*4, 3, self.p["MAZE_HEIGHT"], self.p["MAZE_WIDTH"])                       # [4B,3,H,W]
        # reset scorer 的狀態（避免跨 batch 狀態污染）
        self.scorer.reset()
        logits_flat = self.scorer(e_rep, q_rep, c_flat)                                              # [4B]
        logits = logits_flat.view(B, 4)                                                              # [B,4]
        return logits

# ---- 訓練流程重點 ----
def train_epoch_snn(model, loader, device, optim, spike_reg_lambda=1e-4):
    model.train()
    total, correct, loss_sum = 0, 0, 0.0
    for batch in loader:
        x = {k:(v.to(device) if torch.is_tensor(v) else v) for k,v in batch.items()}
        # snn 模組可能保留隱狀態，保險起見每批 reset
        snn_utils.reset(model)
        logits = model(x["prefix_ids"], x["attn_mask"], x["choices_ids"])   # [B,4]
        loss_ce = F.cross_entropy(logits, x["labels"])
        # 可選：脈衝稀疏正則（以 logits 的 L2 或在 SNN 模組內記錄總 spikes 做 L1）
        loss = loss_ce
        optim.zero_grad()
        loss.backward()
        nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optim.step()

        pred = logits.argmax(dim=-1)
        correct += (pred == x["labels"]).sum().item()
        total += logits.size(0)
        loss_sum += loss.item() * logits.size(0)
    return correct/max(total,1), loss_sum/max(total,1)

@torch.no_grad()
def eval_snn(model, loader, device):
    model.eval()
    total, correct, loss_sum = 0, 0, 0.0
    for batch in loader:
        x = {k:(v.to(device) if torch.is_tensor(v) else v) for k,v in batch.items()}
        snn_utils.reset(model)
        logits = model(x["prefix_ids"], x["attn_mask"], x["choices_ids"])
        loss = F.cross_entropy(logits, x["labels"])
        pred = logits.argmax(dim=-1)
        correct += (pred == x["labels"]).sum().item()
        total += logits.size(0)
        loss_sum += loss.item() * logits.size(0)
    return correct/max(total,1), loss_sum/max(total,1)


In [None]:

from pathlib import Path
from dataclasses import dataclass
_MODE_ALIASES = {
    "rulemap": "rulemap", "random": "random", "logic": "logic",
    "intermediate_case1": "intermediate_case1", "intermediate_case2": "intermediate_case2",
    # 常見拼法
    "intermediate_1": "intermediate_case1", "intermediate_2": "intermediate_case2",
    "intemediete_1": "intermediate_case1", "intemediete_2": "intermediate_case2",
    "intermediat_case1": "intermediate_case1", "intermediat_case2": "intermediate_case2",
    "all": "all",
}
_SPLIT_DIR = {"train":"training_data","val":"validation_data","test":"testing_data"}
_ALL_MODES = ["rulemap","random","logic","intermediate_case1","intermediate_case2"]

@dataclass
class TrainConfig:
    data_root: Path                
    sim_name: str = "sim1"
    mode: str = "all"              
    pad_id: int = 576              
    cls_id: int = 577              
    add_cls: bool = True

    # model
    vocab_size: int = 578          
    d_model: int = 128
    nhead: int = 8
    n_layers: int = 4
    dim_ff: int = 1024
    dropout: float = 0.4
    e_char_dim: int = 64           

    # train
    batch_size: int = 64
    lr: float = 3e-4
    weight_decay: float = 1e-2
    max_epochs: int = 30
    seed: int = 42
    num_workers: int = 4
    device: str = "cuda" if torch.cuda.is_available() else "cpu"

    def canonical_mode(self) -> str:
        key = self.mode.strip().lower()
        if key not in _MODE_ALIASES:
            raise ValueError(f"Unknown mode: {self.mode}")
        return _MODE_ALIASES[key]
    
cfg = TrainConfig(
    data_root=Path("/Users/Jer_ry/Desktop/scripts"), ###
    sim_name="sim1",
    mode="logic",                 #  'rulemap' / 'random' / 'logic' / 'intermediate_case1' / 'intermediate_case2' / 'all'
    batch_size=64,
    max_epochs=30,
    d_model=256, nhead=8, n_layers=4, dim_ff=1024,
    e_char_dim=64,               
    add_cls=True, pad_id=576, cls_id=577,
)

print("device:", cfg.device)


device: cpu


In [None]:
snn_params = dict(
    CONV_CHANNELS=64,        
    LSTM_HIDDEN_SIZE=128,    
    BETA=0.95, THRESHOLD=1.0
    ATTR_STEPS=6
)
model = ToMNet2SNN(cfg, snn_params).to(cfg.device)
optim = torch.optim.AdamW(model.parameters(), lr=cfg.lr, weight_decay=cfg.weight_decay)


In [None]:
# ---- 2) 準備資料載入器（用你現有的 build_loaders）----
from torch.utils.data import Dataset, DataLoader
import pandas as pd
import functools
from functools import partial
import json, math, random

def csv_paths(cfg: TrainConfig, split: str, mode: str) -> Tuple[Path, Path]:
    base = cfg.data_root / _SPLIT_DIR[split] / cfg.sim_name / "csv"
    return base / f"{mode}_train.csv", base / f"{mode}_answers.csv"

def build_loaders(cfg: TrainConfig):
    def _mk(split):
        ds = MazeMCQDataset.load(cfg, split)
        return DataLoader(
            ds,
            batch_size=cfg.batch_size,
            shuffle=(split == "train"),
            num_workers=0,  
            pin_memory=True,
            collate_fn=functools.partial(collate_maze, pad_id=cfg.pad_id)  
        )
    return _mk("train"), _mk("val"), _mk("test")

class MazeMCQDataset(Dataset):
    def __init__(self, df: pd.DataFrame, pad_id: int, cls_id: int, add_cls: bool):
        self.df = df.reset_index(drop=True)
        self.pad_id, self.cls_id, self.add_cls = pad_id, cls_id, add_cls
        self.letter2idx = {"A":0,"B":1,"C":2,"D":3}

    @staticmethod
    def _load_one(cfg: TrainConfig, split: str, mode: str) -> "MazeMCQDataset":
        t_csv, a_csv = csv_paths(cfg, split, mode)
        if not t_csv.exists() or not a_csv.exists():
            raise FileNotFoundError(f"Missing: {t_csv} / {a_csv}")
        t = pd.read_csv(t_csv)
        a = pd.read_csv(a_csv)
        common = sorted(set(t["trial_id"]).intersection(set(a["trial_id"])))
        df = t[t["trial_id"].isin(common)].merge(a[a["trial_id"].isin(common)], on=["trial_id","mode"], how="inner")
        return MazeMCQDataset(df, cfg.pad_id, cfg.cls_id, cfg.add_cls)

    @staticmethod
    def load(cfg: TrainConfig, split: str) -> "MazeMCQDataset":
        cmode = cfg.canonical_mode()
        if cmode != "all":
            return MazeMCQDataset._load_one(cfg, split, cmode)
        parts = []
        for m in _ALL_MODES:
            ds_m = MazeMCQDataset._load_one(cfg, split, m)
            parts.append(ds_m.df)
        df_all = pd.concat(parts, axis=0, ignore_index=True)
        return MazeMCQDataset(df_all, cfg.pad_id, cfg.cls_id, cfg.add_cls)

    def __len__(self): return len(self.df)

    def _parse_cells(self, s: str) -> List[int]:
        arr = json.loads(s)
        assert all(isinstance(x,int) and 0<=x<=575 for x in arr), "cid out of range"
        return arr

    def __getitem__(self, i: int) -> Dict:
        r = self.df.iloc[i]
        # prefix
        prefix = self._parse_cells(r["cell_json"])
        if self.add_cls:
            prefix = [self.cls_id] + prefix

        # choices
        ch = []
        for c in "ABCD":
            vec = self._parse_cells(r[f"choice{c}_cell"])
            assert len(vec) == 3
            ch.append(torch.tensor(vec, dtype=torch.long))
        choices = torch.stack(ch, dim=0)  # [4,3]

        label = torch.tensor({"A":0,"B":1,"C":2,"D":3}[r["correct"]], dtype=torch.long)
        meta = {
            "trial_id": r["trial_id"], "mode": r["mode"],
            "seq_len_T": int(r["seq_len_T"]),
            "case_type": "" if pd.isna(r.get("case_type","")) else str(r.get("case_type","")),
        }
        return {
            "prefix_ids": torch.tensor(prefix, dtype=torch.long),  
            "choices_ids": choices,                                
            "label": label, "meta": meta,
        }
    
def collate_maze(batch: List[Dict], pad_id: int) -> Dict[str, torch.Tensor]:
    B = len(batch)
    lengths = [len(b["prefix_ids"]) for b in batch]
    L = max(lengths)
    prefix = torch.full((B,L), pad_id, dtype=torch.long)
    mask   = torch.zeros((B,L), dtype=torch.bool)  # True=keep
    for i,b in enumerate(batch):
        li = len(b["prefix_ids"])
        prefix[i,:li] = b["prefix_ids"]
        mask[i,:li] = True
    choices = torch.stack([b["choices_ids"] for b in batch], dim=0)  # [B,4,3]
    labels  = torch.stack([b["label"] for b in batch], dim=0)
    meta = {k:[b["meta"][k] for b in batch] for k in batch[0]["meta"].keys()}
    return {"prefix_ids":prefix, "attn_mask":mask, "choices_ids":choices, "labels":labels, "meta":meta}

In [None]:
for p_ in list(model.charnet.conv1.parameters()) + list(model.charnet.lif1.parameters()) + \
          list(model.charnet.conv2.parameters()) + list(model.charnet.lif2.parameters()):
    p_.requires_grad_(True)




train_loader, val_loader, test_loader = build_loaders(cfg)

EPOCHS = 3
best_va, best_state = 0.0, None
for ep in range(1, EPOCHS+1):
    tr_acc, tr_loss = train_epoch_snn(model, train_loader, cfg.device, optim)
    va_acc, va_loss = eval_snn(model, val_loader, cfg.device)
    print(f"[Epoch {ep:02d}] train acc={tr_acc:.3f} loss={tr_loss:.3f} | val acc={va_acc:.3f} loss={va_loss:.3f}")
    if va_acc > best_va:
        best_va, best_state = va_acc, {k:v.cpu() for k,v in model.state_dict().items()}

if best_state is not None:
    model.load_state_dict(best_state)

te_acc, te_loss = eval_snn(model, test_loader, cfg.device)
print(f"[TEST] acc={te_acc:.3f} loss={te_loss:.3f}")


[Epoch 01] train acc=0.264 loss=1.386 | val acc=0.208 loss=1.386
[Epoch 02] train acc=0.264 loss=1.386 | val acc=0.208 loss=1.386
[Epoch 03] train acc=0.264 loss=1.386 | val acc=0.208 loss=1.386
[TEST] acc=0.234 loss=1.386
