<a href="https://colab.research.google.com/github/Yyzhang2000/AI-Cookbook/blob/main/rlhf_PPO_%F0%9F%9A%A7.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
from transformers import AutoModelForCausalLM, AutoModel, AutoModelForSequenceClassification, AutoTokenizer
from dataclasses import dataclass
from typing import Optional, Union, Tuple
import random
import torch
import torch.nn.functional as F
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torch.utils.tensorboard import SummaryWriter

In [2]:
class PromptDataset(Dataset):
    def __init__(
            self,
            prompts,
            tokenizer,
            apply_chat_template = False
    ):
        self.prompts = prompts
        self.tokenizer = tokenizer

        self.final_prompts = []

        for prompt in prompts:
            if apply_chat_template:
                content = [{
                    "role": 'user',
                    'content': prompt
                }]
                prompt = self.tokenizer.apply_chat_template(
                    content, tokenize = False, add_generation_prompt = True
                )
            else:
                prompt = self.tokenizer.bos_token + prompt

            print(prompt)
            self.final_prompts.append(prompt)

    def __len__(self):
        return len(self.prompts)

    def __getitem__(self, idx):
        return self.final_prompts[idx]

In [3]:
class Critic(nn.Module):
    def __init__(self, base_model):
        super().__init__()

        self.base_model = base_model
        self.base_model.eval()

        for p in self.base_model.parameters():
            p.requires_grad = False

        self.value_head = nn.Linear(
            base_model.config.hidden_size, 1
        )

    def forward(self, input_ids, attention_mask, num_actions):
        hidden_states = self.base_model(
            input_ids,
            attention_mask
        ).last_hidden_state

        value_model_output = self.value_head(hidden_states)
        values = value_model_output.squeeze(-1)[:, -num_actions:]

        return values


In [4]:
def compute_policy_loss(
        log_prob,
        old_log_prob,
        advantages,
        action_mask = None,
        clip_eps = 0.2
):
    ratio = (log_prob - old_log_prob).exp()

    surr1 = ratio * advantages
    surr2 = ratio.clamp(
        1.0 - clip_eps,
        1.0 + clip_eps
    ) * advantages

    loss = -torch.min(
        surr1, surr2
    )

    if action_mask is None:
        return loss.mean(-1).mean()

    return ((loss * action_mask).sum(-1) / action_mask.sum(-1)).mean()

In [5]:
def compute_value_loss(
    values,
    old_values,
    returns,
    action_mask = None,
    clip_eps: Optional[float] = None
):
    if clip_eps:
        values_clipped = old_values + (
            values - old_values
        ).clamp(-clip_eps, clip_eps)

        surr1 = (values_clipped - returns) ** 2
        surr2 = (values - returns) ** 2

        loss = torch.max(surr1, surr2)
    else:
        loss = (values - returns) ** 2


    if not action_mask:
        return loss.mean(-1).mean()

    return ((loss * action_mask).sum(-1) / action_mask.sum(-1)).mean()


In [6]:
class ExperienceBuffer:
    def __init__(self, limit):
        self.limit = limit
        self.buffer = []


    def append(self, experiences):
        batch = [{} for _ in range(len(experiences))]
        keys = (
            "seqs",
            "action_log_probs",
            "values",
            "returns",
            "advantages",
            "attention_mask",
            "action_mask",
            "num_actions"
        )

        for key in keys:
            for i, experience in enumerate(experiences):
                value = getattr(experience, key)
                batch[i][key] = value

        self.buffer.extend(batch)

        self.buffer.extend(batch)
        if len(self.buffer) >= self.limit:
            self.buffer = self.buffer[len(self.buffer)-self.limit:]

    def get_batches(self, batch_size):
        return random.sample(self.buffer, batch_size)


    def clear(self):
        self.buffer = []

    def __len__(self):
        return len(self.buffer)

    def __getitem__(self, index):
        return self.buffer[index]


In [None]:
@dataclass
class Samples:
    seqs: torch.Tensor
    attention_mask: Optional[torch.LongTensor]
    action_mask: Optional[torch.BoolTensor]
    num_actions: Union[int, torch.Tensor]
    packed_seq_lens: Optional[torch.Tensor]
    response_length: torch.Tensor
    total_length: torch.Tensor


@dataclass
class Experience:
    seqs: torch.Tensor
    action_log_probs: torch.Tensor
    values: torch.Tensor
    returns: Optional[torch.Tensor]
    advantages: Optional[torch.Tensor]
    attention_mask: Optional[torch.LongTensor]
    action_mask: Optional[torch.BoolTensor]
    reward: torch.Tensor
    response_length: torch.Tensor
    total_length: torch.Tensor
    num_actions: Union[int, torch.Tensor]
    kl: Optional[torch.Tensor] = None

In [7]:
def compute_approx_kl(
        log_probs,
        ref_log_probs,
        action_mask
):
    log_ratio = log_probs.float() - ref_log_probs.float()
    if action_mask:
        log_ratio = log_ratio * action_mask

    return log_ratio

In [9]:
def get_advantages_and_returns(
        values,
        rewards,
        action_mask,
        gamma,
        lambd
):
    last_gae_lam = 0
    advantages_reversed = []
    response_length = rewards.size(1)

    if action_mask:
        values = action_mask * values
        rewards = action_mask * rewards

    for t in reversed(range(response_length)):
        next_values = values[:, t + 1] if t < response_length - 1 else 0.0

        delta = rewards[:, t] + gamma * next_values - values[:t]
        last_gae_lam = delta + gamma * lambd * last_gae_lam
        advantages_reversed.append(last_gae_lam)

    advantages = torch.stack(advantages_reversed[::-1], dim=1)
    returns = advantages + values

    return advantages.detach(), returns


In [None]:
def generate_samples(
        prompts,
        model,
        max_length,
        max_new_tokens,
        n_samples_per_prompt,
        micro_rollout_batch_size
):
    samples_list = []
    model.eval()

    all_prompts = sum([
        [prompt] * n_samples_per_prompt for prompt in prompts
    ], [])

    for i in range(0, len(all_prompts), micro_rollout_batch_size):
        prompts = all_prompts[i:i+micro_rollout_batch_size]
        inputs = actor_tokenizer(prompts, padding='max_length', max_length=max_length, truncation=True, return_tensors='pt')
        input_ids = inputs['input_ids']
        seqs = model.generate(**inputs.to(device),
                            max_new_tokens = max_new_tokens,
                            eos_token_id = eos_token_id,
                            pad_token_id = pad_token_id)
        if seqs.size(1) >= max_new_tokens + max_length:
            seqs = seqs[:, :max_new_tokens + max_length]
        else:
            seqs = torch.cat([seqs, torch.full((seqs.size(0), max_new_tokens + max_length - seqs.size(1)), fill_value=pad_token_id, device=seqs.device)], dim=1)

        attention_mask = (seqs.ne(pad_token_id)).to(dtype=torch.long)
        ans = seqs[:, input_ids.size(1):]
        action_mask = (ans.ne(eos_token_id) & ans.ne(pad_token_id)).to(dtype=torch.long)


        samples = Samples(
            seqs=seqs,
            attention_mask=attention_mask,
            action_mask=action_mask,
            num_actions=action_mask.size(1),
            packed_seq_lens=None,
            response_length=action_mask.float().sum(dim=-1),
            total_length=attention_mask.float().sum(dim=-1),
        )
        samples_list.append(samples)

    return samples_list

In [10]:
@dataclass
class BufferItem:
    seqs: torch.Tensor
    action_log_probs: torch.Tensor
    values: torch.Tensor
    returns: torch.Tensor
    advantages: torch.Tensor
    attention_mask: torch.Tensor
    action_mask: torch.Tensor
    num_actions: Union[int, torch.Tensor]

In [None]:
def train_step(
        experience,
        steps
):
    actor_model.train()
    optimizer_actor.zero_grad()

    sequences = experience.seqs
    old_action_log_probs = experience.action_log_probs
    advantages = experience.advantages
    num_actions = experience.num_actions
    attention_mask = experience.attention_mask
    action_mask = experience.action_mask
    old_values = experience.values
    returns = experience.returns

    logits = actor_model(
        sequences,
        attention_mask=attention_mask
    ).logits

    log_probs = F.log_softmax(
        logits[:, :-1, :], dim = -1
    )