# PPO 教学（手写公式 + 最小实现，手写 LoRA）

这一版与 `PPO_tutorial_no_lora.ipynb` 的区别：

- 冻结基座模型参数
- 在若干 Linear 层上手写注入 LoRA（不使用 peft/trl）
- 只训练 LoRA 参数 + value head

优点：显存更省、训练更容易跑通；缺点：表达能力受限于低秩增量。


## 1. PPO（token 级）核心公式（与无 LoRA 版相同）

符号与公式同 `PPO_tutorial_no_lora.ipynb`，这里不再重复推导，只强调：

- 我们仍然在 token 级计算 `logπ_old/logπ_ref/KL/GAE/ratio/clip objective`
- LoRA 只改变“哪些参数可训练”，不改变 PPO 的数学形式

总 loss：

$$
\mathcal{L}= -L^{clip} + c_V L^V - c_{ent}\,\mathbb{E}_t[H]
$$


## 2. 环境与模型加载（建议离线）

- 建议 `conda activate llm`
- 默认优先用 `MODELSCOPE_CACHE` 下的本地模型目录（避免联网）


In [None]:
import os
import random
import sys
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Dict, List, Tuple

import torch
import torch.nn.functional as F
from torch.nn.utils import clip_grad_norm_
from transformers import AutoModelForCausalLM, AutoTokenizer

os.environ.setdefault("MODELSCOPE_CACHE", r"D:/myProject/modelscope_hub")
print("python:", sys.executable)
print("torch:", torch.__version__, "cuda:", torch.cuda.is_available())
print("MODELSCOPE_CACHE:", os.environ["MODELSCOPE_CACHE"])

seed = 42
random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(seed)

device = "cuda" if torch.cuda.is_available() else "cpu"
device


In [None]:
# 选择模型（优先本地缓存目录）
local_dir = Path(os.environ["MODELSCOPE_CACHE"]) / "models" / "qwen" / "Qwen2-0___5B-Instruct"
model_name_or_path = str(local_dir) if local_dir.exists() else "qwen/Qwen2-0.5B-Instruct"
print("model_name_or_path:", model_name_or_path)

if device == "cuda":
    dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
else:
    dtype = torch.float32
print("dtype:", dtype)

tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
if tokenizer.pad_token_id is None:
    tokenizer.pad_token = "<|endoftext|>"
tokenizer.padding_side = "left"

SYSTEM_PROMPT = "You are a helpful assistant."

class ActorCritic(torch.nn.Module):
    def __init__(self, base: torch.nn.Module):
        super().__init__()
        self.base = base
        self.value_head = torch.nn.Linear(base.config.hidden_size, 1)

    def forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        out = self.base(
            input_ids=input_ids,
            attention_mask=attention_mask,
            output_hidden_states=True,
            use_cache=False,
            return_dict=True,
        )
        hidden = out.hidden_states[-1]
        values = self.value_head(hidden).squeeze(-1)
        return out.logits, values

actor_base = AutoModelForCausalLM.from_pretrained(model_name_or_path, torch_dtype=dtype).to(device)
actor_base.config.use_cache = False
actor_base.config.pad_token_id = tokenizer.pad_token_id

if hasattr(actor_base, "gradient_checkpointing_enable") and device == "cuda":
    actor_base.gradient_checkpointing_enable()

actor_critic = ActorCritic(actor_base).to(device)
actor_critic.value_head.to(device=device, dtype=dtype)

# 参考策略 π_ref（冻结，放 CPU 省显存）
ref_device = "cpu"
ref_model = AutoModelForCausalLM.from_pretrained(model_name_or_path, torch_dtype=torch.float32).to(ref_device)
ref_model.eval()
for p in ref_model.parameters():
    p.requires_grad_(False)

print("ready")


## 3. 手写 LoRA 注入（不使用 peft）

对一个 Linear 权重 $W$，LoRA 训练的是低秩增量：

$$
W' = W + \Delta W, \quad \Delta W = \alpha/r \cdot BA
$$

其中 $A\in\mathbb{R}^{r\times d}$，$B\in\mathbb{R}^{k\times r}$ 是可训练参数，$W$ 冻结。


In [None]:
class LoRALinear(torch.nn.Module):
    def __init__(self, base: torch.nn.Linear, r: int = 8, alpha: int = 16, dropout: float = 0.05):
        super().__init__()
        if not isinstance(base, torch.nn.Linear):
            raise TypeError("LoRALinear only supports torch.nn.Linear")

        self.base = base
        self.r = r
        self.alpha = alpha
        self.scaling = alpha / r
        self.dropout = torch.nn.Dropout(dropout)

        in_features = base.in_features
        out_features = base.out_features

        self.lora_A = torch.nn.Parameter(torch.empty((r, in_features), dtype=base.weight.dtype, device=base.weight.device))
        self.lora_B = torch.nn.Parameter(torch.zeros((out_features, r), dtype=base.weight.dtype, device=base.weight.device))

        torch.nn.init.kaiming_uniform_(self.lora_A, a=5**0.5)

        self.base.weight.requires_grad_(False)
        if self.base.bias is not None:
            self.base.bias.requires_grad_(False)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        base_out = self.base(x)
        lora_out = (self.dropout(x) @ self.lora_A.t()) @ self.lora_B.t()
        return base_out + lora_out * self.scaling

def _get_parent_module(root: torch.nn.Module, module_name: str) -> Tuple[torch.nn.Module, str]:
    parts = module_name.split(".")
    parent = root
    for p in parts[:-1]:
        parent = getattr(parent, p)
    return parent, parts[-1]

def inject_lora(
    m: torch.nn.Module,
    target_suffixes: List[str],
    r: int = 8,
    alpha: int = 16,
    dropout: float = 0.05,
) -> List[str]:
    to_replace: List[Tuple[str, torch.nn.Module]] = []
    for name, module in m.named_modules():
        if any(name.endswith(sfx) for sfx in target_suffixes) and isinstance(module, torch.nn.Linear):
            to_replace.append((name, module))

    replaced = []
    for name, module in to_replace:
        parent, child = _get_parent_module(m, name)
        setattr(parent, child, LoRALinear(module, r=r, alpha=alpha, dropout=dropout))
        replaced.append(name)
    return replaced

# 冻结基座全部参数
for p in actor_critic.base.parameters():
    p.requires_grad_(False)

# 在常见投影层注入 LoRA（Qwen2）
target_suffixes = ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"]
replaced = inject_lora(actor_critic.base, target_suffixes, r=8, alpha=16, dropout=0.05)
print("lora_injected_modules:", len(replaced))

# value head 仍然可训练
for p in actor_critic.value_head.parameters():
    p.requires_grad_(True)

trainable = sum(p.numel() for p in actor_critic.parameters() if p.requires_grad)
total = sum(p.numel() for p in actor_critic.parameters())
print(f"trainable params: {trainable:,} / {total:,} ({100*trainable/total:.4f}%)")


## 4. Toy 任务与规则奖励 $r_{rm}(x,y)$

与无 LoRA 版相同。


In [None]:
train_tasks = [
    {"prompt": "请只回答一个词：小鱼儿", "expected": "小鱼儿"},
    {"prompt": "请只输出数字 4，不要额外文字。2+2等于几？", "expected": "4"},
    {"prompt": "把“我喜欢机器学习”翻译成英文，只输出翻译。", "expected": "I like machine learning"},
    {"prompt": "请只回答：通义千问", "expected": "通义千问"},
]

def normalize_text(s: str) -> str:
    s = s.strip().lower()
    for ch in [" ", "\n", "\t", "。", "，", ",", ".", "!", "?", "：", ":", "\"", "'"]:
        s = s.replace(ch, "")
    return s

def rule_reward(prompt: str, response: str, expected: str) -> float:
    resp = normalize_text(response)
    exp = normalize_text(expected)
    hit = exp in resp
    base = 1.0 if hit else -1.0
    length_penalty = 0.002 * len(resp)
    return base - length_penalty

train_tasks[0]


## 5. PPO 关键实现（与无 LoRA 版相同）

为了避免重复，这里直接复用同样的函数定义（logprob/KL/GAE/PPO update）。


In [None]:
@torch.inference_mode()
def build_prompt_input_ids(tok: Any, user_prompt: str) -> torch.Tensor:
    messages = [
        {"role": "system", "content": SYSTEM_PROMPT},
        {"role": "user", "content": user_prompt},
    ]
    text = tok.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
    enc = tok(text, return_tensors="pt")
    return enc["input_ids"][0]

@torch.inference_mode()
def sample_response(
    m: torch.nn.Module,
    tok: Any,
    prompt_input_ids: torch.Tensor,
    max_new_tokens: int = 48,
    temperature: float = 1.0,
    top_p: float = 0.9,
) -> Tuple[torch.Tensor, torch.Tensor]:
    m.eval()
    input_ids = prompt_input_ids.unsqueeze(0).to(device)

    old_use_cache = getattr(m.config, "use_cache", True)
    m.config.use_cache = True
    out = m.generate(
        input_ids=input_ids,
        max_new_tokens=max_new_tokens,
        do_sample=True,
        temperature=temperature,
        top_p=top_p,
        eos_token_id=tok.eos_token_id,
        pad_token_id=tok.pad_token_id,
    )
    m.config.use_cache = old_use_cache

    full_ids = out[0].detach().cpu()
    response_ids = full_ids[prompt_input_ids.numel() :]
    return full_ids, response_ids

def action_logprobs_from_logits(
    logits: torch.Tensor,
    input_ids: torch.Tensor,
    prompt_len: int,
    response_len: int,
) -> Tuple[torch.Tensor, torch.Tensor]:
    log_probs = F.log_softmax(logits, dim=-1)
    positions = torch.arange(prompt_len, prompt_len + response_len, device=logits.device)
    logp_positions = positions - 1
    action_ids = input_ids[0, positions]
    action_logp = log_probs[0, logp_positions, :].gather(dim=-1, index=action_ids.unsqueeze(-1)).squeeze(-1)

    step_log_probs = log_probs[0, logp_positions, :]
    entropy = -(step_log_probs.exp() * step_log_probs).sum(dim=-1)
    return action_logp, entropy

def get_policy_logp_value_entropy(
    ac: ActorCritic,
    full_ids: torch.Tensor,
    prompt_len: int,
    response_len: int,
    use_grad: bool,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    ctx = torch.enable_grad() if use_grad else torch.inference_mode()
    with ctx:
        input_ids = full_ids.unsqueeze(0).to(device)
        attention_mask = torch.ones_like(input_ids, dtype=torch.long)
        logits, values_all = ac(input_ids=input_ids, attention_mask=attention_mask)
        action_logp, entropy = action_logprobs_from_logits(logits, input_ids, prompt_len, response_len)
        positions = torch.arange(prompt_len, prompt_len + response_len, device=device)
        value_positions = positions - 1
        values = values_all[0, value_positions]
        return action_logp, values, entropy

@torch.inference_mode()
def get_ref_logp(
    ref: torch.nn.Module,
    full_ids: torch.Tensor,
    prompt_len: int,
    response_len: int,
) -> torch.Tensor:
    input_ids = full_ids.unsqueeze(0).to(ref_device)
    attention_mask = torch.ones_like(input_ids, dtype=torch.long)
    logits = ref(input_ids=input_ids, attention_mask=attention_mask, use_cache=False, return_dict=True).logits
    log_probs = F.log_softmax(logits, dim=-1)
    positions = torch.arange(prompt_len, prompt_len + response_len, device=logits.device)
    logp_positions = positions - 1
    action_ids = input_ids[0, positions]
    logp = log_probs[0, logp_positions, :].gather(dim=-1, index=action_ids.unsqueeze(-1)).squeeze(-1)
    return logp.to(device)

def compute_gae(rewards: torch.Tensor, values: torch.Tensor, gamma: float, gae_lambda: float) -> Tuple[torch.Tensor, torch.Tensor]:
    T = rewards.shape[0]
    advantages = torch.zeros_like(rewards)
    last_gae = torch.zeros((), device=rewards.device, dtype=rewards.dtype)
    for t in reversed(range(T)):
        next_value = values[t + 1] if t < T - 1 else torch.zeros((), device=values.device, dtype=values.dtype)
        delta = rewards[t] + gamma * next_value - values[t]
        last_gae = delta + gamma * gae_lambda * last_gae
        advantages[t] = last_gae
    returns = advantages + values
    return advantages, returns


## 6. PPO 训练循环（LoRA 版）

与无 LoRA 版相同，但：

- 学习率可以相对大一些
- 优化器只包含 `requires_grad=True` 的参数（LoRA + value head）


In [None]:
train_steps = 40
max_new_tokens = 48

kl_coef = 0.05
gamma = 1.0
gae_lambda = 0.95

clip_eps = 0.2
value_clip_eps = 0.2
vf_coef = 0.5
ent_coef = 0.0
ppo_epochs = 2
max_grad_norm = 1.0

lr = 1e-4
optimizer = torch.optim.AdamW([p for p in actor_critic.parameters() if p.requires_grad], lr=lr)

use_amp = device == "cuda"
use_bf16 = use_amp and torch.cuda.is_bf16_supported()
autocast_dtype = torch.bfloat16 if use_bf16 else torch.float16
autocast_device_type = "cuda" if use_amp else "cpu"
scaler = torch.cuda.amp.GradScaler(enabled=use_amp and not use_bf16)

def ppo_update_one_trajectory(
    full_ids: torch.Tensor,
    prompt_len: int,
    response_len: int,
    logpi_old: torch.Tensor,
    values_old: torch.Tensor,
    advantages: torch.Tensor,
    returns: torch.Tensor,
) -> Dict[str, float]:
    actor_critic.train()
    advantages = (advantages - advantages.mean()) / (advantages.std().clamp_min(1e-8))

    metrics: Dict[str, float] = {}
    for epoch in range(ppo_epochs):
        with torch.autocast(device_type=autocast_device_type, dtype=autocast_dtype, enabled=use_amp):
            logpi, values, entropy = get_policy_logp_value_entropy(
                actor_critic, full_ids, prompt_len, response_len, use_grad=True
            )

            ratio = torch.exp(logpi - logpi_old)
            surr1 = ratio * advantages
            surr2 = torch.clamp(ratio, 1.0 - clip_eps, 1.0 + clip_eps) * advantages
            policy_loss = -torch.min(surr1, surr2).mean()

            values_clipped = values_old + (values - values_old).clamp(-value_clip_eps, value_clip_eps)
            v_loss1 = (values - returns) ** 2
            v_loss2 = (values_clipped - returns) ** 2
            value_loss = 0.5 * torch.max(v_loss1, v_loss2).mean()

            entropy_loss = -entropy.mean()
            loss = policy_loss + vf_coef * value_loss + ent_coef * entropy_loss

        optimizer.zero_grad(set_to_none=True)
        if scaler.is_enabled():
            scaler.scale(loss).backward()
            scaler.unscale_(optimizer)
            clip_grad_norm_([p for p in actor_critic.parameters() if p.requires_grad], max_grad_norm)
            scaler.step(optimizer)
            scaler.update()
        else:
            loss.backward()
            clip_grad_norm_([p for p in actor_critic.parameters() if p.requires_grad], max_grad_norm)
            optimizer.step()

        approx_kl = (logpi_old - logpi).mean().detach()
        clip_frac = ((ratio - 1.0).abs() > clip_eps).float().mean().detach()

        metrics = {
            "loss": float(loss.detach().cpu().item()),
            "policy_loss": float(policy_loss.detach().cpu().item()),
            "value_loss": float(value_loss.detach().cpu().item()),
            "entropy": float(entropy.mean().detach().cpu().item()),
            "approx_kl": float(approx_kl.cpu().item()),
            "clip_frac": float(clip_frac.cpu().item()),
        }

    return metrics

@torch.inference_mode()
def rollout_one(prompt: str, expected: str) -> Dict[str, Any]:
    prompt_ids = build_prompt_input_ids(tokenizer, prompt)
    full_ids, response_ids = sample_response(actor_critic.base, tokenizer, prompt_ids, max_new_tokens=max_new_tokens)
    response_text = tokenizer.decode(response_ids, skip_special_tokens=True)

    prompt_len = prompt_ids.numel()
    response_len = int(response_ids.numel())
    if response_len == 0:
        return {"skip": True, "prompt": prompt, "response": response_text}

    rm_reward = rule_reward(prompt, response_text, expected)

    logpi_old, values_old, _ = get_policy_logp_value_entropy(
        actor_critic, full_ids, prompt_len, response_len, use_grad=False
    )
    logpi_ref = get_ref_logp(ref_model, full_ids, prompt_len, response_len)

    kl = logpi_old - logpi_ref
    rewards = -kl_coef * kl
    rewards[-1] = rewards[-1] + torch.tensor(rm_reward, device=device, dtype=rewards.dtype)

    advantages, returns = compute_gae(rewards, values_old, gamma=gamma, gae_lambda=gae_lambda)

    return {
        "skip": False,
        "prompt": prompt,
        "expected": expected,
        "response": response_text,
        "full_ids": full_ids,
        "prompt_len": prompt_len,
        "response_len": response_len,
        "rm_reward": float(rm_reward),
        "kl_mean": float(kl.mean().detach().cpu().item()),
        "logpi_old": logpi_old,
        "values_old": values_old,
        "advantages": advantages.detach(),
        "returns": returns.detach(),
    }

print("ready for training")


In [None]:
for step in range(train_steps):
    task = random.choice(train_tasks)
    rollout = rollout_one(task["prompt"], task["expected"])
    if rollout["skip"]:
        print(f"step={step} skip(empty response)")
        continue

    metrics = ppo_update_one_trajectory(
        full_ids=rollout["full_ids"],
        prompt_len=rollout["prompt_len"],
        response_len=rollout["response_len"],
        logpi_old=rollout["logpi_old"],
        values_old=rollout["values_old"],
        advantages=rollout["advantages"],
        returns=rollout["returns"],
    )

    print(
        f"step={step} rm_reward={rollout['rm_reward']:.3f} kl_mean={rollout['kl_mean']:.3f} "
        f"loss={metrics['loss']:.4f} policy={metrics['policy_loss']:.4f} value={metrics['value_loss']:.4f} "
        f"approx_kl={metrics['approx_kl']:.4f} clip_frac={metrics['clip_frac']:.3f}"
    )
    print("prompt:", rollout["prompt"])
    print("response:", rollout["response"])
    print("-" * 60)


## 7. 简单验证

训练后再看一眼输出。


In [None]:
@torch.inference_mode()
def chat(prompt: str, max_new_tokens: int = 64) -> str:
    prompt_ids = build_prompt_input_ids(tokenizer, prompt)
    full_ids, response_ids = sample_response(actor_critic.base, tokenizer, prompt_ids, max_new_tokens=max_new_tokens, temperature=0.7, top_p=0.9)
    return tokenizer.decode(response_ids, skip_special_tokens=True)

for t in train_tasks:
    print("Q:", t["prompt"])
    print("A:", chat(t["prompt"]))
    print()
