In [None]:
# 完整：重构全部轨迹 -> 保存缓存 -> 读取缓存 -> 随机展示一条(states/actions/rewards)
import os
import sys
import pickle
from pathlib import Path
import random

import numpy as np
import torch

# ========== 1) import 路径 & pickle shim ==========
repo_root = Path(os.getcwd()).resolve().parents[1]
if str(repo_root) not in sys.path:
    sys.path.insert(0, str(repo_root))

# from ISE_Transformer.DT_test.DT_train import _install_go_explore_pickle_shim
# _install_go_explore_pickle_shim()

from ISE_Transformer.experiment.il_2_room.environment import GridWorld

# ========== 10_e) 配置 ==========
PKL_IN = Path("../data_room/2_room/optimal_trajectory_archive2.pkl").resolve()
assert PKL_IN.exists(), f"not found: {PKL_IN}"

# 缓存文件（你要“先保存下来，再读取后随机展示一条”，就写到这个新文件里）
CACHE_OUT = Path("./reconstructed_cache_all_trajs.pkl").resolve()

# DT动作(数据集/go_explore_env): 0=up,1=down,10_e=left,3=right
# GridWorld动作(environment.GridWorld): 1=right,10_e=up,3=left,4=down
DT2GW = torch.tensor([2, 4, 3, 1], dtype=torch.long)

ALG_TYPE = "NM"      # weighted_traj_return 的 type（对齐 DT_test.ipynb）
RNG_SEED_SHOW = None    # 控制“展示哪一条”的随机种子

# ========== 3) helper ==========
def make_params(batch_size: int = 1, horizon: int = 80, node_weight: str = "constant", initial: int = 80):
    return {
        "env": {
            "start": 1,
            "step_size": 0.1,
            "shape": {"x": 11, "y": 18},
            "horizon": int(horizon),
            "node_weight": str(node_weight),
            "disc_size": "small",
            "n_players": 3,
            "Cx_lengthscale": 2,
            "Cx_noise": 0.001,
            "Fx_lengthscale": 1,
            "Fx_noise": 0.001,
            "Cx_beta": 1.5,
            "Fx_beta": 1.5,
            "generate": False,
            "env_file_name": "env_data.pkl",
            "cov_module": "Matern",
            "stochasticity": 0.0,
            "domains": "two_room_2",
            "num": 1,
            "initial": int(initial),
        },
        "common": {
            "a": 1,
            "subgrad": "greedy",
            "grad": "pytorch",
            "algo": "both",
            "init": "deterministic",
            "batch_size": int(batch_size),
        },
        "visu": {"wb": "disabled", "a": 1},
        "alg": {"type": str(ALG_TYPE), "gamma": 1},
    }

def make_gridworld_env(params):
    env_load_path = (
        Path("../..")
        / "experiment"
        / "il_2_room"
        / "2r198"
        / "environments"
        / params["env"]["node_weight"]
        / "env_1"
    ).resolve()

    env = GridWorld(
        env_params=params["env"],
        common_params=params["common"],
        visu_params=params["visu"],
        env_file_path=str(env_load_path),
    )
    # 对齐脚本初始化流程
    env.common_params["batch_size"] = int(params["common"]["batch_size"])
    env.initialize(params["env"]["initial"])
    env.get_horizon_transition_matrix()
    return env

def get_state_tensor(env) -> torch.LongTensor:
    s = env.state
    if not torch.is_tensor(s):
        s = torch.as_tensor(s)
    return s.long().view(-1)  # (B,)

def weighted_prefix_rewards(env, mat_state, alg_type: str) -> tuple[np.ndarray, float]:
    """
    用 weighted_traj_return 的“前缀增量”构造逐步 rewards：
      R_t = return(s0..s_t)
      rewards[t] = (R_{t+1}-R_t)
    并做基线平移：R_t -= R_0，保证 sum(rewards) == traj_return
    """
    T = len(mat_state) - 1  # s0..sT，共 T+1 个
    prefix = np.empty((T + 1,), dtype=np.float32)
    for t_end in range(T + 1):
        obj = env.weighted_traj_return(mat_state[: t_end + 1], type=alg_type).float()  # (B=1,)
        prefix[t_end] = float(obj.view(-1)[0].item())
    prefix = prefix - prefix[0]
    rewards = np.diff(prefix).astype(np.float32)  # (T,)
    traj_return = float(prefix[-1])
    return rewards, traj_return

# ========== 4) 如果缓存不存在：重构全部并保存 ==========
if not CACHE_OUT.exists():
    print("[INFO] cache not found, reconstructing ALL trajectories...")
    with open(PKL_IN, "rb") as f:
        archive = pickle.load(f)
    if not isinstance(archive, dict):
        raise TypeError(f"Unexpected pkl type: {type(archive)} (expected dict)")

    items = sorted(list(archive.items()), key=lambda kv: str(kv[0]))

    trajectories = []
    skipped = 0
    processed = 0

    for idx, (k, node) in enumerate(items):
        actions_dt_list = list(getattr(node, "path_actions", []) or [])
        if len(actions_dt_list) == 0:
            continue

        a_dt_np = np.asarray(actions_dt_list, dtype=np.int64)
        if a_dt_np.min() < 0 or a_dt_np.max() > 3:
            skipped += 1
            continue

        a_dt = torch.tensor(a_dt_np.tolist(), dtype=torch.long)
        a_gw = DT2GW[a_dt]  # (T,)
        T = int(a_dt.numel())

        params = make_params(batch_size=1, horizon=T + 1)
        env = make_gridworld_env(params)

        # replay 重构 mat_state: [s0..sT]
        mat_state = [get_state_tensor(env).clone()]  # s0
        for t in range(T):
            env.step(int(t), torch.tensor([int(a_gw[t].item())], dtype=torch.long))
            mat_state.append(get_state_tensor(env).clone())

        states = torch.cat(mat_state, dim=0).detach().cpu().numpy().astype(np.int64)  # (T+1,)
        actions_dt = a_dt.detach().cpu().numpy().astype(np.int64)                    # (T,)
        actions_gw = a_gw.detach().cpu().numpy().astype(np.int64)                    # (T,)

        rewards, traj_return = weighted_prefix_rewards(env, mat_state, alg_type=params["alg"]["type"])

        trajectories.append(
            {
                "key": str(k),
                "states": states,              # list/np array, len=T+1
                "actions_dt": actions_dt,      # len=T, 0..3
                "actions_gw": actions_gw,      # len=T, 1..4
                "rewards": rewards,            # len=T, sum == traj_return
                "traj_return": np.float32(traj_return),
            }
        )

        processed += 1
        if processed % 200 == 0:
            print(f"[INFO] reconstructed {processed} trajectories... (scanned {idx+1}/{len(items)})")

    payload = {
        "source_pkl": str(PKL_IN),
        "cache_out": str(CACHE_OUT),
        "alg_type": ALG_TYPE,
        "DT2GW": DT2GW.cpu().numpy().tolist(),
        "format": {
            "states": "GridWorld state_id (int64), len=T+1",
            "actions_dt": "dataset action (int64) in [0,3], len=T",
            "actions_gw": "GridWorld action (int64) in [1,4], len=T",
            "rewards": "per-step rewards from weighted_traj_return prefix-diff (float32), len=T",
            "traj_return": "float32, equals sum(rewards)",
        },
        "stats": {
            "num_items_in_archive": len(items),
            "num_reconstructed": len(trajectories),
            "num_skipped_invalid_actions": skipped,
        },
        "trajectories": trajectories,
    }

    with open(CACHE_OUT, "wb") as f:
        pickle.dump(payload, f)

    print("[OK] saved cache:", CACHE_OUT)
else:
    print("[INFO] cache exists, skip reconstruction:", CACHE_OUT)

# ========== 5) 读取缓存并随机展示一条 ==========
with open(CACHE_OUT, "rb") as f:
    cache = pickle.load(f)

trajs = cache.get("trajectories", [])
if not trajs:
    raise RuntimeError("cache has no trajectories")

rng = random.Random(RNG_SEED_SHOW)
one = rng.choice(trajs)

states = np.asarray(one["states"], dtype=np.int64)
actions_dt = np.asarray(one["actions_dt"], dtype=np.int64)
actions_gw = np.asarray(one["actions_gw"], dtype=np.int64)
rewards = np.asarray(one["rewards"], dtype=np.float32)
traj_return = float(one.get("traj_return", rewards.sum()))

print("\n[SUMMARY]")
print("cache:", str(CACHE_OUT))
print("num_trajectories:", len(trajs))
print("alg_type:", cache.get("alg_type"))

print("\n[RANDOM TRAJ]")
print("key =", one.get("key"))
print("T =", int(actions_dt.shape[0]))
print("traj_return =", traj_return)
print("check sum(rewards) =", float(rewards.sum()))

print("\nstates(list) =")
print(states.tolist())

print("\nactions_dt(list, 0..3) =")
print(actions_dt.tolist())

print("\nactions_gw(list, 1..4) =")
print(actions_gw.tolist())

print("\nrewards(list, per-step; sum == traj_return) =")
print(rewards.tolist())

In [102]:
import pickle, random
from pathlib import Path

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

from ISE_Transformer.experiment.il_2_room.environment import GridWorld

# ===================== 配置 =====================
CACHE_OUT = Path("./reconstructed_cache_all_trajs.pkl").resolve()
assert CACHE_OUT.exists(), f"cache not found: {CACHE_OUT}"

# 数据动作空间：actions_dt in [0,3]
NUM_ACTIONS = 4
ACTION_PAD = NUM_ACTIONS  # PAD id = 4

# dt动作(0..3) -> GridWorld动作(1..4)
DT2GW = torch.tensor([2, 4, 3, 1], dtype=torch.long)

K = 30                    # context length（每条轨迹随机截取 K）
BATCH_SIZE = 8            # 每 step 采样多少条“随机轨迹窗口”（要严格每次一条就设 1）
TRAIN_STEPS = 3000
LR = 3e-4
WEIGHT_DECAY = 1e-4
D_MODEL = 128
N_HEAD = 4
N_LAYER = 4
DROPOUT = 0.1
RTG_SCALE = 100.0         # RTG 缩放
SEED = None               # 设整数可复现

# 评估：每隔多少 step rollout 一次
EVAL_EVERY = 50
EVAL_HORIZON = 80         # rollout 最长步数（含初始状态约为 T+1 状态）

# weighted_traj_return 的 type（与你重构时一致）
ALG_TYPE = "NM"

SAVE_PATH = Path("./dt_full_ckpt.pt").resolve()

# ===================== 随机种子 =====================
if SEED is not None:
    random.seed(SEED)
    np.random.seed(SEED)
    torch.manual_seed(SEED)
    torch.cuda.manual_seed_all(SEED)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("[INFO] device =", device)

# ===================== 读取缓存 =====================
with open(CACHE_OUT, "rb") as f:
    cache = pickle.load(f)
trajs = cache.get("trajectories", [])
assert len(trajs) > 0, "cache has no trajectories"
print("[INFO] num_trajs =", len(trajs), "alg_type(cache) =", cache.get("alg_type"))

# state_vocab / max_timestep
max_state_id = 0
max_T = 0
traj_returns = []
for tr in trajs:
    s = np.asarray(tr["states"], dtype=np.int64)
    a = np.asarray(tr["actions_dt"], dtype=np.int64)
    max_state_id = max(max_state_id, int(s.max()))
    max_T = max(max_T, int(a.shape[0]))
    if "traj_return" in tr:
        traj_returns.append(float(tr["traj_return"]))
state_vocab = max_state_id + 1
max_timestep = max(max_T, EVAL_HORIZON, K) + 5
print("[INFO] state_vocab =", state_vocab, "max_timestep =", max_timestep)

# 评估时的目标 return（用数据集中较高的 return 做 conditioning）
if len(traj_returns) > 0:
    DESIRED_RETURN = float(np.percentile(np.array(traj_returns, dtype=np.float32), 90))
else:
    DESIRED_RETURN = 1.0
print("[INFO] DESIRED_RETURN =", DESIRED_RETURN)

# ===================== 环境构造（用于评估）=====================
def make_params(batch_size: int = 1, horizon: int = 80, node_weight: str = "constant", initial: int = 80):
    return {
        "env": {
            "start": 1,
            "step_size": 0.1,
            "shape": {"x": 11, "y": 18},
            "horizon": int(horizon),
            "node_weight": str(node_weight),
            "disc_size": "small",
            "n_players": 3,
            "Cx_lengthscale": 2,
            "Cx_noise": 0.001,
            "Fx_lengthscale": 1,
            "Fx_noise": 0.001,
            "Cx_beta": 1.5,
            "Fx_beta": 1.5,
            "generate": False,
            "env_file_name": "env_data.pkl",
            "cov_module": "Matern",
            "stochasticity": 0.0,
            "domains": "two_room_2",
            "num": 1,
            "initial": int(initial),
        },
        "common": {
            "a": 1,
            "subgrad": "greedy",
            "grad": "pytorch",
            "algo": "both",
            "init": "deterministic",
            "batch_size": int(batch_size),
        },
        "visu": {"wb": "disabled", "a": 1},
        "alg": {"type": str(ALG_TYPE), "gamma": 1},
    }

def make_gridworld_env(params):
    env_load_path = (
        Path("../..")
        / "experiment"
        / "il_2_room"
        / "2r198"
        / "environments"
        / params["env"]["node_weight"]
        / "env_1"
    ).resolve()

    env = GridWorld(
        env_params=params["env"],
        common_params=params["common"],
        visu_params=params["visu"],
        env_file_path=str(env_load_path),
    )
    env.common_params["batch_size"] = int(params["common"]["batch_size"])
    env.initialize(params["env"]["initial"])
    env.get_horizon_transition_matrix()
    return env

def get_state_id(env) -> int:
    s = env.state
    if torch.is_tensor(s):
        s = int(s.view(-1)[0].item())
    else:
        s = int(np.asarray(s).reshape(-1)[0])
    return s

def weighted_return_baselined(env, mat_state, alg_type: str) -> float:
    # return([s0..sT]) - return([s0])
    objT = env.weighted_traj_return(mat_state, type=alg_type).float().view(-1)[0].item()
    obj0 = env.weighted_traj_return(mat_state[:1], type=alg_type).float().view(-1)[0].item()
    return float(objT - obj0)

# ===================== 轨迹采样（每次随机取一条轨迹窗口）=====================
def make_rtg(rewards: np.ndarray) -> np.ndarray:
    return np.cumsum(rewards[::-1], dtype=np.float32)[::-1].astype(np.float32)

def sample_batch(trajs, batch_size: int, K: int, rtg_scale: float, device: torch.device):
    states_b, prev_a_b, target_a_b, rtg_b, t_b, keymask_b, lossmask_b = [], [], [], [], [], [], []

    for _ in range(batch_size):
        tr = random.choice(trajs)  # 每次随机取一条轨迹
        states_all = np.asarray(tr["states"], dtype=np.int64)       # (T+1,)
        actions = np.asarray(tr["actions_dt"], dtype=np.int64)      # (T,)
        rewards = np.asarray(tr["rewards"], dtype=np.float32)       # (T,)

        T = int(actions.shape[0])
        states = states_all[:-1]  # (T,)

        rtg = make_rtg(rewards) / float(rtg_scale)  # (T,)

        if T >= K:
            start = random.randint(0, T - K)
            end = start + K
            pad = 0
            s = states[start:end]
            a = actions[start:end]
            r = rtg[start:end]
            tt = np.arange(start, end, dtype=np.int64)
            valid_len = K
        else:
            pad = K - T
            s = np.pad(states, (0, pad), constant_values=0)
            a = np.pad(actions, (0, pad), constant_values=0)
            r = np.pad(rtg, (0, pad), constant_values=0.0)
            tt = np.pad(np.arange(0, T, dtype=np.int64), (0, pad), constant_values=0)
            valid_len = T

        # prev_actions: [PAD] + a[:-1]
        prev = np.empty((K,), dtype=np.int64)
        prev[0] = ACTION_PAD
        prev[1:] = a[:-1]
        if pad > 0:
            prev[-pad:] = ACTION_PAD

        key_padding_mask = np.zeros((K,), dtype=np.bool_)
        if pad > 0:
            key_padding_mask[-pad:] = True

        loss_mask = np.zeros((K,), dtype=np.bool_)
        loss_mask[:valid_len] = True

        states_b.append(s)
        prev_a_b.append(prev)
        target_a_b.append(a)
        rtg_b.append(r)
        t_b.append(tt)
        keymask_b.append(key_padding_mask)
        lossmask_b.append(loss_mask)

    batch = {
        "states": torch.tensor(np.stack(states_b), dtype=torch.long, device=device),            # (B,K)
        "prev_actions": torch.tensor(np.stack(prev_a_b), dtype=torch.long, device=device),      # (B,K)
        "actions": torch.tensor(np.stack(target_a_b), dtype=torch.long, device=device),         # (B,K)
        "rtg": torch.tensor(np.stack(rtg_b), dtype=torch.float32, device=device),               # (B,K)
        "timesteps": torch.tensor(np.stack(t_b), dtype=torch.long, device=device),              # (B,K)
        "key_padding_mask": torch.tensor(np.stack(keymask_b), dtype=torch.bool, device=device), # (B,K)
        "loss_mask": torch.tensor(np.stack(lossmask_b), dtype=torch.bool, device=device),       # (B,K)
    }
    return batch

# ===================== 完整版 Decision Transformer（3 tokens / timestep）=====================
class CausalTransformer(nn.Module):
    def __init__(self, d_model: int, n_head: int, n_layer: int, dropout: float):
        super().__init__()
        layer = nn.TransformerEncoderLayer(
            d_model=d_model,
            nhead=n_head,
            dim_feedforward=4 * d_model,
            dropout=dropout,
            activation="gelu",
            batch_first=True,
            norm_first=True,
        )
        self.enc = nn.TransformerEncoder(layer, num_layers=n_layer)

    def forward(self, x, key_padding_mask=None):
        L = x.size(1)
        causal_mask = torch.triu(torch.ones(L, L, device=x.device, dtype=torch.bool), diagonal=1)
        return self.enc(x, mask=causal_mask, src_key_padding_mask=key_padding_mask)

class DecisionTransformerFull(nn.Module):
    """
    每个时间步 3 tokens: [rtg_t, state_t, prev_action_t]
    在 state token 位置输出 action logits（预测 action_t）
    """
    def __init__(self, state_vocab: int, num_actions: int, max_timestep: int,
                 d_model=128, n_head=4, n_layer=4, dropout=0.1):
        super().__init__()
        self.num_actions = num_actions
        self.action_pad = num_actions

        self.state_emb = nn.Embedding(state_vocab, d_model)
        self.action_emb = nn.Embedding(num_actions + 1, d_model)  # +PAD
        self.time_emb = nn.Embedding(max_timestep + 1, d_model)
        self.type_emb = nn.Embedding(3, d_model)  # 0=rtg,1=state,10_e=action

        self.rtg_proj = nn.Linear(1, d_model)
        self.drop = nn.Dropout(dropout)

        self.tr = CausalTransformer(d_model=d_model, n_head=n_head, n_layer=n_layer, dropout=dropout)
        self.head = nn.Linear(d_model, num_actions)

    def build_tokens(self, states, prev_actions, rtg, timesteps):
        # states/prev_actions/rtg/timesteps: (B,K)
        B, K = states.shape
        d = self.state_emb.embedding_dim

        # (B,K,D)
        rtg_tok = self.rtg_proj(rtg.unsqueeze(-1))
        s_tok = self.state_emb(states)
        a_tok = self.action_emb(prev_actions)

        # time embedding: 对三个 token 都加同一个 time_emb(t)
        t_emb = self.time_emb(timesteps)  # (B,K,D)

        rtg_tok = rtg_tok + t_emb + self.type_emb(torch.zeros((B, K), device=states.device, dtype=torch.long))
        s_tok   = s_tok   + t_emb + self.type_emb(torch.ones((B, K), device=states.device, dtype=torch.long))
        a_tok   = a_tok   + t_emb + self.type_emb(torch.full((B, K), 2, device=states.device, dtype=torch.long))

        # interleave => (B, 3K, D)
        x = torch.stack([rtg_tok, s_tok, a_tok], dim=2).reshape(B, 3 * K, d)
        return x

    def forward(self, states, prev_actions, rtg, timesteps, key_padding_mask=None):
        # key_padding_mask: (B,K) -> expand to (B,3K)
        if key_padding_mask is not None:
            key_padding_mask = key_padding_mask.unsqueeze(-1).expand(-1, -1, 3).reshape(key_padding_mask.size(0), -1)

        x = self.build_tokens(states, prev_actions, rtg, timesteps)
        x = self.drop(x)
        h = self.tr(x, key_padding_mask=key_padding_mask)  # (B,3K,D)

        # state token positions: 1,4,7,... => index 1 + 3*t
        h_state = h[:, 1::3, :]  # (B,K,D)
        logits = self.head(h_state)  # (B,K,A)
        return logits

model = DecisionTransformerFull(
    state_vocab=state_vocab,
    num_actions=NUM_ACTIONS,
    max_timestep=max_timestep,
    d_model=D_MODEL,
    n_head=N_HEAD,
    n_layer=N_LAYER,
    dropout=DROPOUT,
).to(device)

opt = torch.optim.AdamW(model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)

# ===================== 环境 Rollout 评估（每次输出一轮测试回报）=====================
@torch.no_grad()
def rollout_one_episode(model: DecisionTransformerFull, desired_return: float, horizon: int, K: int, rtg_scale: float):
    model.eval()

    params = make_params(batch_size=1, horizon=horizon + 1)
    env = make_gridworld_env(params)

    mat_state = [torch.tensor([get_state_id(env)], dtype=torch.long)]  # list of (1,)
    baseline0 = env.weighted_traj_return(mat_state[:1], type=ALG_TYPE).float().view(-1)[0].item()

    actions_dt = []
    prev_action_dt = ACTION_PAD  # t=0 的 prev_action 用 PAD

    for t in range(horizon):
        # 当前 baselined return
        cur_obj = env.weighted_traj_return(mat_state, type=ALG_TYPE).float().view(-1)[0].item()
        cur_return = float(cur_obj - baseline0)
        rtg_remaining = float(desired_return - cur_return)

        # 构造 context（右侧 padding）
        window_states = [int(x.view(-1)[0].item()) for x in mat_state]  # len=t+1
        window_prev_actions = [ACTION_PAD] + actions_dt  # len=t+1，对齐 state_t 的 prev_action_t
        window_len = min(len(window_states), K)

        s_seq = np.array(window_states[-window_len:], dtype=np.int64)
        pa_seq = np.array(window_prev_actions[-window_len:], dtype=np.int64)
        rtg_seq = np.full((window_len,), rtg_remaining / float(rtg_scale), dtype=np.float32)
        tt_seq = np.arange(t - window_len + 1, t + 1, dtype=np.int64)

        pad = K - window_len
        if pad > 0:
            s_seq = np.pad(s_seq, (0, pad), constant_values=0)
            pa_seq = np.pad(pa_seq, (0, pad), constant_values=ACTION_PAD)
            rtg_seq = np.pad(rtg_seq, (0, pad), constant_values=0.0)
            tt_seq = np.pad(tt_seq, (0, pad), constant_values=0)

        key_padding_mask = np.zeros((K,), dtype=np.bool_)
        if pad > 0:
            key_padding_mask[-pad:] = True

        states_t = torch.tensor(s_seq, dtype=torch.long, device=device).unsqueeze(0)
        prev_a_t = torch.tensor(pa_seq, dtype=torch.long, device=device).unsqueeze(0)
        rtg_t = torch.tensor(rtg_seq, dtype=torch.float32, device=device).unsqueeze(0)
        ts_t = torch.tensor(tt_seq, dtype=torch.long, device=device).unsqueeze(0)
        kpm_t = torch.tensor(key_padding_mask, dtype=torch.bool, device=device).unsqueeze(0)

        logits = model(states_t, prev_a_t, rtg_t, ts_t, key_padding_mask=kpm_t)  # (1,K,A)
        # 取最后一个有效 timestep 的预测
        action_dt = int(torch.argmax(logits[0, window_len - 1], dim=-1).item())
        actions_dt.append(action_dt)

        action_gw = int(DT2GW[torch.tensor(action_dt)].item())
        env.step(int(t), torch.tensor([action_gw], dtype=torch.long))

        mat_state.append(torch.tensor([get_state_id(env)], dtype=torch.long))

    test_return = weighted_return_baselined(env, mat_state, alg_type=ALG_TYPE)
    model.train()
    return test_return

# ===================== 训练（每 step 随机取一条轨迹；并周期性打印测试回报）=====================
model.train()
for step in range(1, TRAIN_STEPS + 1):
    batch = sample_batch(trajs, BATCH_SIZE, K, RTG_SCALE, device)

    logits = model(
        states=batch["states"],
        prev_actions=batch["prev_actions"],
        rtg=batch["rtg"],
        timesteps=batch["timesteps"],
        key_padding_mask=batch["key_padding_mask"],
    )  # (B,K,A)

    # 只在有效 token 上算 loss（排除 padding）
    mask = batch["loss_mask"] & (~batch["key_padding_mask"])  # (B,K)
    logits_flat = logits[mask]                 # (N,A)
    targets_flat = batch["actions"][mask]      # (N,)

    loss = F.cross_entropy(logits_flat, targets_flat)

    opt.zero_grad(set_to_none=True)
    loss.backward()
    nn.utils.clip_grad_norm_(model.parameters(), 1.0)
    opt.step()

    if step % EVAL_EVERY == 0:
        test_ret = rollout_one_episode(model, desired_return=DESIRED_RETURN, horizon=EVAL_HORIZON, K=K, rtg_scale=RTG_SCALE)
        print(f"[TRAIN] step={step:5d} loss={float(loss.item()):.6f}  test_return={test_ret:.6f}")

# 保存
ckpt = {
    "model": model.state_dict(),
    "cfg": {
        "CACHE_OUT": str(CACHE_OUT),
        "NUM_ACTIONS": NUM_ACTIONS,
        "K": K,
        "BATCH_SIZE": BATCH_SIZE,
        "TRAIN_STEPS": TRAIN_STEPS,
        "LR": LR,
        "WEIGHT_DECAY": WEIGHT_DECAY,
        "D_MODEL": D_MODEL,
        "N_HEAD": N_HEAD,
        "N_LAYER": N_LAYER,
        "DROPOUT": DROPOUT,
        "RTG_SCALE": RTG_SCALE,
        "SEED": SEED,
        "ALG_TYPE": ALG_TYPE,
        "EVAL_EVERY": EVAL_EVERY,
        "EVAL_HORIZON": EVAL_HORIZON,
        "DESIRED_RETURN": DESIRED_RETURN,
    },
    "meta": {"alg_type(cache)": cache.get("alg_type"), "state_vocab": state_vocab, "max_timestep": max_timestep},
}
torch.save(ckpt, SAVE_PATH)
print("[OK] saved:", SAVE_PATH)

[INFO] device = cuda
[INFO] num_trajs = 112 alg_type(cache) = NM
[INFO] state_vocab = 198 max_timestep = 85
[INFO] DESIRED_RETURN = 144.0
[TRAIN] step=   50 loss=0.651229  test_return=65.000000
[TRAIN] step=  100 loss=0.450928  test_return=106.000000
[TRAIN] step=  150 loss=0.550896  test_return=144.000000
[TRAIN] step=  200 loss=0.298260  test_return=137.000000
[TRAIN] step=  250 loss=0.312202  test_return=133.000000
[TRAIN] step=  300 loss=0.241559  test_return=134.000000
[TRAIN] step=  350 loss=0.323345  test_return=144.000000
[TRAIN] step=  400 loss=0.152591  test_return=144.000000
[TRAIN] step=  450 loss=0.140766  test_return=140.000000
[TRAIN] step=  500 loss=0.160854  test_return=142.000000
[TRAIN] step=  550 loss=0.159200  test_return=137.000000
[TRAIN] step=  600 loss=0.133808  test_return=144.000000
[TRAIN] step=  650 loss=0.156491  test_return=140.000000
[TRAIN] step=  700 loss=0.085428  test_return=144.000000
[TRAIN] step=  750 loss=0.074581  test_return=144.000000
[TRAIN] 