### GRPO

In [None]:
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig
import torch

bf16 = False
if torch.cuda.is_bf16_supported():
    bf16 = True


from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
device = 'cuda' if torch.cuda.is_available else 'cpu'
def load_model(model_name):
  model = AutoModelForCausalLM.from_pretrained(model_name, device_map=device, torch_dtype=torch.bfloat16 if bf16 else "auto")
  tokenizer = AutoTokenizer.from_pretrained(model_name)
  tokenizer.pad_token = tokenizer.eos_token # will use the same eos token as pad token

  return model, tokenizer

model_name = 'Qwen/Qwen2.5-1.5B-Instruct'
model, tokenizer = load_model(model_name)
tokenizer.pad_token = tokenizer.eos_token





In [None]:
!git clone https://github.com/open-thought/tiny-grpo.git

In [None]:
from dataclasses import dataclass

@dataclass
class Config:
  batch_size: int = 8
  group_size: int = 8
  epsilon : float = 1e-6
  exp_epoch : int = 1
  top_p : float = 1.0
  temperature :float = 0.7
  max_length : int = 512
  do_sample : bool = True
  mini_batch : int = 4

config = Config()


In [None]:
import re
import json
from pathlib import Path
from typing import Iterator, Optional, Any
from collections.abc import Callable
from torch.utils.data import DataLoader



def reward_fn(output: list, ground_truth_ans: Any) -> torch.tensor:
  # give reward based on its structure and the final answer
  returns = torch.zeros(config.group_size, 1, dtype=torch.float)

  for i, completion in enumerate(output):
      # search answer tag
      answer_match = re.search(
          r"<answer>(.*?)</answer>", #maybe strip the leading and tralining zeros
          completion,
          flags=re.DOTALL
      )

      answer = answer_match.group(1) if answer_match else None
      reward = 0
      if answer is not None:
          if answer.strip(' ') == ground_truth_ans:
              reward = 1.0
          elif ground_truth_ans in answer:
              reward = 0.5
          else:
              reward = 0.01

      returns[i] = reward

  return returns



In [None]:
def get_logprobs(logits, target):
  logprobs = torch.gather(logits[:,:-1,:], dim=-1, index=target[:,1:].unsqueeze(-1)).squeeze() #B,T
  return logprobs


In [None]:
import json
from pathlib import Path
from typing import Iterator, Optional, Any
from collections.abc import Callable
from torch.utils.data import DataLoader


def read_jsonl(path : str | Path) -> Iterator:
  with open(path, 'r') as f:
    # data = [json.loads(line) for line in f]
    for line in f:
      yield(json.loads(line))

def load_prompt(path: str, check_fn : Optional[Callable[ [Any],bool ]]) -> str:
  rows = []
  for x in read_jsonl(path):
    if check_fn(x):
      rows.append(x)

  return rows

loaded_prompt = load_prompt(
    '/home/ubuntu/tiny-grpo/data/math_tasks.jsonl',
    lambda x: len(x['question']) < 128
    and x['num_terms'] <=3
    and x['num_digits'] <=3)

print('total prompt size', len(loaded_prompt))

system_prompt = """A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning process and answer are enclosed within <think> </think> and <answer> </answer> tags, respectively, i.e., <think> reasoning process here </think>
<answer> answer here </answer>."""


dataloader = DataLoader(loaded_prompt, batch_size=config.batch_size, pin_memory=False, drop_last=True, shuffle=True)

In [None]:
import torch

@torch.no_grad()
def rollout(model, q, a):
  model.eval()
  system_prompt = """A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning process and answer are enclosed within <think> </think> and <answer> </answer> tags, respectively, i.e., <think> reasoning process here </think>
<answer> answer here </answer>."""

  template = [
    {"role": "system", "content": system_prompt},
    {"role": "user", "content": q},
]

  prompt_template = tokenizer.apply_chat_template(template, add_generation_token=True, tokenize=False)
  tokens = tokenizer(prompt_template, padding=True, padding_side='left', return_tensors='pt').to(device)

  #duplicate the tokens to groups
  tokens_input_ids = tokens['input_ids'].repeat(config.group_size,1)
  tokens_attention_mask = tokens['attention_mask'].repeat(config.group_size,1)

  tokens = {'input_ids' : tokens_input_ids, 'attention_mask': tokens_attention_mask}
  generation_config = GenerationConfig(
        do_sample=True,
        top_p=config.top_p,
        temperature=config.temperature,
        max_length=config.max_length,
        pad_token_id=tokenizer.pad_token_id,
    )
  output = model.generate(**tokens, generation_config=generation_config)
  #get the output tokens only
  response_tokens = output[:, tokens_input_ids.shape[1]:] # remove all the input tokens

  decoded_output = tokenizer.batch_decode(response_tokens, skip_special_tokens=True) # should now skip all pad tokens added at the end of response_tokens
  rewards = reward_fn(decoded_output, ground_truth_ans=a)

  action_mask = output != tokenizer.pad_token_id # generates a mask i.e false in place of tokens with pad_token_id else true
  action_mask[:, :tokens_input_ids.shape[1]] = False


  # perform the generations, find rewards, action_mask
  return output, rewards, action_mask



In [None]:
reference_model, _ = load_model(model_name)
print(f"Memory reserved before clearing: {torch.cuda.memory_reserved()/1e9}")

In [None]:
# make experience.
import torch
from dataclasses import dataclass

@dataclass
class Experience:
  logprobs: torch.tensor
  logprobs_ref : torch.tensor
  advantages: torch.tensor
  action_mask : torch.tensor
  output : torch.tensor # the is the generated tokens
  rewards : torch.tensor

  def to(self, device: str) -> None:
    # method to change all the tensor's device
    for key, val in self.__dict__.items():
      if isinstance(val,torch.Tensor):
        setattr(self, key, val.to(device=device))


In [None]:
dtype = torch.bfloat16 if bf16 else torch.float32
dtype

In [None]:
import gc
pad_token_id = tokenizer.pad_token_id
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-6)

ctx = (nullcontext() if device=='cpu' else torch.amp.autocast(device_type=device, dtype=dtype))

print(f"Memory reserved before clearing: {torch.cuda.memory_reserved()/1e9}")

In [None]:
from dataclasses import dataclass, fields
from typing import Optional

import torch
import torch.nn.functional as F

def zero_pad_sequences(
    sequences: list[torch.Tensor], side: str = "left"
) -> torch.Tensor:
    assert side in ("left", "right")
    max_len = max(seq.size(0) for seq in sequences)
    padded_sequences = []
    for seq in sequences:
        pad_len = max_len - seq.size(0)
        padding = (pad_len, 0) if side == "left" else (0, pad_len)
        padded_sequences.append(F.pad(seq, padding))
    return torch.stack(padded_sequences, dim=0)

In [None]:
def join_experience_batch(items: list[Experience]) -> Experience:
    batch_data = {}
    keys = (
        "logprobs",
        "logprobs_ref",
        "output",
        "advantages",
        "action_mask",
    )
    for key in keys:
        vals = [getattr(item, key) for item in items]
        if all(v is not None for v in vals):
            data = zero_pad_sequences(vals, "left")
        else:
            data = None
        batch_data[key] = data
    return Experience(**batch_data)

In [None]:
def get_data(data, mini_bsz):
    exp_list = []
    ls = torch.randint(len(data), (mini_bsz,))
    for i in ls:
        exp_list.append(data[i])

    return exp_list

def approx_kl_divergence(
    log_probs: torch.Tensor,
    log_probs_ref: torch.Tensor,
    action_mask: Optional[torch.Tensor],
) -> torch.Tensor:
    """
    Monte-Carlo approximation of KL divergence, k3 estimator, see: http://joschu.net/blog/kl-approx.html
    """

    log_ratio = log_probs_ref.float() - log_probs.float()
    if action_mask is not None:
        log_ratio = log_ratio * action_mask

    return log_ratio.exp() - log_ratio - 1

def grpo_loss(logprobs, exp, clip_eps=0.2, kl_coeff=0.01):
  ratio = (logprobs - exp.logprobs).exp()
  surr1 = ratio * exp.advantages
  surr2 = ratio.clamp(1 - clip_eps, 1 + clip_eps) * exp.advantages

  kl = approx_kl_divergence(logprobs, exp.logprobs_ref, exp.action_mask)
  loss = -torch.min(surr1, surr2) + kl_coeff * kl

  loss = (loss * exp.action_mask).sum(dim=-1) / exp.action_mask.sum(dim=-1) # mean across the tokens

  return loss





In [None]:
cnt = 0
loss_list = []
rewards_list = []

for batch in dataloader:
  exp_list = []
  print(f" {cnt} Before starting one batch logits: {torch.cuda.memory_reserved()/1e9}")
  for q,a in zip(batch['question'], batch['answer']):
    print(f" {cnt} start questions {torch.cuda.memory_reserved()/1e9}")
    print(q,a)
    output, rewards, action_mask = rollout(model, q,a)
    advantages = (rewards - rewards.mean(dim=0))/ (rewards.std(dim=0) + config.epsilon)

    attention_mask = output != tokenizer.pad_token_id
    with torch.no_grad():
      logits = model.forward(output,attention_mask=attention_mask).logits

    logprobs = get_logprobs(logits, output)

    # print(f"Before clearing logits: {torch.cuda.memory_reserved()/1e9}")
    del logits
    gc.collect()
    torch.cuda.empty_cache()
    # print(f"After clearing logits: {torch.cuda.memory_reserved()/1e9}")

    with torch.no_grad():
      logits_ref = reference_model.forward(output, attention_mask=attention_mask).logits
    logprobs_ref = get_logprobs(logits_ref, output)

    # print(f"Before clearing logits_ref: {torch.cuda.memory_reserved()/1e9}")
    # delete unwanted tensors from the GPU

    del logits_ref
    gc.collect()
    torch.cuda.empty_cache()


    experience = Experience(
        logprobs=logprobs,
        logprobs_ref=logprobs_ref,
        advantages=advantages,
        action_mask=action_mask[:, 1:],
        output=output,
        rewards=rewards
    )
    experience.to('cpu')
    exp_list.append(experience)

  print(f" {cnt} After moving experience to cpu: {torch.cuda.memory_reserved()/1e9}")


  for e in range(config.exp_epoch):
    # exp_dataloader = DataLoader(exp_list, shuffle=True, batch_size=config.mini_batch, drop_last=True, collate_fn=join_experience_batch)
    exp_dataloader = get_data(exp_list, config.mini_batch)
    for exp in exp_dataloader:
      # find the grpo loss
      exp.to(device)
      # first sample the new model's logits
      #create attention mask from the output tokens
      attention_mask = exp.output != pad_token_id
      with ctx:

        logits = model.forward(exp.output, attention_mask=attention_mask).logits
        logprobs = get_logprobs(logits, exp.output)

        # delete logits and clear gpu cache
        # del logits
        # gc.collect()
        # torch.cuda.empty_cache()
        loss = grpo_loss(logprobs, exp).mean() # mean across same examples

      optimizer.zero_grad()
      loss.backward()
      clip_grad_norm_(model.parameters(), max_norm=max_norm)
      optimizer.step()

      del logits, logprobs
      gc.collect()
      torch.cuda.empty_cache()

      print(f" {cnt} After doing step: {torch.cuda.memory_reserved()/1e9}")

      ## extras

      print(f'LOSS {cnt}', loss)
      print(f'AVG REWARDS: {exp.rewards.mean()}')
      loss_list.append(loss)
      rewards_list.append(exp.rewards.mean())
      cnt+=1

