In [None]:
import wandb
import torch
import os
import torch.nn as nn
from datasets import load_dataset
from transformers import GPT2Tokenizer, GPT2LMHeadModel
import torch.nn.functional as F

from torch.utils.data import Dataset, DataLoader
from dataclasses import dataclass
from tqdm import tqdm
from datasets import load_dataset

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
def getwandbrun(cfgs):
  wandb.login(key=cfgs.WANDBAPI_KEY)
  run = wandb.init(
      entity="ajheshbasnet-kpriet",
      project="RLVR",
      name = "rewards-runs",
      config=vars(cfgs),
  )
  return run

In [None]:
@dataclass
class configs:
  MAX_SEQ_LEN = 512
  REWARD_LEARNING_RATE = 1e-4
  TRANSFORMER_LEARNING_RATE = 1e-5
  TRAIN_LENGTH = 10000
  VALID_LENGTH = 2000
  DRIVE_STEP = 7_000
  EVAL_EVERY_STEP = 1400
  GRADIENT_ACCUM_STEPS = 8 # Increased to compensate for smaller batch size
  MODEL_NAME = "gpt2"
  WANDBAPI_KEY = ""
  TRAIN_BATCH_SIZE = 8 # Reduced to save VRAM
  VALID_BATCH_SIZE = 4
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

cfg = configs()

In [None]:
tokenizer = GPT2Tokenizer.from_pretrained(cfg.MODEL_NAME)
tokenizer.padding_side = "left"
tokenizer.pad_token = tokenizer.eos_token

In [None]:
model = GPT2LMHeadModel.from_pretrained(cfg.MODEL_NAME).to(cfg.DEVICE)

In [None]:
checkpointer = torch.load("/content/drive/MyDrive/checkpoint_epoch_4.pth", map_location=cfg.DEVICE)

In [None]:
model.load_state_dict(checkpointer['model_state_dict'])

In [None]:
runs = getwandbrun(cfg)

# **REWARD MODEL**

In [None]:
rl_dataset = load_dataset("CarperAI/openai_summarize_comparisons")

In [None]:
rl_dataset

In [None]:
rl_dataset_train = rl_dataset['train'].select(torch.randperm(len(rl_dataset['train'])))[:cfg.TRAIN_LENGTH]
rl_dataset_valid = rl_dataset['valid1'].select(torch.randperm(len(rl_dataset['valid1']))[:cfg.VALID_LENGTH])

In [None]:
class RewardModelDataset(Dataset):

  def __init__(self, ds):

    self.prompt = []
    self.chosen = []
    self.reject = []
    self.tokenizer = tokenizer
    self.tokenizer.pad_token_id = self.tokenizer.eos_token_id

    for p, c, r in tqdm(zip(ds['prompt'], ds['chosen'], ds['rejected']), total=len(ds)):

      if len(tokenizer(p)['input_ids'] + tokenizer(c)['input_ids']) <= cfg.MAX_SEQ_LEN and len(tokenizer(p)['input_ids'] + tokenizer(r)['input_ids']) <= cfg.MAX_SEQ_LEN:
        self.prompt.append(p)
        self.chosen.append(c)
        self.reject.append(r)

  def __len__(self):
    return len(self.prompt)

  def __getitem__(self, index):

    prompt_chosen = f'{self.prompt[index]}\nTL;DR:  {self.chosen[index]}'
    prompt_reject = f'{self.prompt[index]}\nTL;DR:  {self.reject[index]}'

    prompt_chosen_ids = self.tokenizer(prompt_chosen, max_length=cfg.MAX_SEQ_LEN-1, return_tensors = 'pt', truncation=True, padding='max_length')['input_ids'][0]
    prompt_chosen_msk = self.tokenizer(prompt_chosen, max_length=cfg.MAX_SEQ_LEN-1, return_tensors = 'pt', truncation=True, padding='max_length')['attention_mask'][0]

    prompt_chosen_ids = torch.cat((prompt_chosen_ids, torch.tensor([tokenizer.eos_token_id])))
    prompt_chosen_ids = torch.cat((prompt_chosen_msk, torch.tensor([1])))

    prompt_reject_ids = self.tokenizer(prompt_reject, max_length=cfg.MAX_SEQ_LEN-1, return_tensors = 'pt',  truncation=True, padding='max_length')['input_ids'][0]
    prompt_reject_msk = self.tokenizer(prompt_reject, max_length=cfg.MAX_SEQ_LEN-1, return_tensors = 'pt', truncation=True, padding='max_length')['attention_mask'][0]

    prompt_reject_ids = torch.cat((prompt_reject_ids, torch.tensor([tokenizer.eos_token_id])))   # eos token id is used to
    prompt_reject_msk = torch.cat((prompt_reject_msk, torch.tensor([1])))

    return prompt_chosen_ids, prompt_chosen_msk, prompt_reject_ids, prompt_reject_msk

In [None]:
train_ds = RewardModelDataset(rl_dataset_train)
train_loader = DataLoader(train_ds, batch_size= cfg.TRAIN_BATCH_SIZE, shuffle=True)

valid1 = RewardModelDataset(rl_dataset_valid)
valid_loader = DataLoader(valid1, batch_size=cfg.VALID_BATCH_SIZE, shuffle=True)

In [None]:
class RewardModel(nn.Module):

  def __init__(self, model):
    super().__init__()

    self.base_model = model.transformer
    self.reward_head = nn.Sequential(
        nn.Linear(768, 1)
        )

  def forward(self, x, attn_mask):
    h_s = self.base_model(input_ids = x, attention_mask=attn_mask).last_hidden_state
    end_tk = h_s[:, -1, :]
    rewards = self.reward_head(end_tk)
    return rewards

In [None]:
rewardmodel = RewardModel(model).to(cfg.DEVICE)
rewardmodel = torch.compile(rewardmodel)

In [None]:
REWARD_EPOCHS = 4

OPTIMIZER = torch.optim.AdamW([
    {"params": rewardmodel.base_model.parameters(), "lr": cfg.TRANSFORMER_LEARNING_RATE},   # it has pretty decent knowledge so it's lr is less
    {"params": rewardmodel.reward_head.parameters(), "lr": cfg.REWARD_LEARNING_RATE},  # this is new so it trains very fast hence it has high learning rates
])

In [None]:
cfg.EVAL_EVERY_STEP = 1000
len(train_loader)

In [None]:
import os
from tqdm import tqdm
import torch
from torch.amp import autocast, GradScaler

# ------------------- Performance Boost (Safe) -------------------
torch.backends.cudnn.allow_tf32 = True
torch.backends.cudnn.conv.fp32_precision = 'tf32'

# ------------------- AMP setup -------------------
scaler = GradScaler(device="cuda")

for epoch in range(REWARD_EPOCHS):

    OPTIMIZER.zero_grad(set_to_none=True)
    train_loss = 0.0
    global_rollouts = 0
    rewardmodel.train()

    for step, batch in enumerate(train_loader):

        prompt_chosen_ids, prompt_chosen_msk, prompt_reject_ids, prompt_reject_msk = batch

        prompt_chosen_ids = prompt_chosen_ids.to(cfg.DEVICE, non_blocking=True)
        prompt_chosen_msk = prompt_chosen_msk.to(cfg.DEVICE, non_blocking=True)
        prompt_reject_ids = prompt_reject_ids.to(cfg.DEVICE, non_blocking=True)
        prompt_reject_msk = prompt_reject_msk.to(cfg.DEVICE, non_blocking=True)


        input_ids = torch.cat((prompt_chosen_ids, prompt_reject_ids), dim = 0)
        attention_mask = torch.cat((prompt_chosen_msk, prompt_reject_msk), dim = 0)

        with autocast(device_type="cuda", dtype=torch.float16):

            # chosen_rewards = rewardmodel(prompt_chosen_ids, prompt_chosen_msk)
            # reject_rewards = rewardmodel(prompt_reject_ids, prompt_reject_msk)

            reward = rewardmodel(input_ids, attention_mask)

            chosen_rewards, reject_rewards = reward.chunk(2, dim = 0)

            logits = chosen_rewards - reject_rewards

            loss = F.binary_cross_entropy_with_logits(
                logits,
                torch.ones_like(logits)
            )

            loss = loss / cfg.GRADIENT_ACCUM_STEPS

        scaler.scale(loss).backward()

        train_loss += loss.detach()
        global_rollouts += 1

        if (step + 1) % cfg.GRADIENT_ACCUM_STEPS == 0:

            scaler.unscale_(OPTIMIZER)
            torch.nn.utils.clip_grad_norm_(rewardmodel.parameters(), 2.0)

            scaler.step(OPTIMIZER)
            scaler.update()
            OPTIMIZER.zero_grad(set_to_none=True)

            runs.log({
                "training-reward-loss": train_loss.item(),
                "steps": step + 1
            })

            train_loss = 0.0

        # ---------------------- VALIDATION ----------------------

        if (step + 1) % cfg.EVAL_EVERY_STEP == 0:

            valid_loss = 0.0
            valid_counter = 0

            rewardmodel.eval()

            with torch.no_grad():
                for valid_batch in valid_loader:

                    prompt_chosen_ids, prompt_chosen_msk, prompt_reject_ids, prompt_reject_msk = valid_batch

                    prompt_chosen_ids = prompt_chosen_ids.to(cfg.DEVICE, non_blocking=True)
                    prompt_chosen_msk = prompt_chosen_msk.to(cfg.DEVICE, non_blocking=True)
                    prompt_reject_ids = prompt_reject_ids.to(cfg.DEVICE, non_blocking=True)
                    prompt_reject_msk = prompt_reject_msk.to(cfg.DEVICE, non_blocking=True)

                    val_ids = torch.cat((prompt_chosen_ids, prompt_reject_ids), dim = 0)
                    val_msk = torch.cat((prompt_chosen_msk, prompt_reject_msk), dim = 0)

                    with autocast(device_type="cuda", dtype=torch.float16):

                        # chosen_rewards = rewardmodel(prompt_chosen_ids, prompt_chosen_msk)
                        # reject_rewards = rewardmodel(prompt_reject_ids, prompt_reject_msk)

                        rewards = rewardmodel(val_ids,val_msk)

                        chosen_rewards, reject_rewards = rewards.chunk(2, dim = 0)

                        logits = chosen_rewards - reject_rewards

                        loss = F.binary_cross_entropy_with_logits(
                            logits,
                            torch.ones_like(logits)
                        )

                    valid_loss += loss.detach().float().item()
                    valid_counter += 1

            valid_loss = valid_loss / valid_counter

            runs.log({
                "valid-reward-loss": valid_loss,
                "steps": step + 1
            })

            rewardmodel.train()

        # ---------------- Handle Final Partial Accumulation ----------------
    if (step + 1) % cfg.GRADIENT_ACCUM_STEPS != 0:
        scaler.unscale_(OPTIMIZER)
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)

        scaler.step(OPTIMIZER)
        scaler.update()

        if 'scheduler' in globals():
            OPTIMIZER.step()

        OPTIMIZER.zero_grad(set_to_none=True)


    # ---------------- Checkpoint ----------------

    if  global_rollouts%cfg.DRIVE_STEP==0:
      checkpoint = {
          'epoch': epoch + 1,
          'model_state_dict': rewardmodel.state_dict(),
          'optimizer_state_dict': OPTIMIZER.state_dict(),
          'scaler_state_dict': scaler.state_dict()
      }

      save_dir = "/content/drive/MyDrive/reward-optimizer"
      os.makedirs(save_dir, exist_ok=True)

      filename = f"checkpoint_epoch_{epoch+1}.pth"
      torch.save(checkpoint, os.path.join(save_dir, filename))
      print(f"==================== EPOCH {epoch+1} CHECKPOINTER IS SAVED ====================")


In [None]:
rewardmodel(pc, pcm)

In [None]:
rewardmodel(pr, prm)