<a href="https://colab.research.google.com/github/Papa-Panda/Paper_reading/blob/main/PPO.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# add link

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from transformers import GPT2LMHeadModel, GPT2Tokenizer

# Load GPT-2 model and tokenizer
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
tokenizer.pad_token = tokenizer.eos_token
model = GPT2LMHeadModel.from_pretrained("gpt2")

# Set model to train on CPU
device = torch.device("cpu")
model.to(device)

# PPO Parameters
ppo_epochs = 4
clip_epsilon = 0.2
lr = 5e-5
gamma = 0.99

# Dataset for prompts and responses
class PromptDataset(Dataset):
    def __init__(self, prompts):
        self.prompts = prompts

    def __len__(self):
        return len(self.prompts)

    def __getitem__(self, idx):
        return self.prompts[idx]

# Example prompts
prompts = [
    "What is the capital of France?",
    "Explain the theory of relativity.",
    "Why is the sky blue?",
    "Tell me a joke about computers."
]
dataset = PromptDataset(prompts)
dataloader = DataLoader(dataset, batch_size=2, shuffle=True)

# Reward simulation
def get_reward(response):
    """Simulated reward function."""
    if "Paris" in response:
        return 1.0  # Example: reward for correct answer
    elif "joke" in response:
        return 0.5  # Reward for mentioning a joke
    else:
        return 0.1  # Default reward

# PPO Training Loop
optimizer = optim.Adam(model.parameters(), lr=lr)

for epoch in range(3):  # Outer training loop
    for batch in dataloader:
        # Generate responses
        batch = [prompt for prompt in batch]
        inputs = tokenizer(batch, return_tensors="pt", padding=True, truncation=True).to(device)
        with torch.no_grad():
            outputs = model.generate(input_ids=inputs["input_ids"], max_length=50, num_return_sequences=1)

        responses = [tokenizer.decode(output, skip_special_tokens=True) for output in outputs]

        # Calculate rewards
        rewards = torch.tensor([get_reward(response) for response in responses], dtype=torch.float32).to(device)

        # Prepare for PPO
        optimizer.zero_grad()

        # Train with PPO
        for _ in range(ppo_epochs):
            # Recompute old logits and old log probs

            # Compute old logits and old log probs (fixed policy before PPO updates)
            with torch.no_grad():
                old_logits = model(inputs["input_ids"], attention_mask=inputs["attention_mask"]).logits
                old_log_probs = torch.log_softmax(old_logits, dim=-1)
                old_log_probs = old_log_probs.gather(-1, inputs["input_ids"].unsqueeze(-1)).squeeze(-1)

            # Train with PPO
            for _ in range(ppo_epochs):
                # Compute new logits and new log probs (updated policy)
                new_logits = model(inputs["input_ids"], attention_mask=inputs["attention_mask"]).logits
                new_log_probs = torch.log_softmax(new_logits, dim=-1)
                new_log_probs = new_log_probs.gather(-1, inputs["input_ids"].unsqueeze(-1)).squeeze(-1)

                # Compute ratios
                ratios = torch.exp(new_log_probs - old_log_probs)

                # Expand advantages to match token-level shape
                advantages = rewards - rewards.mean()  # Shape: [batch_size]
                advantages = advantages.unsqueeze(-1)  # Shape: [batch_size, 1]

                # PPO loss
                surrogate1 = ratios * advantages
                surrogate2 = torch.clamp(ratios, 1 - clip_epsilon, 1 + clip_epsilon) * advantages
                loss = -torch.min(surrogate1, surrogate2).mean()

                # Backward and optimize
                optimizer.zero_grad()
                loss.backward()
                for name, param in model.named_parameters():
                    if param.requires_grad:  # Only track trainable parameters
                        print(f"Epoch {epoch}, Before Update, {name}: {param.view(-1)[0].item()}")  # Flatten and get first value
                        break

                optimizer.step()

                for name, param in model.named_parameters():
                    if param.requires_grad:  # Only track trainable parameters
                        print(f"Epoch {epoch}, After Update, {name}: {param.view(-1)[0].item()}")  # Flatten and get first value
                        break
        print(f"Epoch {epoch}, Loss: {loss.item()}")

print("Training complete!")

The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


Epoch 0, Before Update, transformer.wte.weight: -0.11010301113128662
Epoch 0, After Update, transformer.wte.weight: -0.11005301773548126
Epoch 0, Before Update, transformer.wte.weight: -0.11005301773548126
Epoch 0, After Update, transformer.wte.weight: -0.11001952737569809
Epoch 0, Before Update, transformer.wte.weight: -0.11001952737569809
Epoch 0, After Update, transformer.wte.weight: -0.10999363660812378
Epoch 0, Before Update, transformer.wte.weight: -0.10999363660812378
Epoch 0, After Update, transformer.wte.weight: -0.10997243225574493
Epoch 0, Before Update, transformer.wte.weight: -0.10997243225574493
Epoch 0, After Update, transformer.wte.weight: -0.10994093120098114
Epoch 0, Before Update, transformer.wte.weight: -0.10994093120098114
Epoch 0, After Update, transformer.wte.weight: -0.10991378128528595
Epoch 0, Before Update, transformer.wte.weight: -0.10991378128528595
Epoch 0, After Update, transformer.wte.weight: -0.10989007353782654
Epoch 0, Before Update, transformer.wte.w

The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.


Epoch 0, After Update, transformer.wte.weight: -0.109721340239048
Epoch 0, Loss: -0.039974309504032135
Epoch 0, Before Update, transformer.wte.weight: -0.109721340239048
Epoch 0, After Update, transformer.wte.weight: -0.1097053587436676
Epoch 0, Before Update, transformer.wte.weight: -0.1097053587436676
Epoch 0, After Update, transformer.wte.weight: -0.1096908450126648
Epoch 0, Before Update, transformer.wte.weight: -0.1096908450126648
Epoch 0, After Update, transformer.wte.weight: -0.10967765748500824
Epoch 0, Before Update, transformer.wte.weight: -0.10967765748500824
Epoch 0, After Update, transformer.wte.weight: -0.10966566205024719
Epoch 0, Before Update, transformer.wte.weight: -0.10966566205024719
Epoch 0, After Update, transformer.wte.weight: -0.1096547469496727
Epoch 0, Before Update, transformer.wte.weight: -0.1096547469496727
Epoch 0, After Update, transformer.wte.weight: -0.109644815325737
Epoch 0, Before Update, transformer.wte.weight: -0.109644815325737
Epoch 0, After Upd

The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.


Epoch 0, After Update, transformer.wte.weight: -0.10958320647478104
Epoch 0, Loss: -0.0
Epoch 1, Before Update, transformer.wte.weight: -0.10958320647478104
Epoch 1, After Update, transformer.wte.weight: -0.1095796450972557
Epoch 1, Before Update, transformer.wte.weight: -0.1095796450972557
Epoch 1, After Update, transformer.wte.weight: -0.10957640409469604
Epoch 1, Before Update, transformer.wte.weight: -0.10957640409469604
Epoch 1, After Update, transformer.wte.weight: -0.10957345366477966
Epoch 1, Before Update, transformer.wte.weight: -0.10957345366477966
Epoch 1, After Update, transformer.wte.weight: -0.10957076400518417
Epoch 1, Before Update, transformer.wte.weight: -0.10957076400518417
Epoch 1, After Update, transformer.wte.weight: -0.10956831276416779
Epoch 1, Before Update, transformer.wte.weight: -0.10956831276416779
Epoch 1, After Update, transformer.wte.weight: -0.1095660850405693
Epoch 1, Before Update, transformer.wte.weight: -0.1095660850405693
Epoch 1, After Update, tr

The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


Epoch 1, After Update, transformer.wte.weight: -0.10955231636762619
Epoch 1, Loss: -0.0
Epoch 1, Before Update, transformer.wte.weight: -0.10955231636762619
Epoch 1, After Update, transformer.wte.weight: -0.10956891626119614
Epoch 1, Before Update, transformer.wte.weight: -0.10956891626119614
Epoch 1, After Update, transformer.wte.weight: -0.10957801342010498
Epoch 1, Before Update, transformer.wte.weight: -0.10957801342010498
Epoch 1, After Update, transformer.wte.weight: -0.10958628356456757
Epoch 1, Before Update, transformer.wte.weight: -0.10958628356456757
Epoch 1, After Update, transformer.wte.weight: -0.10959379374980927
Epoch 1, Before Update, transformer.wte.weight: -0.10959379374980927
Epoch 1, After Update, transformer.wte.weight: -0.10959772020578384
Epoch 1, Before Update, transformer.wte.weight: -0.10959772020578384
Epoch 1, After Update, transformer.wte.weight: -0.1095653772354126
Epoch 1, Before Update, transformer.wte.weight: -0.1095653772354126
Epoch 1, After Update, 

The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


Epoch 1, After Update, transformer.wte.weight: -0.10936454683542252
Epoch 1, Loss: 0.06367666274309158
Epoch 2, Before Update, transformer.wte.weight: -0.10936454683542252
Epoch 2, After Update, transformer.wte.weight: -0.10935021936893463
Epoch 2, Before Update, transformer.wte.weight: -0.10935021936893463
Epoch 2, After Update, transformer.wte.weight: -0.109337218105793
Epoch 2, Before Update, transformer.wte.weight: -0.109337218105793
Epoch 2, After Update, transformer.wte.weight: -0.10932543128728867
Epoch 2, Before Update, transformer.wte.weight: -0.10932543128728867
Epoch 2, After Update, transformer.wte.weight: -0.10931473970413208
Epoch 2, Before Update, transformer.wte.weight: -0.10931473970413208
Epoch 2, After Update, transformer.wte.weight: -0.10930504649877548
Epoch 2, Before Update, transformer.wte.weight: -0.10930504649877548
Epoch 2, After Update, transformer.wte.weight: -0.10929625481367111
Epoch 2, Before Update, transformer.wte.weight: -0.10929625481367111
Epoch 2, A

The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.


Epoch 2, After Update, transformer.wte.weight: -0.1092430129647255
Epoch 2, Loss: -0.0
Epoch 2, Before Update, transformer.wte.weight: -0.1092430129647255
Epoch 2, After Update, transformer.wte.weight: -0.10923996567726135
Epoch 2, Before Update, transformer.wte.weight: -0.10923996567726135
Epoch 2, After Update, transformer.wte.weight: -0.10923705995082855
Epoch 2, Before Update, transformer.wte.weight: -0.10923705995082855
Epoch 2, After Update, transformer.wte.weight: -0.10923418402671814
Epoch 2, Before Update, transformer.wte.weight: -0.10923418402671814
Epoch 2, After Update, transformer.wte.weight: -0.1092311292886734
Epoch 2, Before Update, transformer.wte.weight: -0.1092311292886734
Epoch 2, After Update, transformer.wte.weight: -0.10922833532094955
Epoch 2, Before Update, transformer.wte.weight: -0.10922833532094955
Epoch 2, After Update, transformer.wte.weight: -0.10922563821077347
Epoch 2, Before Update, transformer.wte.weight: -0.10922563821077347
Epoch 2, After Update, tr