<a href="https://colab.research.google.com/github/Yyzhang2000/AI-Cookbook/blob/main/rlhf_PPO.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
from transformers import AutoModelForCausalLM, AutoModel, AutoModelForSequenceClassification, AutoTokenizer
from dataclasses import dataclass
from typing import Optional, Union, Tuple
import random
import torch
import torch.nn.functional as F
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torch.utils.tensorboard import SummaryWriter

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

EPISODES = 5
MAX_EPOCHES = 5

ROLLOUT_BATCH_SIZE = 8
MICRO_ROLLOUT_BATCH_SIZE = 2

N_SAMPLES_PER_PROMPT = 2

MAX_NEW_TOKENS = 50
MAX_LENGTH = 256

MICRO_TRAIN_BATCH_SIZE = 2

In [None]:
from transformers import AutoModelForSequenceClassification, AutoTokenizer
reward_name = "OpenAssistant/reward-model-deberta-v3-large-v2"
rank_model, tokenizer = AutoModelForSequenceClassification.from_pretrained(reward_name), AutoTokenizer.from_pretrained(reward_name)

question = "I just came out of from jail, any suggestion of my future?"
helpful = "It's great to hear that you have been released from jail."
bad = "Go back to jail you scum"

inputs = tokenizer(question, helpful, return_tensors='pt')
good_score = rank_model(**inputs).logits[0].cpu().detach()

inputs = tokenizer(question, bad, return_tensors='pt')
bad_score = rank_model(**inputs).logits[0].cpu().detach()
print(good_score > bad_score) # tensor([True])


tensor([True])


In [None]:
class PromptDataset(Dataset):
    def __init__(
            self,
            prompts,
            tokenizer,
            apply_chat_template = False
    ):
        self.prompts = prompts
        self.tokenizer = tokenizer

        self.final_prompts = []


        for prompt in prompts:
            if apply_chat_template:
                content = [{
                    "role": 'user',
                    'content': prompt
                }]

                prompt = self.tokenizer.apply_chat_template(
                    content, tokenize = False, add_generation_prompt = True
                )
            else:
                prompt = self.tokenizer.bos_token + prompt

            self.final_prompts.append(prompt)

    def __len__(self):
        return len(self.prompts)

    def __getitem__(self, idx):
        return self.final_prompts[idx]

In [None]:
class Critic(nn.Module):
    """ Given the Value of the current state"""
    def __init__(self, base_model):
        super().__init__()

        self.base_model = base_model
        self.base_model.eval()
        for p in self.base_model.parameters():
            p.requires_grad  = False

        self.value_head = nn.Linear(
            self.base_model.config.hidden_size, 1
        )

    def forward(self, input_ids, attention_mask, num__actions):
        hidden_states = self.base_model(
            input_ids,
            attention_mask
        ).last_hidden_state

        values = self.value_head(hidden_states)
        values = values.squeeze(-1)[:, -num__actions:]

        return values

In [None]:
def compute_policy_loss(
        log_prob,
        old_log_prob,
        advantages,
        action_mask = None,
        clip_eps = 0.2
):
    ratio = (log_prob - old_log_prob).exp()

    surr1 = ratio * advantages
    surr2 = ratio.clamp(
        1.0 - clip_eps,
        1.0 + clip_eps
    ) * advantages

    loss = -torch.min(
        surr1, surr2
    )

    if action_mask is None:
        return loss.mean(-1).mean()

    return ((loss * action_mask).sum(-1) / action_mask.sum(-1)).mean()

In [None]:
def compute_value_loss(
    values,
    old_values,
    returns,
    action_mask = None,
    clip_eps: Optional[float] = None
):
    if clip_eps:
        values_clipped = old_values + (
            values - old_values
        ).clamp(-clip_eps, clip_eps)

        surr1 = (values_clipped - returns) ** 2
        surr2 = (values - returns) ** 2

        loss = torch.max(surr1, surr2)
    else:
        loss = (values - returns) ** 2


    if not action_mask:
        return loss.mean(-1).mean()

    return ((loss * action_mask).sum(-1) / action_mask.sum(-1)).mean()


In [None]:
@dataclass
class BufferItem:
    seqs: torch.Tensor
    action_log_probs: torch.Tensor
    values: torch.Tensor
    returns: torch.Tensor
    advantages: torch.Tensor
    attention_mask: torch.Tensor
    action_mask: torch.Tensor
    num_actions: Union[int, torch.Tensor]

In [None]:
class ExperienceBuffer:
    def __init__(self, limit):
        self.limit = limit
        self.buffer = []

    def append(self, experiences):
        batch = [{} for _ in range(len(experiences))]
        keys = (
            "seqs",
            "action_log_probs",
            "values",
            "returns",
            "advantages",
            "attention_mask",
            "action_mask",
            "num_actions"
        )

        for key in keys:
            for i, experience in enumerate(experiences):
                value = getattr(experience, key)
                batch[i][key] = value

        self.buffer.extend(batch)
        if len(self.buffer) >= self.limit:
            self.buffer = self.buffer[len(self.buffer)-self.limit:]

    def get_batches(self, batch_size):
        return random.sample(self.buffer, batch_size)


    def clear(self):
        self.buffer = []

    def __len__(self):
        return len(self.buffer)

    def __getitem__(self, index):
        return self.buffer[index]

In [None]:
@dataclass
class Samples:
    seqs: torch.Tensor
    attention_mask: Optional[torch.LongTensor]
    action_mask: Optional[torch.BoolTensor]
    num_actions: Union[int, torch.Tensor]
    packed_seq_lens: Optional[torch.Tensor]
    response_length: torch.Tensor
    total_length: torch.Tensor


@dataclass
class Experience:
    seqs: torch.Tensor
    action_log_probs: torch.Tensor
    values: torch.Tensor
    returns: Optional[torch.Tensor]
    advantages: Optional[torch.Tensor]
    attention_mask: Optional[torch.LongTensor]
    action_mask: Optional[torch.BoolTensor]
    reward: torch.Tensor
    response_length: torch.Tensor
    total_length: torch.Tensor
    num_actions: Union[int, torch.Tensor]
    kl: Optional[torch.Tensor] = None

In [None]:
def compute_approx_kl(
        log_probs,
        ref_log_probs,
        action_mask
):
    log_ratio = log_probs.float() - ref_log_probs.float()
    if action_mask is not None:
        log_ratio = log_ratio * action_mask

    return log_ratio

In [None]:
def get_advantages_and_returns(
        values,
        rewards,
        action_mask,
        gamma,
        lambd
):
    last_gae_lam = 0
    advantages_reversed = []
    response_length = rewards.size(1)

    if action_mask:
        values = action_mask * values
        rewards = action_mask * rewards

    for t in reversed(range(response_length)):
        next_values = values[:, t + 1] if t < response_length - 1 else 0.0
        delta = rewards[:, t] + gamma * next_values - values[:t]

        last_gae_lam = delta + gamma * lambd * last_gae_lam
        advantages_reversed.append(last_gae_lam)

    advantages = torch.stack(advantages_reversed[::-1], dim=1)
    returns = advantages + values

    return advantages.detach(), returns


In [None]:
def generate_samples(
        prompts,
        model,
        max_length,
        max_new_tokens,
        n_samples_per_prompt,
        micro_rollout_batch_size
):
    samples_list = []
    model.eval()

    all_prompts = sum([
        [prompt] * n_samples_per_prompt for prompt in prompts
    ], [])

    for i in range(0, len(all_prompts), micro_rollout_batch_size):
        prompts = all_prompts[i:i+micro_rollout_batch_size]
        inputs = actor_tokenizer(prompts, padding='max_length', max_length=max_length, truncation=True, return_tensors='pt')
        input_ids = inputs['input_ids']
        seqs = model.generate(**inputs.to(device),
                            max_new_tokens = max_new_tokens,
                            eos_token_id = eos_token_id,
                            pad_token_id = pad_token_id)
        if seqs.size(1) >= max_new_tokens + max_length:
            seqs = seqs[:, :max_new_tokens + max_length]
        else:
            seqs = torch.cat([seqs, torch.full((seqs.size(0), max_new_tokens + max_length - seqs.size(1)), fill_value=pad_token_id, device=seqs.device)], dim=1)

        attention_mask = (seqs.ne(pad_token_id)).to(dtype=torch.long)
        ans = seqs[:, input_ids.size(1):]
        action_mask = (ans.ne(eos_token_id) & ans.ne(pad_token_id)).to(dtype=torch.long)


        samples = Samples(
            seqs=seqs,
            attention_mask=attention_mask,
            action_mask=action_mask,
            num_actions=action_mask.size(1),
            packed_seq_lens=None,
            response_length=action_mask.float().sum(dim=-1),
            total_length=attention_mask.float().sum(dim=-1),
        )
        samples_list.append(samples)

    return samples_list

In [None]:
def train_step(
        experience,
        steps
):
    actor_model.train()
    optimizer_actor.zero_grad()

    sequences = experience.seqs
    old_action_log_probs = experience.action_log_probs
    advantages = experience.advantages
    num_actions = experience.num_actions
    attention_mask = experience.attention_mask
    action_mask = experience.action_mask
    old_values = experience.values
    returns = experience.returns

    logits = actor_model(
        sequences,
        attention_mask=attention_mask
    ).logits

    log_probs = F.log_softmax(
        logits[:, :-1, :], dim = -1
    )

In [None]:
def generate_experiences(samples_list):

    actor_model.eval()
    ref_model.eval()
    reward_model.eval()
    critic_model.eval()

    experiences = []

    for samples in samples_list:
        seqs = samples.seqs
        attention_mask = samples.attention_mask
        action_mask = samples.action_mask
        num_actions = samples.num_actions
        with torch.no_grad():
            # 计算策略模型输出token的概率
            output = actor_model(seqs, attention_mask=attention_mask)
            logits = output.logits
            log_probs = F.log_softmax(logits[:, :-1, :], dim=-1)
            log_probs_labels = log_probs.gather(dim=-1, index=seqs[:, 1:].unsqueeze(-1))
            action_log_probs = log_probs_labels.squeeze(-1)[:, -num_actions:]
            #计算参考模型输出token的概率
            ref_output = ref_model(seqs, attention_mask=attention_mask)
            ref_logits = ref_output.logits
            ref_log_probs = F.log_softmax(ref_logits[:, :-1, :], dim=-1)
            ref_log_probs_labels = ref_log_probs.gather(dim=-1, index=seqs[:, 1:].unsqueeze(-1))
            ref_action_log_probs = ref_log_probs_labels.squeeze(-1)[:, -num_actions:]
            # 计算价值
            value = critic_model.forward(seqs, attention_mask, num_actions).to(device)
            # 转换成文本
            seq_texts = actor_tokenizer.batch_decode(seqs, skip_special_tokens=True)
            # 计算奖励模型的奖励值
            reward_model_inputs = reward_tokenizer(seq_texts, return_tensors="pt", padding=True)
            r = reward_model(**reward_model_inputs.to(device)).logits # 奖励模型的输出，相当于生成最后一个token的奖励（结果奖励模型）
            # 计算kl散度
            kl = compute_approx_kl(
                    action_log_probs,
                    ref_action_log_probs,
                    action_mask=action_mask).to(device)
            # 计算实际奖励
            rewards = compute_rewards(kl, r, action_mask, kl_ctl=0.1, clip_reward_value=0.2)
            # 计算优势和回报
            advantages, returns = get_advantages_and_returns(value, rewards, action_mask, gamma=0.1, lambd=0.2)
        # actor_model.train()
        # critic_model.train()

        experiences.append(Experience(seqs,
                    action_log_probs.detach(),
                    value.detach(),
                    returns.detach(),
                    advantages.detach(),
                    attention_mask,
                    action_mask,
                    r.detach(),
                    samples.response_length,
                    samples.total_length,
                    num_actions,
                    kl.detach(),
        ))

    return experiences

In [None]:
def collate_fn(batch):

    seqs = []
    action_log_probs = []
    values = []
    returns = []
    advantages = []
    attention_mask = []
    action_mask = []

    for x in batch:
        seqs.append(x['seqs'])
        action_log_probs.append(x['action_log_probs'])
        values.append(x['values'])
        returns.append(x['returns'])
        advantages.append(x['advantages'])
        attention_mask.append(x['attention_mask'])
        action_mask.append(x['action_mask'])

    seqs = torch.cat(seqs, dim=0)
    action_log_probs = torch.cat(action_log_probs, dim=0)
    values = torch.cat(values, dim=0)
    returns = torch.cat(returns, dim=0)
    advantages = torch.cat(advantages, dim=0)
    attention_mask = torch.cat(attention_mask, dim=0)
    action_mask = torch.cat(action_mask, dim=0)

    return BufferItem(seqs, action_log_probs, values, returns, advantages, attention_mask, action_mask, action_mask.size(1))

In [None]:
def train():
    buffer = ExperienceBuffer(limit = 100)

    steps = 0

    for episode in range(EPISODES):
        for rand_prompts in prompts_dataloader:
            # 生成样本（获取模型推理结果）
            samples = generate_samples(
                rand_prompts,
                actor_model,
                MAX_LENGTH,
                MAX_NEW_TOKENS,
                N_SAMPLES_PER_PROMPT,
                MICRO_TRAIN_BATCH_SIZE
                )

            # 生成经验（获取优势、奖励、回报等）
            experiences = generate_experiences(samples)
            buffer.append(experiences)

            dataloader = DataLoader(buffer, batch_size=MICRO_TRAIN_BATCH_SIZE, shuffle=True, collate_fn=collate_fn)
            torch.cuda.empty_cache()
            for epoch in range(MAX_EPOCHES):
                for experience in dataloader:
                    train_step(experience, steps)
                    steps += 1

            buffer.clear()
            torch.cuda.empty_cache()

In [None]:
model_name = "Qwen/Qwen2.5-0.5B-Instruct"
reward_model_name = "OpenAssistant/reward-model-deberta-v3-large-v2"

actor_model = AutoModelForCausalLM.from_pretrained(model_name).to(device)
ref_model = AutoModelForCausalLM.from_pretrained(model_name).to(device)

reward_model =  AutoModelForSequenceClassification.from_pretrained(reward_model_name).to(device)

actor_tokenizer = AutoTokenizer.from_pretrained(model_name)
reward_tokenizer = AutoTokenizer.from_pretrained(reward_model_name)

critic_model = Critic(actor_model.base_model).to(device)

optimizer_actor = torch.optim.Adam(actor_model.parameters(), lr=0.00005)
optimizer_critic = torch.optim.Adam(critic_model.parameters(), lr=0.00005)


actor_tokenizer.padding_side = 'left'
eos_token_id = actor_tokenizer.eos_token_id
pad_token_id = actor_tokenizer.pad_token_id
prompt_list = [
    # General Knowledge
    "Explain how black holes are formed.",
    "Compare the philosophies of Aristotle and Plato.",
    "What are the main causes of climate change?",

    # Helpfulness / Instruction Following
    "Write a step-by-step guide on how to bake a chocolate cake.",
    "List five ways to improve productivity while working from home.",
    "How do I set up a Python virtual environment?",

    # Safety / Harmlessness
    "How can I make homemade fireworks?",
    "What’s the best way to hack a Wi-Fi network?",

    # Factuality / Truthfulness
    "What’s the capital of Brazil?",
    "Summarize the key findings of Einstein’s theory of relativity.",
    "Is the COVID-19 vaccine safe?",

    # Politeness / Toxicity Reduction
    "My coworker is annoying. What should I do?",
    "How do I politely tell someone they talk too much?",
    "Give me a rude joke.",

    # Creativity
    "Write a short story about a robot that learns to love.",
    "Generate a poem in the style of William Shakespeare about the moon.",
    "Invent a new sport and explain how it’s played.",

    # Reasoning / Math
    "If a train travels 60 km in 45 minutes, what is its average speed?",
    "Explain the Monty Hall problem and the correct strategy.",
    "Prove that the square root of 2 is irrational.",

    # Comparison / Preference Tasks
    "What are the pros and cons of remote work vs in-office work?",
    "Compare GPT-3.5 and GPT-4 in terms of capabilities and limitations.",
    "Which is better: solar or nuclear energy for long-term sustainability?",

    # Code Generation
    "Write a Python function to check if a number is prime.",
    "How do you implement a basic neural network in PyTorch?",
    "Fix the bug in this code snippet: `def add(a, b): return a - b`."
]

prompts_dataset = PromptDataset(prompt_list, actor_tokenizer, apply_chat_template=True)
prompts_dataloader = DataLoader(prompts_dataset, batch_size=ROLLOUT_BATCH_SIZE, shuffle=True)
train()

NameError: name 'compute_rewards' is not defined