In [19]:
from transformers import AutoModelForCausalLM, AutoTokenizer

import torch
import torch.nn.functional as F

# Lesson seven: Calculating loss in the GRPO algorithm

This lecture covers details of the GRPO loss function, including policy loss ratio, clipping, KL divergence, and shows in practice how the process uses a 'reference model' - a base model - and then a 'policy model' that's the reference model plus trainable LoRA weights. The stuff in this lesson is for learning and understanding - in practice when doing RFT we won't write and use our own implementation of the loss function... instead we'll let other code, like Predibase code, handle it, since it's always the same, plugging in things that change, like our own reward function(s). This is covered in lesson eight.  

## Initialize the model and tokenizer 

In [12]:
model_str = 'babylm/babyllama-100m-2024'
base_model = AutoModelForCausalLM.from_pretrained(model_str)
tokenizer = AutoTokenizer.from_pretrained(model_str)

# pad on left so we can add new tokens on the right
tokenizer.padding_side = 'left'
tokenizer.truncation_side = 'left'

In [13]:
print(base_model)

LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(16000, 512, padding_idx=0)
    (layers): ModuleList(
      (0-15): 16 x LlamaDecoderLayer(
        (self_attn): LlamaAttention(
          (q_proj): Linear(in_features=512, out_features=512, bias=False)
          (k_proj): Linear(in_features=512, out_features=512, bias=False)
          (v_proj): Linear(in_features=512, out_features=512, bias=False)
          (o_proj): Linear(in_features=512, out_features=512, bias=False)
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=512, out_features=1024, bias=False)
          (up_proj): Linear(in_features=512, out_features=1024, bias=False)
          (down_proj): Linear(in_features=1024, out_features=512, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): LlamaRMSNorm((512,), eps=1e-06)
        (post_attention_layernorm): LlamaRMSNorm((512,), eps=1e-06)
      )
    )
    (norm): LlamaRMSNorm((512,), eps=1e-06)
    (rotary_emb)

In [4]:
prompt = 'The quick brown fox jumped over the '

input_ids = tokenizer(prompt, return_tensors='pt')
input_ids

{'input_ids': tensor([[1086, 1617, 2837, 6114, 4749,  551,  196,  154]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1]])}

In [5]:
with torch.no_grad():
    outputs = base_model.generate(
        **input_ids,
        max_new_tokens=2,
        pad_token_id=tokenizer.pad_token_id
    )

outputs

tensor([[1086, 1617, 2837, 6114, 4749,  551,  196,  154, 3407, 1952]])

In [6]:
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
generated_text

'The quick brown fox jumped over the icy ground'

In [7]:
generated_portion = generated_text[len(prompt):]
generated_portion

'icy ground'

In [8]:
print(f"Generated text: {prompt}\033[94m{generated_portion}\033[0m")

Generated text: The quick brown fox jumped over the [94micy ground[0m


## Create reference and policy models

In [14]:
import copy
from peft import LoraConfig, get_peft_model

ref_model = copy.deepcopy(base_model)

# init LoRA
lora_config = LoraConfig(
    r=8, # rank of update matrices
    lora_alpha=32, # alpha scaling factor
    target_modules=['q_proj', 'v_proj'], # apply LoRA here
    lora_dropout=0.1,
    init_lora_weights=False,
    bias='none',
    task_type='CAUSAL_LM'
)

model = get_peft_model(base_model, lora_config)
model

PeftModelForCausalLM(
  (base_model): LoraModel(
    (model): LlamaForCausalLM(
      (model): LlamaModel(
        (embed_tokens): Embedding(16000, 512, padding_idx=0)
        (layers): ModuleList(
          (0-15): 16 x LlamaDecoderLayer(
            (self_attn): LlamaAttention(
              (q_proj): lora.Linear(
                (base_layer): Linear(in_features=512, out_features=512, bias=False)
                (lora_dropout): ModuleDict(
                  (default): Dropout(p=0.1, inplace=False)
                )
                (lora_A): ModuleDict(
                  (default): Linear(in_features=512, out_features=8, bias=False)
                )
                (lora_B): ModuleDict(
                  (default): Linear(in_features=8, out_features=512, bias=False)
                )
                (lora_embedding_A): ParameterDict()
                (lora_embedding_B): ParameterDict()
                (lora_magnitude_vector): ModuleDict()
              )
              (k_proj): Linea

## Calculate policy loss ratio

In [15]:
def prepare_inputs(prompt, completion):
    prompt_tokens = tokenizer(prompt, return_tensors='pt')
    completion_tokens = tokenizer(completion, return_tensors='pt')

    # combined input
    input_ids = torch.cat(
        [
            prompt_tokens['input_ids'],
            completion_tokens['input_ids'],
        ],
        dim=1
    )
    attention_mask = torch.cat(
        [
            prompt_tokens['attention_mask'],
            completion_tokens['attention_mask'],
        ],
        dim=1
    )

    prompt_length = prompt_tokens['input_ids'].shape[1]
    completion_length = completion_tokens['input_ids'].shape[1]
    total_length = prompt_length + completion_length

    # create mask to identify tokens generated by the model
    completion_mask = torch.zeros(total_length, dtype=torch.float32)
    completion_mask[prompt_length:] = 1.0

    return input_ids, attention_mask, completion_mask

In [16]:
def compute_log_probs(model, input_ids, attention_mask):
    outputs = model(input_ids, attention_mask=attention_mask)

    # compute the log prob for each token in the sequence
    # output.logits is the logits for all tokens in the vocab, for each position in the sequence
    log_probs = F.log_softmax(outputs.logits, dim=-1)

    # and extract the log prob for the actual token generated at each position in the sequence
    return log_probs.gather(
        dim=-1,
        index=input_ids.unsqueeze(-1)
    ).squeeze(-1)

And now do the GRPO loss using the prepared inputs and log probs.

In [17]:
def grpo_loss(model, ref_model, prompt, completion, advantage):
    input_ids, attention_mask, completion_mask = prepare_inputs(prompt, completion)

    # model forward
    token_log_probs = compute_log_probs(model, input_ids, attention_mask)
    with torch.no_grad():
        ref_token_log_probs = compute_log_probs(ref_model, input_ids, attention_mask)

    # ratio = p_model / p_ref = exp(log(p_model) - log(p_ref))
    ratio = torch.exp(token_log_probs - ref_token_log_probs)

    # scale ratio by advantage func
    policy_loss = ratio * advantage

    # since we want to maximize reward, make the loss negative since optimizers minimize loss
    per_token_loss = -policy_loss

    # and only compute loss over the output tokens
    loss = (per_token_loss * completion_mask).sum() / completion_mask.sum()
    return loss

And call w/ a sample completion and a hard-coded advantage value (which in practice we'll calc w/ reward functions, as in prev lectures).

In [20]:
grpo_loss(model, ref_model, prompt, 'fence and', advantage=2.0)

tensor(-6.3921, grad_fn=<DivBackward0>)

The rest of the lecture shows the math to add clipping and KL divergence. I'm not going to reproduce that code here since it's truly an implementation detail and I'm not digging in to understand the math at this level right now.