![SAPO](./image.png)

* **SAPO算法**的核心思路
    1. 用简单的 MLP 策略网络模拟 LLM 的 token 生成逻辑
    2. 实现 SAPO 的温度可控软门控（替代硬裁剪）
    3. 实现正负 token 的非对称温度（τ_neg > τ_pos）
    4. 用模拟的文本生成任务（固定目标序列）验证算法

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from torch.distributions import Categorical

In [2]:
# 参数
VOCAB_SIZE = 10  # 词汇表大小
SEQ_LEN = 5  # 序列长度
BATCH_SIZE = 8  # 批量大小
GRADIENT_STEPS = 100  # 训练步数
LR = 1e-3  # 学习率
TAU_POS = 1.0 # 正 advantage 的温度
TAU_NEG = 1.05 # 负 advantage 的温度(> TAU_POS)
GAMMA = 0.9 # 折扣因子(用于优势计算)

In [15]:
class SimplePolicy(nn.Module):
    def __init__(self, vocab_size, seq_len, hidden_dim=32):
        super(SimplePolicy, self).__init__()
        self.seq_len = seq_len
        self.vocab_size = vocab_size
        self.fc1 = nn.Linear(seq_len, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, seq_len * vocab_size)
        self.relu = nn.ReLU()
    
    def forward(self, x):
        x = self.relu(self.fc1(x))
        logits = self.fc2(x)
        logits = logits.reshape(-1, self.seq_len, self.vocab_size)
        return logits

In [4]:
# 简化的优势函数计算：归一化奖励（论文中 group-normalized advantage）
def compute_advantage(rewards, batch_rewards):
    adv = rewards - batch_rewards.mean()
    return adv

In [5]:
# SAPO核心：温度可控的软门控函数
# 基于论文公式σ(x)(1-σ(x)) = 1/4 * sech²(x/2)，实现平滑衰减
def sigmoid_gate(r_t, tau):
    x = (r_t - 1) * tau  # 偏离 on-policy(r_t=1)的程度
    sig = torch.sigmoid(x)
    gate_weight = 4 / tau * sig * (1 - sig)  # 论文中的4/τ因子
    return gate_weight.clamp(0, 1)  # 限制权重范围

In [6]:
def sapo_loss(old_policy, new_policy, states, actions, rewards, tau_pos, tau_neg):
    """
    计算SAPO的损失函数
    :param old_policy: 旧策略（行为策略）
    :param new_policy: 新策略（目标策略）
    :param states: 输入状态
    :param actions: 采样的token序列
    :param rewards: 序列奖励
    :param tau_pos: 正advantage温度
    :param tau_neg: 负advantage温度
    :return: SAPO损失
    """
    # 1. 计算 tokens 级 importance ratio r_t(θ)
    old_logits = old_policy(states)
    new_logits = new_policy(states)

    old_dist = Categorical(logits=old_logits)
    new_dist = Categorical(logits=new_logits)

    # token 级的对数概率(shape: [batch, seq_len])
    old_log_probs = old_dist.log_prob(actions)
    new_log_probs = new_dist.log_prob(actions)

    # importance ratio r_t = π_new / π_old （token级）
    r_t = torch.exp(new_log_probs - old_log_probs)  # [batch, seq_len]

    # 2. 计算 group-normalized advantage
    batch_adv = compute_advantage(rewards, rewards)  # [batch]
    token_adv = batch_adv.unsqueeze(1).repeat(1, SEQ_LEN)  # [batch, seq_len]

    # 3. 非对称温度：根据 advantage 正负选择温度
    tau = torch.where(token_adv > 0, tau_pos, tau_neg)  # [batch, seq_len]

    # 4. 计算软门控权重
    gate_weights = sigmoid_gate(r_t, tau)  # [batch, seq_len]

    # 5. SAPO目标：加权 policy gradient
    # 目标最大化，因此损失去负值
    sapo_obj = gate_weights * r_t * token_adv
    loss = -sapo_obj.mean()

    return loss

In [None]:
def simulate_reward(actions, target_seq):
    """模拟奖励：与目标序列的匹配度越高，奖励越高"""
    # target_seq: 目标序列(shape: [batch_size, seq_len])
    target_seq_batch = target_seq.unsqueeze(0).repeat(actions.shape[0], 1)
    match = (actions == target_seq_batch).float()
    seq_reward = match.sum(dim=1) / SEQ_LEN  # [batch]
    return seq_reward

In [None]:
# 运行
# 1. 初始化策略网络（新策略+旧策略）
policy = SimplePolicy(vocab_size=VOCAB_SIZE, seq_len=SEQ_LEN)
old_policy = SimplePolicy(vocab_size=VOCAB_SIZE, seq_len=SEQ_LEN)
old_policy.load_state_dict(policy.state_dict())  # 初始时新旧策略一致
optimizer = optim.Adam(policy.parameters(), lr=LR)

# 2. 定义目标序列（模拟任务：让模型生成该序列）
target_seq = torch.tensor([2, 5, 7, 3, 1])  # 固定目标token

# 3. 训练循环
for step in range(GRADIENT_STEPS):
    # 3.1 采样数据（模拟 rollout 过程）
    # 生成随机状态（模拟输入query的特征）
    states = torch.randn(BATCH_SIZE, SEQ_LEN)

    # 旧策略采样 token 序列
    with torch.no_grad():
        old_logits = policy(states)
        old_dist = Categorical(logits=old_logits)
        actions = old_dist.sample()  # [batch, seq_len]
    
    # 计算序列奖励
    rewards = simulate_reward(actions, target_seq)

    # 3.2 计算 SAPO 损失
    loss = sapo_loss(
        old_policy=old_policy,
        new_policy=policy,
        states=states,
        actions=actions,
        rewards=rewards,
        tau_pos=TAU_POS,
        tau_neg=TAU_NEG
    )

    # 3.3 更新策略
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    # 3.4 更新旧策略
    if (step + 1) % 10 == 0:
        print("\n===== 更新旧策略 =====")
        old_policy.load_state_dict(policy.state_dict())

    # 3.5 打印训练日志
    if (step + 1) % 10 == 0:
        avg_reward = rewards.mean().item()
        print(f"Step [{step+1}/{GRADIENT_STEPS}] | Loss: {loss.item():.4f} | Avg Reward: {avg_reward:.4f}")

# 训练完成后验证
print("\n===== 训练完成，验证策略 =======")
test_state = torch.randn(1, SEQ_LEN)
with torch.no_grad():
    logits = policy(test_state)
    dist = Categorical(logits=logits)
    gen_seq = dist.sample().squeeze(0)
print(f"目标序列: {target_seq.numpy()}")
print(f"生成序列: {gen_seq.numpy()}")