In [None]:
import wandb
import torch
import os
import torch.nn as nn
from datasets import load_dataset
from transformers import GPT2Tokenizer, GPT2LMHeadModel, GPT2Model
import torch.nn.functional as F

from torch.utils.data import Dataset, DataLoader
from dataclasses import dataclass
from tqdm import tqdm
from datasets import load_dataset

In [1]:
# training takes a lot of time so better so save the weights of optimizers and model

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
def getwandbrun(cfgs):
  wandb.login(key=cfgs.WANDBAPI_KEY)
  run = wandb.init(
      entity="ajheshbasnet-kpriet",
      project="RLVR",
      name = "grpo-training-loop",
      config=vars(cfgs),
  )
  return run

In [None]:

@dataclass
class configs:
  MAX_SEQ_LEN = 512
  SFT_LEARNING_RATE = 2.5e-5
  SFT_EPOCHS = 6
  EVAL_EVERY_STEP = 12
  GRADIENT_ACCUM_STEPS = 32 # Increased to compensate for smaller batch size
  MODEL_NAME = "gpt2"
  WANDBAPI_KEY = ""
  SFT_TRAIN_BATCH_SIZE = 8 # Reduced to save VRAM
  SFT_VALID_BATCH_SIZE = 4
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

cfg = configs()

In [None]:
model = GPT2LMHeadModel.from_pretrained(cfg.MODEL_NAME).to(cfg.DEVICE)

I used the Pad token as EOS token but it is handled well using masking in the loop

In [None]:
tokenizer = GPT2Tokenizer.from_pretrained(cfg.MODEL_NAME)
tokenizer.padding_side = "left"
# tokenizer.add_special_tokens({'pad_token': '<PAD>'})
# model.resize_token_embeddings(len(tokenizer))

In [None]:
tokenizer.pad_token_id = tokenizer.eos_token_id

In [None]:
runs = getwandbrun(cfg)

In [None]:
# checkpoint = {
# 'epoch': epoch,
# 'model_state_dict': model.state_dict(),
# 'optimizer_state_dict': SFT_OPTIMIZER.state_dict(),
# 'loss': loss,
# }

In [None]:
# save_dir = "/content/drive/MyDrive/sft-optimizer"
# os.makedirs(save_dir, exist_ok=True)

# filename = f"checkpoint_epoch_{epoch}.pth"
# torch.save(checkpoint, os.path.join(save_dir, filename))

# **REWARD MODEL**

In [None]:
rl_dataset = load_dataset("CarperAI/openai_summarize_comparisons")

In [None]:
rl_dataset_train = rl_dataset['train'].select(torch.randperm(len(rl_dataset['train']))[:500])
rl_dataset_valid = rl_dataset['valid1'].select(torch.randperm(len(rl_dataset['valid1']))[:50])

rl_dataset_train, rl_dataset_valid

In [None]:
class RewardModel(nn.Module):

  def __init__(self):
    super().__init__()

    self.base_model = GPT2Model.from_pretrained(cfg.MODEL_NAME)
    self.reward_head = nn.Sequential(
        nn.Linear(768, 1)
        )

  def forward(self, x, attn_mask, eos_idx=None):
    h_s = self.base_model(input_ids = x, attention_mask=attn_mask).last_hidden_state
    if eos_idx == None:
      chosen_idx = (torch.fill(torch.zeros(h_s.size(0),), h_s.size(1)) - 1).long()
    else:
      chosen_idx = eos_idx
    hs = h_s[torch.arange(h_s.size(0), dtype=torch.long, device = h_s.device), chosen_idx]
    rewards = self.reward_head(hs)
    return rewards

In [None]:
reward_model = RewardModel().to(cfg.DEVICE)

In [None]:
sft_checkpointer=torch.load("/content/drive/MyDrive/sft-optimizer/checkpoint_epoch_4.pth")

In [None]:
model.load_state_dict(sft_checkpointer['model_state_dict'])

In [None]:
rew_checkpointer=torch.load("/content/drive/MyDrive/checkpoints/rewards_weights")

In [None]:
checkpoint = torch.load("/content/drive/MyDrive/checkpoints/rewards_weights")

new_state_dict = {}
for k, v in checkpoint.items():
    new_key = k.replace("_orig_mod.", "")
    new_state_dict[new_key] = v

reward_model.load_state_dict(new_state_dict)

# **GRPO**

In [None]:
DEVICE = "cuda"
LR = 2e-5
KL_COEF = 0.01
PPO_EPS = 0.2
GRAD_ACCUM = 2
MAX_NEW_TOKENS = 54
NUM_SAMPLES = 4
TEMPERATURE = 0.6
EVAL_INTERVAL = 6000
TRAIN_BATCH_SIZE = 1
VALID_BATCH_SIZE = 2
MAX_PROMPT_LENGTH = 500
DRIVE_CHECKPOINTER = 10_000
UPDATE_WEIGHTS = 500
MAX_N_STEPS = 10

In [None]:
class GRPODataset(Dataset):

  def __init__(self, ds):

    self.prompt = ds

  def __len__(self):
    return len(self.prompt)

  def __getitem__(self, index):
    tokenized = tokenizer(self.prompt[index], max_length = MAX_PROMPT_LENGTH,  truncation=True, padding = "max_length", return_tensors="pt")
    ids = tokenized['input_ids'][0]
    msk = tokenized['attention_mask'][0]
    return ids, msk

In [None]:
grpo_trainds = GRPODataset(rl_dataset_train['prompt'])
grpo_validds = GRPODataset(rl_dataset_valid['prompt'])

In [None]:
train_dataloader = DataLoader(grpo_trainds, batch_size = TRAIN_BATCH_SIZE, shuffle = True)
valid_dataloader = DataLoader(grpo_trainds, batch_size = VALID_BATCH_SIZE, shuffle = True)

In [None]:
import copy

frozen_model = copy.deepcopy(model)

for p in frozen_model.parameters():
    p.requires_grad = False

_ = frozen_model.eval()

In [None]:
optimizer = torch.optim.AdamW(model.parameters(), lr=LR)

In [None]:
from tqdm import tqdm

In [None]:
for _ in range(MAX_N_STEPS):

    buffer = {
        "outputs": [],
        "attn_mask": [],
        "prompt_mask": [],
        "advantages": [],
        "old_log_probs": [],
        "loss_mask":[]
    }

    for step, batch in tqdm(enumerate(train_dataloader)):

        ids, msk = batch
        ids, msk = ids.to(cfg.DEVICE), msk.to(cfg.DEVICE)

        outputs = model.generate(
                        ids,
                        attention_mask=msk,
                        do_sample=True,
                        temperature=TEMPERATURE,
                        max_new_tokens=MAX_NEW_TOKENS,
                        num_return_sequences=NUM_SAMPLES,
                        pad_token_id=tokenizer.pad_token_id
                    )

        prompt_len = ids.size(-1)

        response = outputs[:, prompt_len:]

        response_msk = (torch.cumsum((response == tokenizer.eos_token_id).int(), dim = -1)<=1).long() # only includes first <EOS_TOKEN > response + <EOS>=1 others 0.

        # this is for masking prompts and seen only response for the loss calculation - here we mask prompts = PAD = 0 and response+<EPS>=1

        prompt_mask = torch.zeros((outputs.size(0), ids.size(-1)), device = ids.device)

        loss_mask = torch.cat((prompt_mask, response_msk), dim = -1)

        response_eos_length = ids.size(-1) + response_msk.sum(dim = -1) -1  # -1 for indices.

        attn_mask = torch.cat((msk[torch.floor(torch.arange(outputs.size(0)) / NUM_SAMPLES).long()], response_msk), dim = -1)

        r = reward_model(outputs, attn_mask, response_eos_length.view(-1)).view(-1, NUM_SAMPLES)

        advantages = (r - r.mean(dim = -1, keepdim = True)) / (r.std(dim = -1, keepdim = True) + 1e-8)

        logits = F.log_softmax(model(outputs, attention_mask = attn_mask).logits, dim = -1)

        input_logits = logits[:, :-1, :]

        target_ids = outputs[:, 1:]

        new_log_probs = input_logits.gather(dim = -1, index = target_ids.unsqueeze(-1)).squeeze(-1)                # [B*N, T_MAX]

        advantages = advantages.view(-1, 1).detach()            #[BXH, 1]

        buffer['outputs'].append((outputs.detach().cpu()))
        buffer['prompt_mask'].append((prompt_mask.detach().cpu()))
        buffer['attn_mask'].append((attn_mask.detach().cpu()))
        buffer['old_log_probs'].append((new_log_probs.detach().cpu()))
        buffer['advantages'].append((advantages.detach().cpu()))
        buffer['loss_mask'].append((loss_mask.detach().cpu()))


        runs.log({"step":step})
        if (step+1) % UPDATE_WEIGHTS == 0:

          all_indices = torch.randperm(len(buffer['outputs']))

          for idx in range(0, len(buffer['outputs']), TRAIN_BATCH_SIZE):

              start_idx = idx
              end_idx = start_idx + TRAIN_BATCH_SIZE

              batch_indices = all_indices[start_idx:end_idx]

              batch_outputs = buffer['outputs'][batch_indices].to(cfg.DEVICE)
              batch_attn_mask = buffer['attn_mask'][batch_indices].to(cfg.DEVICE)
              batch_old_log_probs = buffer['old_log_probs'][batch_indices].to(cfg.DEVICE)
              batch_advantages = buffer['advantages'][batch_indices].to(cfg.DEVICE)
              batch_loss_mask = buffer['loss_mask'][batch_indices].to(cfg.DEVICE)

              batch_new_logits = F.log_softmax(model(batch_outputs, attention_mask = batch_attn_mask).logits, dim = -1)

              batch_target = batch_outputs[:, 1:]
              batch_input = batch_new_logits[:, :-1, :]

              batch_new_log_probs = batch_input.gather(dim = -1, index = batch_target.unsqueeze(-1)).squeeze(-1)

              ratio = torch.exp(batch_new_log_probs - batch_old_log_probs)        #[BXH, T]

              policy_loss = - (torch.min(ratio*batch_advantages, torch.clamp(ratio, 1 - PPO_EPS, 1 + PPO_EPS)*batch_advantages) * batch_attn_mask[:, 1:]).sum()/ batch_attn_mask[:, 1:].sum()


              # ==================== KL-DIVERGENCE ==================== #

              ref_log_logits = F.log_softmax(frozen_model(batch_outputs, attention_mask=batch_attn_mask).logits, dim = -1)

              ref_log_probs = ref_log_logits[:, :-1, :].gather(dim = -1, index = batch_target.unsqueeze(-1)).squeeze(-1)

              kl_div_per_token = (batch_new_log_probs - ref_log_probs)     # [BXH , T]
              kl_div = (kl_div_per_token * batch_loss_mask[:, 1:]).sum() / batch_loss_mask[:, 1:].sum()

              loss_GRPO = policy_loss + KL_COEF * kl_div

              optimizer.zero_grad()
              loss_GRPO.backward()
              optimizer.step()

              runs.log({"loss-grpo": loss_GRPO.item(), "kl-div": kl_div.item(), "policy-loss": policy_loss.item()})

          buffer = {
                  "outputs": [],
                  "attn_mask": [],
                  "prompt_mask": [],
                  "advantages": [],
                  "old_log_probs": [],
                  "loss_mask":[]
              }

In [None]:
  (batch_new_log_probs)# - ref_log_probs)

In [None]:
step

In [None]:
outputs[2][544]

In [None]:
outputs