In [1]:
import kagglehub
path = kagglehub.dataset_download("satyajeetrai/medical-cost")

In [2]:
print(path)

/Users/lixiaokang/.cache/kagglehub/datasets/satyajeetrai/medical-cost/versions/1


In [21]:
import trl

In [4]:
import os
print(os.listdir(path))

['dataset_.csv']


In [10]:
# !pip uninstall -y pandas
! pip install pandas

Collecting pandas
  Downloading pandas-2.3.3-cp310-cp310-macosx_11_0_arm64.whl.metadata (91 kB)
Downloading pandas-2.3.3-cp310-cp310-macosx_11_0_arm64.whl (10.8 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m10.8/10.8 MB[0m [31m118.5 kB/s[0m  [33m0:01:16[0mm0:00:01[0m00:03[0m
[?25hInstalling collected packages: pandas
Successfully installed pandas-2.3.3


In [None]:
import os
import pandas as pd
from typing import Dict, Optional, List
from dataclasses import dataclass
import logging

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("MIMIC-DataLoader")


@dataclass
class FileSpec:
    """File specification: name + dtype overrides"""
    filename: str
    dtypes: Optional[Dict[str, str]] = None

class MIMICDataLoader:
    """
    Elegant, extensible DataLoader for MIMIC-III/IV.
    Features:
    - automatic dtype handling
    - clean logging
    - modular structure
    - simple to extend to new tables
    """
    def __init__(self, data_path: str):
        self.data_path = data_path
        self.datasets: Dict[str, pd.DataFrame] = {}

        # Required core files
        self.required_files: List[FileSpec] = [
            FileSpec("DIAGNOSES_ICD.csv", {"icd9_code": str}),
            FileSpec("PATIENTS.csv"),
            FileSpec("PRESCRIPTIONS.csv", {"drug_name_generic": str}),
            FileSpec("LABEVENTS.csv"),
            FileSpec("CHARTEVENTS.csv"),
            FileSpec("ICUSTAYS.csv"),
            FileSpec("MICROBIOLOGYEVENTS.csv", {"org_itemid": str, "ab_itemid": str}),
        ]
    # ----------------------------------------------------------------------
    def _filepath(self, filename):
        return os.path.join(self.data_path, filename)

    # ----------------------------------------------------------------------
    def _load_one(self, spec: FileSpec):
        """Load a single MIMIC file with clean logging + dtype override."""
        path = self._filepath(spec.filename)

        if not os.path.exists(path):
            logger.warning(f"[SKIP] File not found: {spec.filename}")
            return pd.DataFrame()

        try:
            df = pd.read_csv(path, dtype=spec.dtypes, low_memory=False)
            logger.info(f"[OK] Loaded {spec.filename:<25} shape={df.shape}")
            return df
        except Exception as e:
            logger.error(f"[ERROR] Failed to load {spec.filename}: {e}")
            return pd.DataFrame()
    # ----------------------------------------------------------------------
    def load_all(self):
        """Load all required files into self.datasets."""
        for spec in self.required_files:
            key = spec.filename.split(".")[0].lower()
            self.datasets[key] = self._load_one(spec)

        self._validate()
        return self.datasets

    # ----------------------------------------------------------------------
    def _validate(self):
        """Ensure that critical datasets exist and are non-empty."""
        mandatory = ["diagnoses_icd", "patients", "prescriptions"]

        for key in mandatory:
            df = self.datasets.get(key)
            if df is None or df.empty:
                raise ValueError(f"[ERROR] Required dataset missing or empty: {key}")

        logger.info("[OK] All required datasets successfully validated.")



['home', 'usr', '.resolve', 'bin', 'sbin', '.file', 'etc', 'var', 'Library', 'System', '.VolumeIcon.icns', 'private', '.vol', 'Users', 'Applications', 'opt', 'dev', 'Volumes', '.nofollow', 'tmp', 'cores']


In [11]:
import pandas as pd
import os

csv_path = os.path.join(path, "dataset_.csv")
df = pd.read_csv(csv_path)
df.head()

AttributeError: partially initialized module 'pandas' has no attribute '_pandas_datetime_CAPI' (most likely due to a circular import)

In [15]:
from typing import List, Tuple, Optional, Dict
import math
import torch
import torch.nn as nn
import torch.nn.functional as F


class RMSNorm(nn.Module):
    """Root Mean Square Layer Normalization.

    y = x * weight / sqrt(mean(x^2) + eps)
    """
    def __init__(self, dim, eps=1e-6):
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(dim))

    def forward(self, x):
        # x: (..., dim)
        norm = x.pow(2).mean(dim=-1, keepdim=True)
        x = x * torch.rsqrt(norm + self.eps)
        return x * self.weight


def _split_heads(x, num_heads):
    """将最后一维分解为多头：(B, T, H*d) -> (B, T, H, d)
    """
    B, T, D = x.shape
    assert D % num_heads == 0, f"Hidden size {D} not divisible by heads {num_heads}"
    d = D // num_heads
    return x.view(B, T, num_heads, d)


def _merge_heads(x):
    """(B, T, H, d) -> (B, T, H*d)
    """
    B, T, H, d = x.shape
    return x.contiguous().view(B, T, H * d)


def scaled_dot_product_attention(q,k,v,attn_mask,dropout_p=0.0,training=True):
    """基础 SDPA
      q: (B, T_q, H_q, d)
      k: (B, T_k, H_k, d)
      v: (B, T_k, H_k, d)
      attn_mask: (B, 1, T_q, T_k) or (1, 1, T_q, T_k)，-inf 位置被屏蔽
    """
    d = q.size(-1)
    # (B, H_q, T_q, d) @ (B, H_k, d, T_k) → (B, H_q, T_q, T_k)
    q_ = q.permute(0, 2, 1, 3)
    k_ = k.permute(0, 2, 3, 1)
    attn_scores = torch.matmul(q_, k_) / math.sqrt(d)

    if attn_mask is not None:
        attn_scores = attn_scores + attn_mask  # 预期 mask 已为 -inf/0 形式

    attn_probs = F.softmax(attn_scores, dim=-1)
    if dropout_p > 0 and training:
        attn_probs = F.dropout(attn_probs, p=dropout_p)

    # (B, H_q, T_q, T_k) @ (B, H_k, T_k, d) → (B, H_q, T_q, d)
    v_ = v.permute(0, 2, 1, 3)
    context = torch.matmul(attn_probs, v_)
    # -> (B, T_q, H_q, d)
    context = context.permute(0, 2, 1, 3)
    return context, attn_probs


class MultiHeadSelfAttention(nn.Module):
    """标准多头自注意力（支持可选的 causal mask）。"""
    def __init__(self, d_model, n_heads, attn_drop=0.0, proj_drop=0.0):
        super().__init__()
        assert d_model % n_heads == 0
        self.d_model = d_model
        self.n_heads = n_heads
        self.d_head = d_model // n_heads

        self.qkv = nn.Linear(d_model, 3 * d_model, bias=False)
        self.out_proj = nn.Linear(d_model, d_model, bias=False)
        self.attn_drop = attn_drop
        self.proj_drop = proj_drop
        self.norm_qkv = RMSNorm(d_model)

    def forward(self, x, causal=True):
        # x: (B, T, D)
        B, T, D = x.shape
        x = self.norm_qkv(x)
        qkv = self.qkv(x)  # (B, T, 3D)
        q, k, v = qkv.chunk(3, dim=-1)
        q = _split_heads(q, self.n_heads)
        k = _split_heads(k, self.n_heads)
        v = _split_heads(v, self.n_heads)

        # causal mask: (1, 1, T, T)
        attn_mask = None
        if causal:
            # 下三角为0，上三角为 -inf
            mask = torch.full((T, T), float('-inf'), device=x.device)
            mask = torch.triu(mask, diagonal=1)
            attn_mask = mask.unsqueeze(0).unsqueeze(0)  # (1,1,T,T)

        ctx, _ = scaled_dot_product_attention(q, k, v, attn_mask=attn_mask,
                                              dropout_p=self.attn_drop, training=self.training)
        out = _merge_heads(ctx)  # (B, T, D)
        out = self.out_proj(out)
        if self.proj_drop > 0:
            out = F.dropout(out, p=self.proj_drop, training=self.training)
        return out


class LazyCrossAttentionGQA(nn.Module):
    """
    懒惰交叉注意力（GQA 版本）：
      - 仅对 Query 做线性映射；K/V 由 ContextProcessor 直接提供（已按 Gkv 分组，无需 Wk/Wv）
      - 支持 H_q 查询头数量与 G_kv 组数不同（H_q >= G_kv），通过 repeat_interleave 将 (K,V) 组扩展到 H_q
    输入：
      x_q: (B, T_q, D)
      k_ctx, v_ctx: (B, T_k, Gkv, d_head)
    参数：
      n_heads_q = H_q, d_head = D // H_q
    输出：
      (B, T_q, D)
    """
    def __init__(self, d_model, n_heads_q, gkv, attn_drop=0.0, proj_drop=0.0):
        super().__init__()
        assert n_heads_q % gkv == 0 # n_heads_q 必须能被 Gkv 整除（每组共享一份 K/V）
        self.d_model = d_model
        self.n_heads_q = n_heads_q
        self.gkv = gkv
        self.d_head = d_model // n_heads_q

        self.q_proj = nn.Linear(d_model, d_model, bias=False)
        self.out_proj = nn.Linear(d_model, d_model, bias=False)
        self.attn_drop = attn_drop
        self.proj_drop = proj_drop
        self.norm_q = RMSNorm(d_model)

    def forward(self,x_q,k_ctx,v_ctx,attn_mask=None):
        # x_q: (B, T_q, D); k_ctx/v_ctx: (B, T_k, Gkv, d_head)
        B, Tq, D = x_q.shape
        _, Tk, Gkv, d = k_ctx.shape
        assert Gkv == self.gkv and d == self.d_head

        q = self.q_proj(self.norm_q(x_q))  # (B, Tq, D)
        q = _split_heads(q, self.n_heads_q)  # (B, Tq, Hq, d)

        # 将 (B, Tk, Gkv, d) 映射为 (B, Tk, Hq, d)，通过 repeat_interleave
        repeat = self.n_heads_q // self.gkv
        k = k_ctx.repeat_interleave(repeat, dim=2)
        v = v_ctx.repeat_interleave(repeat, dim=2)

        ctx, _ = scaled_dot_product_attention(q, k, v, attn_mask=attn_mask,
                                              dropout_p=self.attn_drop, training=self.training)
        out = _merge_heads(ctx)  # (B, Tq, D)
        out = self.out_proj(out)
        if self.proj_drop > 0:
            out = F.dropout(out, p=self.proj_drop, training=self.training)
        return out


class FeedForward(nn.Module):
    def __init__(self, d_model, d_ff, drop=0.0, activation="silu"):
        super().__init__()
        act = nn.SiLU() if activation == "silu" else nn.GELU()
        self.net = nn.Sequential(
            nn.Linear(d_model, d_ff, bias=False),
            act,
            nn.Linear(d_ff, d_model, bias=False),
        )
        self.drop = drop
        self.norm = RMSNorm(d_model)

    def forward(self, x):
        x_in = self.norm(x)
        y = self.net(x_in)
        if self.drop > 0:
            y = F.dropout(y, p=self.drop, training=self.training)
        return y


class MoEFeedForward(nn.Module):
    def __init__(self, d_model, d_ff, num_experts=8, drop=0.0):
        super().__init__()
        self.num_experts = num_experts
        self.drop = drop

        self.gate = nn.Linear(d_model, num_experts, bias=False)

        self.experts = nn.ModuleList([
            nn.Sequential(
                nn.Linear(d_model, d_ff),
                nn.SiLU(),
                nn.Linear(d_ff, d_model),
            )
            for _ in range(num_experts)
        ])

        self.norm = RMSNorm(d_model)

    def forward(self, x):
        B, T, D = x.shape
        x_norm = self.norm(x)

        # -------- gating --------
        gate_logits = self.gate(x_norm)         # (B,T,E)
        gate_scores = torch.softmax(gate_logits, dim=-1)
        top1_idx = gate_scores.argmax(dim=-1)   # (B,T)
        top1_score = gate_scores.max(dim=-1).values  # (B,T)

        # -------- output buffer --------
        out = torch.zeros_like(x_norm)

        # -------- per expert routing --------
        for e in range(self.num_experts):
            # mask: (B,T)
            mask = (top1_idx == e)

            if mask.sum() == 0:
                continue

            # 取出所有属于 expert e 的 token → (Ne, D)
            xe = x_norm[mask]       # 直接在 2D flatten 上 mask

            # expert forward
            ye = self.experts[e](xe)   # (Ne, D)

            # 写回对应 token
            out[mask] = ye

        # gating score 缩放
        out = out * top1_score.unsqueeze(-1)

        if self.drop > 0:
            out = F.dropout(out, p=self.drop, training=self.training)

        return out



class ContextProcessor(nn.Module):
    """
    将多路上下文输入（用户静态、短期、长期）统一映射到 (k_l, v_l) 列表：
      - d_context = Skv * Lkv * Gkv * d_head
      - 将特征维度切成 Lkv 份（每份包含 Skv 个槽），得到每层的 (k_l, v_l)
      - 当 Skv=1：v_l = k_l（KV 共享）；当 Skv=2：第 0 槽为 k_l，第 1 槽为 v_l
    期望输入：
      user_static:  (B, Ns, D_in)
      short_term:   (B, Ts, D_in)
      long_term:    (B, Tl, D_in)
    输出：
      kv_list: List[(k_l, v_l)], 其中 k_l/v_l -> (B, T_ctx, Gkv, d_head)
    """
    def __init__(self,
                 d_in,
                 d_head,
                 gkv,
                 lkv=1,
                 skv=1,
                 use_norm_k=True,
                 use_norm_v=True):
        super().__init__()
        assert skv in (1, 2)
        self.d_in = d_in
        self.d_head = d_head
        self.gkv = gkv
        self.lkv = lkv
        self.skv = skv
        self.d_context = skv * lkv * gkv * d_head

        # 三路线性映射到相同维度后拼接（也可换成各自独立映射 + concat + 再线性）
        self.proj = nn.Linear(d_in, self.d_context, bias=False)
        self.norm_k_layers = nn.ModuleList([RMSNorm(gkv * d_head) if use_norm_k else nn.Identity()
                                            for _ in range(lkv)])
        self.norm_v_layers = nn.ModuleList([RMSNorm(gkv * d_head) if use_norm_v else nn.Identity()
                                            for _ in range(lkv)])

    def forward(self,user_static,short_term,long_term):
        ctx_parts = []
        for x in (user_static, short_term, long_term):
            if x is not None:
                # x: (B, T, D_in) → (B, T, d_context)
                ctx_parts.append(self.proj(x))
        assert len(ctx_parts) > 0 # 至少需要一条上下文输入
        # 沿时间维拼接： (B, T_ctx, d_context)
        ctx = torch.cat(ctx_parts, dim=1) if len(ctx_parts) > 1 else ctx_parts[0]

        B, Tctx, D = ctx.shape
        assert D == self.d_context

        # 切块：Lkv 份，每份大小 = skv * gkv * d_head
        chunk_size = self.skv * self.gkv * self.d_head
        chunks = ctx.split(chunk_size, dim=-1)  # 长度应为 Lkv
        assert len(chunks) == self.lkv, f"期望 {self.lkv} 份，得到 {len(chunks)}"

        kv_list = []
        for l, ch in enumerate(chunks):
            # ch: (B, Tctx, skv*gkv*d_head)
            if self.skv == 1:
                # 共享：v_l = k_l
                k = ch  # (B, Tctx, gkv*d_head)
                k = self.norm_k_layers[l](k)
                # reshape 到 (B,T,Gkv,d_head)
                k = k.view(B, Tctx, self.gkv, self.d_head)
                v = k
            else:
                # 独立：前半为 K，后半为 V
                mid = (self.gkv * self.d_head)
                k, v = ch[..., :mid], ch[..., mid:]
                k = self.norm_k_layers[l](k)
                v = self.norm_v_layers[l](v)
                k = k.view(B, Tctx, self.gkv, self.d_head)
                v = v.view(B, Tctx, self.gkv, self.d_head)
            kv_list.append((k, v))
        return kv_list


class LazyDecoderBlock(nn.Module):
    def __init__(self,
                 d_model,
                 n_heads_q,
                 gkv,
                 d_ff,
                 attn_drop=0.0,
                 resid_drop=0.0):
        super().__init__()
        self.cross_attn = LazyCrossAttentionGQA(d_model, n_heads_q, gkv,
                                               attn_drop=attn_drop, proj_drop=resid_drop)
        self.self_attn = MultiHeadSelfAttention(d_model, n_heads_q,
                                               attn_drop=attn_drop, proj_drop=resid_drop)

        # ⭐ FFN → MoE
        self.ffn = MoEFeedForward(
            d_model=d_model,
            d_ff=d_ff,
            num_experts=8,     # 或者 54，完全看你规模要求
            drop=resid_drop
        )

    def forward(self, x, k_ctx, v_ctx, causal=True):
        x = x + self.cross_attn(x, k_ctx, v_ctx, attn_mask=None)
        x = x + self.self_attn(x, causal=causal)
        x = x + self.ffn(x)
        return x


class LazyDecoder(nn.Module):
    """
    vocab_size：语义 ID 词表大小（例如 3 个 token 的共享词表）
    d_model：主通道维度（= n_heads_q * d_head）
    n_layers：解码层数 Nlayer
    n_heads_q：解码端 query 头数 Hq
    gkv：上下文 KV 组数（Gkv），需整除 Hq
    d_ff：前馈隐层维度

    Lkv：KV 层共享组数
    Skv：1 表示 v=k（共享表示），2 表示独立 K/V

    """
    def __init__(self,
                 vocab_size,
                 d_model = 768,
                 n_layers = 12,
                 n_heads_q = 12,
                 gkv = 3,
                 d_ff = 2048,
                 # Context Processor
                 d_ctx_in = 256,
                 lkv = 1,
                 skv = 1,
                 pad_id = 0,
                 bos_id = 1,
                 attn_drop = 0.0,
                 resid_drop = 0.0):
        super().__init__()
        assert d_model % n_heads_q == 0
        assert n_heads_q % gkv == 0

        self.vocab_size = vocab_size
        self.d_model = d_model
        self.n_layers = n_layers
        self.n_heads_q = n_heads_q
        self.gkv = gkv
        self.d_head = d_model // n_heads_q
        self.pad_id = pad_id
        self.bos_id = bos_id
        self.lkv = lkv
        self.skv = skv

        # 嵌入层（共享）
        self.tok_emb = nn.Embedding(vocab_size, d_model)
        self.out_norm = RMSNorm(d_model)
        self.out_proj = nn.Linear(d_model, vocab_size, bias=False)

        # 上下文处理器：将输入特征映射为 (k_l, v_l) 列表
        self.ctx_proc = ContextProcessor(
            d_in=d_ctx_in,
            d_head=self.d_head,
            gkv=gkv,
            lkv=lkv,
            skv=skv,
        )

        # 构建 N 层解码块
        self.blocks = nn.ModuleList([
            LazyDecoderBlock(d_model, n_heads_q, gkv, d_ff,
                             attn_drop=attn_drop, resid_drop=resid_drop)
            for _ in range(n_layers)
        ])

    def _kv_index_for_layer(self, l):
        # l ∈ [0, Nlayer-1] → l_kv ∈ [0, Lkv-1]
        return (l * self.lkv) // self.n_layers

    def forward(self,
                target_ids,  # (B, T_gen)，例如 [BOS, s1, s2]
                user_static,  # (B, Ns, d_ctx_in)
                short_term,  # (B, Ts, d_ctx_in)
                long_term,   # (B, Tl, d_ctx_in)
                return_hidden = False):
        # 1) 生成 (k_l, v_l) 列表（按层共享）
        kv_list = self.ctx_proc(user_static, short_term, long_term)

        # 2) 目标 token 嵌入
        x = self.tok_emb(target_ids)  # (B, T, D)

        # 3) 逐层堆叠（Cross-Attn 使用共享 KV）
        for l, blk in enumerate(self.blocks):
            idx = self._kv_index_for_layer(l)
            k_ctx, v_ctx = kv_list[idx]
            x = blk(x, k_ctx, v_ctx, causal=True)

        # 4) 输出分类头
        h = self.out_norm(x)
        logits = self.out_proj(h)  # (B, T, vocab_size)
        out = {"logits": logits}
        if return_hidden:
            out["hidden"] = h
        return out

    @torch.no_grad()
    def step(self,
             prev_ids: torch.Tensor,     # (B, T_prev)
             user_static: Optional[torch.Tensor] = None,
             short_term: Optional[torch.Tensor] = None,
             long_term: Optional[torch.Tensor] = None) -> torch.Tensor:
        out = self.forward(prev_ids, user_static, short_term, long_term, return_hidden=False)
        logits_last = out["logits"][:, -1, :]  # (B, vocab)
        return logits_last

class GBPOTrainer:
    """
    强化学习：
      - 支持两阶段训练：监督（CE） + RL（GBPO）
      - GBPO Loss 参考论文定义：
        L_GBPO = -E[ clip(ratio, 1 - eps, 1 + eps) * A ]
      - 其中 ratio = pi_new / pi_old，A 为奖励差分（advantage）
    """
    def __init__(self, model, lambda_rl=0.1, clip_ratio=0.2):
        self.model = model
        self.lambda_rl = lambda_rl
        self.clip_ratio = clip_ratio

    def compute_supervised_loss(self, logits, targets, pad_id=0):
        """
        标准交叉熵损失（stage 1）
        """
        B, T, V = logits.shape
        loss = F.cross_entropy(
            logits.view(-1, V),
            targets.view(-1),
            ignore_index=pad_id
        )
        return loss

    def compute_gbpo_loss(self,
                          new_logits,
                          old_logits,
                          rewards,
                          mask=None):
        """
        GBPO 策略优化损失（stage 2）
        参数：
          new_logits: 当前策略输出 (B, T, V)
          old_logits: 冻结的旧策略输出 (B, T, V)
          rewards: 奖励信号 A (B, T)
          mask: (B, T)，可选掩码
        """
        # softmax 概率
        logp_new = F.log_softmax(new_logits, dim=-1)
        logp_old = F.log_softmax(old_logits.detach(), dim=-1)

        # 获取 token 对应概率
        # 注意：这里假设 reward 只针对目标 token
        # 可改为对完整序列求和
        probs_new = F.softmax(new_logits, dim=-1).clamp_min(1e-9)
        probs_old = F.softmax(old_logits.detach(), dim=-1).clamp_min(1e-9)

        pi_bound = torch.where(
            rewards.unsqueeze(-1) >= 0,  # advantage为正
            torch.max(probs_old, probs_new.detach()),  # 防止正样本过激上升
            torch.max(probs_old, 1 - probs_new.detach())  # 防止负样本概率过低
        )

        ratio = (probs_new / pi_bound).clamp(1e-3, 10)

        adv = rewards.unsqueeze(-1)
        clipped_ratio = torch.clamp(ratio, 1 - self.clip_ratio, 1 + self.clip_ratio)
        loss_unclipped = -ratio * adv
        loss_clipped = -clipped_ratio * adv
        loss = torch.max(loss_unclipped, loss_clipped)

        if mask is not None:
            loss = loss * mask.unsqueeze(-1)
        return loss.mean()

    def train_step(self,
                   batch,
                   optimizer,
                   use_rl=False,
                   old_logits=None,
                   rewards=None):
        """
        单步训练：
          - use_rl=False：监督训练阶段
          - use_rl=True：RL阶段
        """
        target_ids = batch["target_ids"]
        user_static = batch.get("user_static", None)
        short_term = batch.get("short_term", None)
        long_term = batch.get("long_term", None)

        out = self.model(target_ids, user_static, short_term, long_term)
        logits = out["logits"]

        if not use_rl:
            loss = self.compute_supervised_loss(logits, target_ids)
        else:
            assert old_logits is not None and rewards is not None
            loss_rl = self.compute_gbpo_loss(logits, old_logits, rewards)
            loss_ce = self.compute_supervised_loss(logits, target_ids)
            loss = loss_ce + self.lambda_rl * loss_rl

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        return loss.item()



if __name__ == "__main__":
    torch.manual_seed(0)
    B = 4
    Ns, Ts, Tl = 8, 32, 128
    d_ctx_in = 256
    vocab = 5000
    T_gen = 4  # 生成序列长度（含 BOS）

    model = LazyDecoder(
        vocab_size=vocab,
        d_model=768,
        n_layers=4,
        n_heads_q=12,
        gkv=3,
        d_ff=2048,
        d_ctx_in=d_ctx_in,
        lkv=1,   # KV 层共享
        skv=1,   # 1 表示 K=V
        attn_drop=0.0,
        resid_drop=0.1,
    )

    # 假数据
    user_static = torch.randn(B, Ns, d_ctx_in)
    short_term  = torch.randn(B, Ts, d_ctx_in)
    long_term   = torch.randn(B, Tl, d_ctx_in)

    # 目标 token 序列（BOS 开头）
    target_ids = torch.randint(2, vocab, (B, T_gen))
    target_ids[:, 0] = model.bos_id

    # 前向测试
    out = model(target_ids, user_static, short_term, long_term)
    print("logits shape:", out["logits"].shape)

    # 简单训练示例 
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
    trainer = GBPOTrainer(model, lambda_rl=0.2, clip_ratio=0.2)

    print("\\n[Stage 1] Supervised CE Training:")
    for step in range(3):
        batch = {
            "target_ids": target_ids,
            "user_static": user_static,
            "short_term": short_term,
            "long_term": long_term
        }
        loss = trainer.train_step(batch, optimizer, use_rl=False)
        print(f"step {step}: CE loss = {loss:.4f}")

    print("\\n[Stage 2] GBPO Reinforcement Fine-tune:")
    with torch.no_grad():
        old_logits = model(target_ids, user_static, short_term, long_term)["logits"]
    # 模拟用户奖励信号（随机 ±1）
    rewards = torch.randint(low=-1, high=2, size=(B, T_gen)).float()

    for step in range(3):
        batch = {
            "target_ids": target_ids,
            "user_static": user_static,
            "short_term": short_term,
            "long_term": long_term
        }
        loss = trainer.train_step(batch, optimizer, use_rl=True,
                                  old_logits=old_logits,
                                  rewards=rewards)
        print(f"step {step}: RL loss = {loss:.4f}")

logits shape: torch.Size([4, 4, 5000])
\n[Stage 1] Supervised CE Training:
step 0: CE loss = 8.6934
step 1: CE loss = 6.8178
step 2: CE loss = 5.0920
\n[Stage 2] GBPO Reinforcement Fine-tune:
step 0: RL loss = 3.6789
step 1: RL loss = 2.5219
step 2: RL loss = 1.5357
