# Transformers & Reinforcement Learning


In this lab, you'll learn about **preference alignment** and tuning transformer models with **reinforcement learning** (RL). Specifically, we'll cover:

1. **Supervised Fine-Tuning (SFT)**
2. **Preference Optimization**
    - [Direct Preference Optimization (DPO)](https://arxiv.org/abs/2305.18290)
3. **RL Model Tuning**
    - with a simplified [GRPO](https://arxiv.org/abs/2501.12948)


## Overview of Model Preparation

When creating a chat agent we usually start with a randomly initialized transformer model that we train on trillions of tokens/words using the language modeling objective.
To be more precise:
+ We fix the architecture of our chatbot (currently one of the most popular ones is [Llama-like transformer decoder](https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py)).  
  The main points of this architecture are as follows:
    + **Embedding Layer** (currently the number of tokens exceeds 100K)
    + a bunch of **Transformer Blocks** consisting of
        - **RMSNorm** (replaced LayerNorm due to comparable stability and faster computation)
        - [Grouped Query Attention (GQA)](https://arxiv.org/abs/2305.13245)
        - **Residual Connection**
        - **RMSNorm**
        - **SwiGLU** $V((\mathrm{sigmoid}(Gx) \cdot Gx)\cdot Wx)$
        - **Residual Connection**
    + **Final Norm**
    + **Language Modeling Head**
+ We prepare the **pre-training dataset** (this step can be applied in parallel with the above one).   
  We mainly focus on getting a lot of high-quality data that covers all topics that we would like our model to know.  
  In particular, if we want our model to be multilingual, we should introduce the multilingual data in the pre-training dataset.  
  We also make sure that the desired knowledge and abilities are presented in various forms to enhance model recall abilities (see [Physics of Language Models: Part 3.1, Knowledge Storage and Extraction](https://arxiv.org/abs/2309.14316)).  
  It is also worth noting that **the way we show the text pieces to the model can significantly impact its performance**
    - [In-context Pretraining: Language Modeling Beyond Document Boundaries](https://arxiv.org/abs/2310.10638)
    - [Analysing The Impact of Sequence Composition on Language Model Pre-Training](https://arxiv.org/abs/2402.13991).
+ We train our model using the **language modeling loss** on the prepared data. The pre-training stage is usually pretty long and **can utilize over $15$ trillion tokens** (see [Meta blog post on LLama 3.1](https://ai.meta.com/blog/meta-llama-3-1/)).
+ After pre-training, our model can have trouble following instructions.  
  For example, it can answer the question and then start writing a Wikipedia article.  
  To adjust the model one usually uses **Supervised Fine-Tuning (SFT)**.  
+ After SFT we have the model that should be able to follow instructions but the style of answers may not match human preferences.  
  To adjust the model here, one can train reward models that reflect human preferences and utilize a reinforcement learning algorithm to tune the model (usually done in several stages).  
  The other approach is to use simpler methods like
    - [Direct Preference Optimization (DPO)](https://arxiv.org/abs/2305.18290)
    - [Odds Ratio Preference Optimization (ORPO)](https://arxiv.org/abs/2403.07691)
+ Currently enabling the models to utilize inference time to improve the answers has gained a lot of attention.  
  Teaching the model to "think" can be achieved in several ways and is currently interleaved into the SFT part, or in some pure cases done right after the pre-training.  
  However, the last approach can result in a model that generates thoughts that are hard to understand for humans and uses several languages (see [DeepSeek-R1-Zero](https://arxiv.org/abs/2501.12948)).  
  The whole procedure can be achieved using reinforcement learning algorithms such as [GRPO](https://arxiv.org/abs/2501.12948) or via supervised fine-tuning on reasoning traces (see [s1: Simple test-time scaling](https://arxiv.org/abs/2501.19393)).


There are also other approaches. For example, instead of training the model from scratch, we can try to prune and distill a larger one into a smaller one (see [LLM Pruning and Distillation in Practice: The Minitron Approach](https://arxiv.org/abs/2408.11796)).



## Supervised Fine Tuning for Instruction Following
The idea behind SFT is simple, we want to prepare the data that will consist of **pairs (instruction, response)**.  
Where the instruction is usually the question asked by the human and the response is the desired model response.

Instruction-tuning data can be generated in several ways:
- manually - humans provide questions and answer them
- automatically - given a large enough pre-trained model we can use a bunch of few shot examples to generate more training samples and then heuristically filter them (see [Self-Instruct: Aligning Language Models with Self-Generated Instructions](https://arxiv.org/abs/2212.10560)).  
 Of course, we can make this process iterative, that is as the model is getting better at following instructions we can use the new version to generate more instruction-tuning data.  
 In fact, the iterative approach resembles the [Expert Iteration](https://arxiv.org/pdf/1705.08439) algorithm from RL.

Below you are given a HF code that performs instruction tuning on a small base (non-instruction tuned model).

In [1]:
# NOTE we do not need gcfs
!pip3 install transformers==4.49 torch datasets trl==0.16.1 accelerate sentencepiece protobuf



In [2]:
from trl import SFTConfig, SFTTrainer, setup_chat_format
from transformers import AutoModelForCausalLM, AutoTokenizer, PreTrainedModel, PreTrainedTokenizer
from datasets import load_dataset
import torch
import json
import numpy as np
import string
from typing import Callable, List, Dict
# old method for disabling wandb
import os
os.environ["WANDB_DISABLED"] = "true"

2025-04-08 11:18:46.312597: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1744103926.343360  156199 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1744103926.353067  156199 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1744103926.376699  156199 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1744103926.376731  156199 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1744103926.376734  156199 computation_placer.cc:177] computation placer alr

In [3]:
# We download the data
data = load_dataset("HuggingFaceTB/everyday-conversations-llama3.1-2k")

sample = json.dumps(data["train_sft"][0], indent=2)
print(sample)

{
  "topic": "Shopping",
  "subtopic": "Budgeting",
  "subsubtopic": "Tracking expenses",
  "full_topic": "Shopping/Budgeting/Tracking expenses",
  "prompt": "Generate a very simple multi-turn conversation between a User and an AI Assistant about Shopping/Budgeting/Tracking expenses. The conversation should start with a basic greeting like \"Hello\" or \"Hi\" and be straightforward. Include 3-4 short exchanges. The AI should give brief, clear answers. The User should ask simple questions.\n\nStart the conversation like this:\n\nUser: [Greeting]\n\nAI: Hello! How can I help you today?\n\nUser: [Continue with a simple question or statement]\n\nAI: [Respond briefly and clearly]\n\nUser: [Ask a follow-up question or make another simple statement]\n\nAI: [Provide a final helpful response]\n\nMake sure the entire conversation remains very simple and easy to understand, focusing on basic topics or requests.",
  "completion": "User: Hi\n\nAI: Hello! How can I help you today?\n\nUser: I'm tryin

In [4]:
# We download a "smol" model

DEVICE_NAME = "cuda" if torch.cuda.is_available() else "cpu"

model_name = (
    "BEE-spoke-data/smol_llama-101M-GQA"  # note that this model is really "smol"
)
tokenizer_name = model_name


model = AutoModelForCausalLM.from_pretrained(model_name, device_map=DEVICE_NAME)
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, use_fast=False)

# we let huggingface to set the chat format for us

# NOTE setup_chat_format
# "Adds special tokens to the tokenizer, e.g. <|im_start|> and <|im_end|>, to indicate the start and end of a conversation."
# "Resizes the model’s embedding layer to accommodate the new tokens."
# "Sets the chat_template of the tokenizer, which is used to format the input data into a chat-like format. The default is chatml from OpenAI."
# For more see https://huggingface.co/docs/trl/en/sft_trainer#add-special-tokens-for-chat-format
model, tokenizer = setup_chat_format(model=model, tokenizer=tokenizer)

example_chat = [
    {"role": "user", "content": "Hey!"},
    {"role": "assistant", "content": "Hey! How can I help you?"},
    {"role": "user", "content": "How to install gentoo?"},
]

formatted_text = tokenizer.apply_chat_template(example_chat, tokenize=False)
print(formatted_text)

<|im_start|>user
Hey!<|im_end|>
<|im_start|>assistant
Hey! How can I help you?<|im_end|>
<|im_start|>user
How to install gentoo?<|im_end|>



In [5]:
def test_model_on_simple_example(example: List[Dict[str, str]]):
    tokens = tokenizer.apply_chat_template(
        example_chat, return_tensors="pt", add_generation_prompt=True
    )

    attention_mask = torch.ones_like(tokens).to(model.device).bool()
    output = model.generate(
        tokens.to(model.device), attention_mask=attention_mask, max_new_tokens=256, do_sample=False
    )

    print(tokenizer.decode(output[0]))


test_model_on_simple_example(example_chat)

<|im_start|> user
Hey! <|im_end|> 
 <|im_start|> assistant
Hey! How can I help you? <|im_end|> 
 <|im_start|> user
How to install gentoo? <|im_end|> 
 <|im_start|> assistant
Hey! How can I help you?!("I am not a gentoo"):
Hey! How can I help you?!("I am not a gentoo"):
Hey! How can I help you?!("I am not a gentoo"):
Hey! How can I help you?!("I am not a gentoo"):
Hey! How can I help you?!("I am not a gentoo"):
Hey! How can I help you?!("I am not a gentoo"):
Hey! How can I help you?!("I am not a gentoo"):
Hey! How can I help you?!("I am not a gentoo"):
Hey! How can I help you?!("I am not a gentoo"):
Hey! How can I help you?!("I am not a gentoo"):
Hey! How can I help you?!("I am not a gentoo"):
Hey! How can I help you?!("I am not a gentoo"):
Hey! How can I help you?!("I am not a gentoo"):
Hey! How can I help you?


In [6]:
## adopted from https://github.com/huggingface/smol-course/blob/main/1_instruction_tuning/notebooks/sft_finetuning_example.ipynb
# NOTE You can read more about the SFTTrainer here https://huggingface.co/docs/trl/en/sft_trainer


### TODO - Experiment with fine-tuning parameters, test the importance of learning rate and weight decay

## see https://huggingface.co/docs/trl/en/sft_trainer#trl.SFTConfig
sft_config = SFTConfig(
    output_dir="./sft_part",
    max_steps=300,
    per_device_train_batch_size=8,
    gradient_accumulation_steps=2,  # in general we would like to have gradient_accumulation_steps*per_device_train_batch_size >= 32
    max_seq_length=2048,
    learning_rate=5e-5,  # make it lower if the model breaks
    logging_steps=10,
    save_steps=100,
    eval_strategy="steps",
    eval_steps=100,  # eval after each 100 steps
    use_mps_device=False,
    hub_model_id=None,
    packing=False,  # False - do not add many examples in the same context
    weight_decay=0.0,  # Used in continued-training and pre-training, here we do not want to lose general model knowledge
    gradient_checkpointing=True,  # can reduce VRAM requirements
    bf16=True,  # rember that we want mixed precision training (that is we want to aggregate gradients in higher precision)
    report_to="none" # disable wandb
)

# Initialize the SFTTrainer
trainer = SFTTrainer(
    model=model,
    args=sft_config,
    train_dataset=data["train_sft"],
    processing_class=tokenizer,
    eval_dataset=data["test_sft"],
)
# !!!NOTE IF THE CODE ABOVE FAILS WITH no field named "text" exception then click Runtime->Restart Session and run all!!

In [7]:
trainer.train()

  ctx_manager = torch.cpu.amp.autocast(cache_enabled=cache_enabled, dtype=self.amp_dtype)
`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`.


Step,Training Loss,Validation Loss


KeyboardInterrupt: 

In [None]:
# NOTE
# Our model is very "smol" and can have trouble producing longer coherent text
test_model_on_simple_example(example_chat)

<|im_start|> user
Hey! <|im_end|> 
 <|im_start|> assistant
Hey! How can I help you? <|im_end|> 
 <|im_start|> user
How to install gentoo? <|im_end|> 
 <|im_start|> assistant
To install gentoo, you'll need a web browser or a web browser-specific software. You can download it from the official website or by contacting the website's authorized user or authorized browser support. <|im_end|>


## Preference Optimization

One may also want the model to follow a specific style.  
For example to provide helpful, non-toxic, and exhaustive responses.  
To achieve this, we can train the model to put more probability to desirable responses and less to undesirable ones.  
For this lab scenario, we will utilize DPO.  


Some people believe that preference optimization can result in lowered model performance.  
That is, a model subjected to preference optimization can perform worse on benchmarks and simply refuse to complete some tasks.  
What do you think?

### DPO
Suppose that we have a reward model $r$ and a language model $\pi_{\mathrm{base}}$, and we want to create $\pi_r$ that is not too far from  $\pi_{\mathrm{base}}$ but also generates responses with high rewards.  
One of the ways to achieve this is to minimize the following loss:
$$
    -\left[\mathbb{E}_{x \sim D, y \sim \pi_r(|x)}[r(y, x)] - \beta \mathbb{D}_{\mathrm{KL}}(\pi_r(y|x) || \pi_{\mathrm{base}}(y|x))\right]
$$

The DPO authors derive that the closed form solution is as follows
$$
\pi_r(y|x) = \frac{\pi_\mathrm{base}(y|x)  e^{\frac{r(y, x)}{\beta}}}{\sum_{z}{\pi_\mathrm{base}(z|x)  e^{\frac{r(z, x)}{\beta}}}}
$$
As they note the sum is intractable. However, when we consider the Bradley-Terry preference model:
$$
p(y_1 > y_2 | x) = \sigma(r(y_1 | x) - r(y_2 | x)) = \frac{1}{1+e^{-r(y_1 | x) + r(y_2 | x)}} =
\frac{e^{r(y_1 | x)}}{e^{r(y_1 | x)}+e^{r(y_2 | x)}}
$$
and rewrite
$$
\pi_r(y|x) = \frac{\pi_\mathrm{base}(y|x)  e^{\frac{r(y, x)}{\beta}}}{\sum_{z}{\pi_\mathrm{base}(z|x)  e^{\frac{r(z, x)}{\beta}}}} | \mathrm{log}
$$
$$
\log(\pi_r(y|x)) = \log(\pi_\mathrm{base}(y|x)) + \frac{r(y, x)}{\beta} - \log\left(\sum_{z}{\pi_\mathrm{base}(z|x)  e^{\frac{r(z, x)}{\beta}}}\right)
$$
$$
r(y, x) = \beta \log\left(\frac{\pi_r(y|x)}{\pi_\mathrm{base}(y|x)}\right) + \beta \log\left(\sum_{z}{\pi_\mathrm{base}(z|x)  e^{\frac{r(z, x)}{\beta}}}\right)
$$
$$
r(y, x) = \beta \log\left(\frac{\pi_r(y|x)}{\pi_\mathrm{base}(y|x)}\right) + \beta \log\left(Z(x)\right)
$$
Then we have
$$
p(y_1 > y_2 | x) = \sigma\left(\beta \log\left(\frac{\pi_r(y_1|x)}{\pi_\mathrm{base}(y_1|x)}\right) + \beta \log\left(Z(x)\right) - \beta \log\left(\frac{\pi_r(y_2|x)}{\pi_\mathrm{base}(y_2|x)}\right) - \beta \log\left(Z(x)\right) \right)
$$
What equals
$$
p(y_1 > y_2 | x) = \sigma\left(\beta \log\left(\frac{\pi_r(y_1|x)}{\pi_\mathrm{base}(y_1|x)}\right) - \beta \log\left(\frac{\pi_r(y_2|x)}{\pi_\mathrm{base}(y_2|x)}\right) \right)
$$

And we can optimize the following loss by changing the parameters of $\pi_r$
$$
- \mathbb{E}_{x, y_1, y_2 \sim D}[\log\left(p(y_1 > y_2 | x)\right))
$$

Below we show how one can use DPO trainer from TRL library

In [None]:
## adopted from https://github.com/huggingface/smol-course/blob/main/2_preference_alignment/notebooks/dpo_finetuning_example.ipynb
from trl import DPOTrainer, DPOConfig

# NOTE a trick to download only a fragment of the dataset
dataset = load_dataset(path="trl-lib/ultrafeedback_binarized", split=['train[:2%]', "test[:2%]"])

### TODO - Experiment with fine-tuning parameters

training_args = DPOConfig(
    # Training batch size per GPU
    per_device_train_batch_size=8,
    # Number of updates steps to accumulate before performing a backward/update pass
    # Effective batch size = per_device_train_batch_size * gradient_accumulation_steps
    gradient_accumulation_steps=2,
    # Saves memory by storing subset of activations during forward pass
    # Instead recomputes them during backward pass
    gradient_checkpointing=True,
    # Base learning rate for training
    learning_rate=1e-5,
    # Learning rate schedule - 'cosine' gradually decreases LR following cosine curve
    lr_scheduler_type="cosine",
    # Total number of training steps
    max_steps=100,
    # Disables model checkpointing during training
    save_strategy="no",
    # How often to log training metrics
    logging_steps=1,
    # Directory to save model outputs
    output_dir="smol_dpo_output",
    # Number of steps for learning rate warmup
    warmup_steps=50,
    # Use bfloat16 precision for faster training
    bf16=True,
    # Disable wandb/tensorboard logging
    report_to="none",
    # Keep all columns in dataset even if not used
    remove_unused_columns=False,
    # Enable MPS (Metal Performance Shaders) for Mac devices
    use_mps_device=False,
    # Model ID for HuggingFace Hub uploads
    hub_model_id=None,
    # DPO-specific temperature parameter that controls the strength of the preference model
    # Lower values (like 0.1) make the model more conservative in following preferences
    beta=0.1,
    # Maximum length of the input prompt in tokens
    max_prompt_length=512,
    # Maximum combined length of prompt + response in tokens
    max_length=1024,
)

trainer = DPOTrainer(
    # The model to be trained
    model=model,
    # Training configuration from above
    args=training_args,
    # Dataset containing preferred/rejected response pairs
    train_dataset=dataset[0],
    eval_dataset=dataset[1],
    # Tokenizer for processing inputs
    processing_class=tokenizer,
)

trainer.train()

Using the `WANDB_DISABLED` environment variable is deprecated and will be removed in v5. Use the --report_to flag to control the integrations used for logging result (for instance --report_to none).


README.md:   0%|          | 0.00/643 [00:00<?, ?B/s]

train-00000-of-00001.parquet:   0%|          | 0.00/131M [00:00<?, ?B/s]

test-00000-of-00001.parquet:   0%|          | 0.00/2.14M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/62135 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/1000 [00:00<?, ? examples/s]

Extracting prompt in train dataset:   0%|          | 0/1243 [00:00<?, ? examples/s]

Applying chat template to train dataset:   0%|          | 0/1243 [00:00<?, ? examples/s]

Tokenizing train dataset:   0%|          | 0/1243 [00:00<?, ? examples/s]

Extracting prompt in eval dataset:   0%|          | 0/20 [00:00<?, ? examples/s]

Applying chat template to eval dataset:   0%|          | 0/20 [00:00<?, ? examples/s]

Tokenizing eval dataset:   0%|          | 0/20 [00:00<?, ? examples/s]

Step,Training Loss
1,0.6931
2,0.6931
3,0.6862
4,0.7092
5,0.7521
6,0.7021
7,0.6736
8,0.6618
9,0.705
10,0.7635


KeyboardInterrupt: 

In [None]:
test_model_on_simple_example(example_chat)



<|im_start|> user
Hey! <|im_end|> 
 <|im_start|> assistant
Hey! How can I help you? <|im_end|> 
 <|im_start|> user
How to install gentoo? <|im_end|> 
 <|im_start|> assistant
To install gentoo, you'll need a web browser or a web browser-specific software. You can download it from the official website of your choice. <|im_end|>


## RL training

We will implement a simple (not complete and not optimized) version of GRPO and try to teach our model to follow a specific style in responses (model may be too small for harder tasks).

In [None]:
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import json
import numpy as np
from typing import Callable, List, Dict
from transformers import PreTrainedModel, PreTrainedTokenizer

DEVICE_NAME = "cuda" if torch.cuda.is_available() else "cpu"

model_name = (
    "BEE-spoke-data/smol_llama-101M-GQA"  # note that this model is really "smol"
)
tokenizer_name = model_name

In [None]:
!wget https://gist.githubusercontent.com/creikey/42d23d1eec6d764e8a1d9fe7e56915c6/raw/b07de0068850166378bc3b008f9b655ef169d354/top-1000-nouns.txt -O top-1000-nouns.txt

--2025-04-06 21:36:22--  https://gist.githubusercontent.com/creikey/42d23d1eec6d764e8a1d9fe7e56915c6/raw/b07de0068850166378bc3b008f9b655ef169d354/top-1000-nouns.txt
Resolving gist.githubusercontent.com (gist.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.110.133, ...
Connecting to gist.githubusercontent.com (gist.githubusercontent.com)|185.199.108.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 7257 (7.1K) [text/plain]
Saving to: ‘top-1000-nouns.txt’


2025-04-06 21:36:22 (82.8 MB/s) - ‘top-1000-nouns.txt’ saved [7257/7257]



In [None]:
from typing import Tuple
from functools import partial


def get_word_list(filepath: str):
    with open(filepath, "r") as f:
        lines = f.readlines()

    data = [l.strip() for l in lines if len(l.strip()) > 0]

    return data


def generate_task(seed: int, topics: List[str]):
    """
    We have only three tasks.
    CAPS - the model should generate some text using uppercase charcters
    DIGITS - the model should output digits
    LOWERCASE - the model should generate some text using lowercase charcters
    """

    prev_rnd_state = np.random.get_state()
    np.random.seed(seed=seed)
    mode = np.random.choice(["CAPS", "DIGITS", "LOWERCASE"]).item()
    topic = np.random.choice(topics).item()
    text = f"Text in {mode} style about {topic}:"
    np.random.set_state(prev_rnd_state)
    return text, mode


generate_task_with_common_topic = partial(generate_task, topics=get_word_list(filepath="top-1000-nouns.txt"))




def grade_response(response: str, gt):
    response_length = max(1, len(response))
    if gt == "CAPS":
        return sum([1 for c in response if c.isupper()]) / response_length
    elif gt == "DIGITS":
        return sum([1 for c in response if c.isdigit()]) / response_length
    elif gt == "LOWERCASE":
        return sum([1 for c in response if c.islower()]) / response_length
    else:
        raise ValueError()


def grade_many_responses(responses: List[str], gt: List[bool]):
    assert len(responses) == len(gt)
    results = []

    for r, g in zip(responses, gt):
        results.append(grade_response(r, g))

    return results


class DataGenerator:
    def __init__(
        self,
        gen_fn: Callable[[int], Tuple[str, str]],
        tokenizer: PreTrainedTokenizer,
        seed: int = 0,
    ):
        self.seed = seed
        self.gen_fn = gen_fn
        self.tokenizer = tokenizer

    def get_batch(self, batch_size: int):
        data_question = []
        data_answer = []
        for i in range(batch_size):
            question, answer = self.gen_fn(seed=self.seed)
            data_question.append(question)
            data_answer.append(answer)
            self.seed += 1

        tokenized_data = self.tokenizer(
            data_question, padding_side="left", padding="longest", return_tensors="pt"
        )

        return tokenized_data, data_answer, data_question

In [None]:
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, use_fast=False)
tokenizer.pad_token_id = tokenizer.eos_token_id
data_generator = DataGenerator(
    gen_fn=generate_task_with_common_topic, tokenizer=tokenizer
)

print(data_generator.get_batch(32)[-1][0])

Text in CAPS style about background:


In [None]:
@torch.no_grad
def generate_k_responses(
    model: PreTrainedModel,
    tokenizer: PreTrainedTokenizer,
    data_batch: Dict[str, torch.Tensor],
    k: int,
    temperature: float = 1.0,
    topp: float = 0.95,
):
    input_ids = data_batch["input_ids"].to(model.device)
    attention_mask = data_batch["attention_mask"].to(model.device)

    batch, seq = input_ids.shape

    input_ids = input_ids.repeat_interleave(repeats=k, dim=0)
    attention_mask = attention_mask.repeat_interleave(repeats=k, dim=0)

    assert attention_mask.shape == input_ids.shape

    input_plus_output = model.generate(
        inputs=input_ids,
        attention_mask=attention_mask,
        do_sample=True,
        top_p=topp,
        temperature=temperature,
        stop_strings=["\n", r"</s>"],
        tokenizer=tokenizer,
        max_new_tokens=32,
    )
    output = input_plus_output[:, input_ids.shape[1] :]
    output_text = tokenizer.batch_decode(output)

    output = output.reshape(batch, k, -1)
    input_ids = input_ids.reshape(batch, k, seq)
    attention_mask = attention_mask.reshape(batch, k, seq)

    return (input_ids, attention_mask, output), output_text


model = AutoModelForCausalLM.from_pretrained(model_name).to(torch.device(DEVICE_NAME))
data_batch = data_generator.get_batch(8)
print(data_batch)
generate_k_responses(model=model, tokenizer=tokenizer, data_batch=data_batch[0], k=4)

Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


({'input_ids': tensor([[    1,  3992,   297,   360,  6259,  1806, 29903,  3114,  1048,  5120,
         29901],
        [    2,     1,  3992,   297,   315,  3301, 29903,  3114,  1048,  8099,
         29901],
        [    1,  3992,   297,   360,  6259,  1806, 29903,  3114,  1048,  6554,
         29901],
        [    1,  3992,   297,   360,  6259,  1806, 29903,  3114,  1048,  4259,
         29901],
        [    1,  3992,   297,   360,  6259,  1806, 29903,  3114,  1048, 15354,
         29901],
        [    2,     1,  3992,   297,   315,  3301, 29903,  3114,  1048, 15483,
         29901],
        [    1,  3992,   297,   360,  6259,  1806, 29903,  3114,  1048,  1302,
         29901],
        [    1,  3992,   297,   360,  6259,  1806, 29903,  3114,  1048,  3458,
         29901]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
        [0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1, 1

((tensor([[[    1,  3992,   297,   360,  6259,  1806, 29903,  3114,  1048,  5120,
            29901],
           [    1,  3992,   297,   360,  6259,  1806, 29903,  3114,  1048,  5120,
            29901],
           [    1,  3992,   297,   360,  6259,  1806, 29903,  3114,  1048,  5120,
            29901],
           [    1,  3992,   297,   360,  6259,  1806, 29903,  3114,  1048,  5120,
            29901]],
  
          [[    2,     1,  3992,   297,   315,  3301, 29903,  3114,  1048,  8099,
            29901],
           [    2,     1,  3992,   297,   315,  3301, 29903,  3114,  1048,  8099,
            29901],
           [    2,     1,  3992,   297,   315,  3301, 29903,  3114,  1048,  8099,
            29901],
           [    2,     1,  3992,   297,   315,  3301, 29903,  3114,  1048,  8099,
            29901]],
  
          [[    1,  3992,   297,   360,  6259,  1806, 29903,  3114,  1048,  6554,
            29901],
           [    1,  3992,   297,   360,  6259,  1806, 29903,  3114,  1048,

Below you are given a simplified RL code.  
It shows the main idea of GRPO, to generate several completions for each question and calculate the group normalized reward.  
The code below has the following differences with GRPO:  
* lack of kl divergence with a periodically saved reference model
* lack of PPO sample utilization (here each sample is used for only one gradient step)

Feel free to extend this code, however note that the colab resources are highly limited.

In [None]:
from tqdm import trange
import copy


def training_step(
    model: PreTrainedModel,
    optimizer,
    grouped_input: torch.Tensor,
    grouped_output: torch.Tensor,
    grouped_attention_mask_input: torch.Tensor,
    grouped_attention_mask_output: torch.Tensor,
    grouped_rewards: torch.Tensor,
):
    """
    grouped_input - of shape (batch, num_samples, input_seq_length)
    grouped_output - of shape (batch, num_samples, output_seq_length)
    grouped_attention_mask_input - of shape (batch, num_samples, input_seq_length)
    grouped_attention_mask_output - of shape (batch, num_samples, output_seq_length)
    grouped_rewards - of shape (batch, num_samples, output_seq_length)

    Implements a highly simplified version of GRPO
    Differences
    * no ppo part (fully on-line algorithm that utilizes each sample exactly once)
    * no kl-divergence (model may break )
    """

    def get_model_log_probs(model: PreTrainedModel, input_ids: torch.Tensor, attention_mask: torch.Tensor):
        logits = model(input_ids=input_ids, attention_mask=attention_mask).logits
        assert len(logits.shape) == 3  # batch * k, seq, vocab
        logits = logits[
            :, (input_seq - 1) : -1
        ]  # we only give points for the predicted text
        log_probs = torch.log_softmax(logits, dim=-1)

        log_probs = torch.gather(log_probs, dim=-1, index=input_ids[:, input_seq:, None])[
            ..., 0
        ]
        assert len(log_probs.shape) == 2

        grouped_log_probs = log_probs.reshape(batch, k, output_seq)
        grouped_log_probs = grouped_log_probs * grouped_attention_mask_output
        return grouped_log_probs

    _, _, input_seq = grouped_input.shape
    _, _, output_seq = grouped_output.shape

    all_tokens = torch.cat([grouped_input, grouped_output], dim=-1)
    attn_mask = torch.cat(
        [grouped_attention_mask_input, grouped_attention_mask_output], dim=-1
    )
    assert attn_mask.shape == all_tokens.shape
    batch, k, seq = all_tokens.shape

    all_tokens = all_tokens.reshape(batch * k, seq)
    attn_mask = attn_mask.reshape(batch * k, seq)

    assert all_tokens.shape == attn_mask.shape

    grouped_log_probs = get_model_log_probs(model=model,
                                            input_ids=all_tokens,
                                            attention_mask=attn_mask)


    grouped_rewards_mean = grouped_rewards.mean(dim=1, keepdim=True)
    grouped_rewards_std = grouped_rewards.std(dim=1, keepdim=True) + 1e-5

    normalized_rewards = (grouped_rewards - grouped_rewards_mean) / grouped_rewards_std

    grouped_attention_mask_output_reduced = grouped_attention_mask_output.sum(-1)

    denominator =  torch.maximum(grouped_attention_mask_output_reduced,
                                 torch.ones_like(grouped_attention_mask_output_reduced))

    grouped_log_probs = grouped_log_probs.sum(-1) / denominator

    assert normalized_rewards.shape == grouped_log_probs.shape

    loss = -(grouped_log_probs * normalized_rewards)

    loss = loss.mean()

    optimizer.zero_grad()

    loss.backward()

    optimizer.step()

    return loss.detach().cpu().item()


def training_loop(
    model: PreTrainedModel,
    tokenizer: PreTrainedTokenizer,
    data_generator: DataGenerator,
    solution_grader: Callable,
    iterations: int,
    preserve_model_each_n_iter: int,
    batch_size: int = 8,
    k: int = 16,
    lr=2.5e-5,
):

    model.train()

    optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=0.0)

    with torch.amp.autocast(dtype=torch.bfloat16, device_type="cuda"):

        reward_history = []

        bar = trange(0, iterations, 1, desc="Training")

        for _ in bar:
            question_batch, solution_batch, _ = data_generator.get_batch(
                batch_size=batch_size
            )

            (
                grouped_input_ids,
                grouped_attention_mask,
                grouped_responses_tokens,
            ), response_texts = generate_k_responses(
                model=model, tokenizer=tokenizer, data_batch=question_batch, k=k
            )

            solution_batch_repeated = []
            for sb in solution_batch:
                solution_batch_repeated = solution_batch_repeated + [sb] * k

            grading_results = solution_grader(response_texts, solution_batch_repeated)
            grading_results = torch.tensor(
                grading_results, dtype=torch.float32, device=model.device
            )
            grouped_grading_results = grading_results.reshape(batch_size, k)

            output_attention_mask = (
                grouped_responses_tokens != tokenizer.pad_token_id
            ).to(grouped_attention_mask.dtype)

            per_problem_performance = grouped_grading_results.mean(1)
            reward_history.append(per_problem_performance.mean().cpu().item())

            loss = training_step(
                model=model,
                optimizer=optimizer,
                grouped_input=grouped_input_ids,
                grouped_output=grouped_responses_tokens,
                grouped_attention_mask_input=grouped_attention_mask,
                grouped_attention_mask_output=output_attention_mask,
                grouped_rewards=grouped_grading_results,
            )

            bar.set_description(
                desc=(
                    f"Training: loss {loss:.2f} "
                    f"reward {reward_history[-1]:.5f} "
                    f"reward-last-10 {np.mean(reward_history[-10:]):.5f} "
                    f"reward-last-100: {np.mean(reward_history[-100:]):.5f} "
                )
            )


training_loop(
    model=model,
    tokenizer=tokenizer,
    data_generator=data_generator,
    solution_grader=grade_many_responses,
    iterations=1000,
    preserve_model_each_n_iter=100,
)

Training:   0%|          | 0/1000 [00:00<?, ?it/s]Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.
Training: loss 0.04 reward 0.29749 reward-last-10 0.29749 reward-last-100: 0.29749 :   0%|          | 1/1000 [00:04<1:08:15,  4.10s/it]Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.
Training: loss -0.09 reward 0.14454 reward-last-10 0.22102 reward-last-100: 0.22102 :   0%|          | 2/1000 [00:06<49:27,  2.97s/it]  Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.
Training: loss -0.03 reward 0.20071 reward-last-10 0.21425 reward-last-100: 0.21425 :   0%|          | 3/1000 [00:08<45:44,  2.75s/it]Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.
Training: loss -0.11 reward 0.32285 reward-last-10 0.24140 reward-last-100: 0.24140 :   0%|          | 4/1000 [00:11<43:01,  2.59s/it]Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.
Training: loss -0.22 reward 0.17525 reward-last-10 0.22817 reward-last-

In [None]:
## TODO Read the implementation of the grading function and investigate the outputs of the original and RL-tuned model. Note that we do not have KL-divergence loss
## Investigate the model training: What can be the main obstacle for the model? Does the base model understand the problem, or is it just guessing responses?
data_batch = data_generator.get_batch(8)
question = data_batch[-1]

response = generate_k_responses(model=model, tokenizer=tokenizer, data_batch=data_batch[0], k=1)[-1]

qr = [(q, r) for q, r in zip(question, response)]

print(json.dumps(qr, indent=2))

Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


[
  [
    "Text in LOWERCASE style about murder:",
    "sustainable and environmentally sustainable and environmentally sustainable and environmentally sustainable and environmentally sustainable and environmentally sustain"
  ],
  [
    "Text in CAPS style about stock:",
    "INTERPRETERMISATIONALIZATIONALIZATIONALIZATIONALIZATIONALIZATIONALIZATIONALIZATIONALIZATIONAL"
  ],
  [
    "Text in CAPS style about initiative:",
    "INTERPRETERMISATIONALIZATIONALIZATIONALIZATIONALIZATIONALIZATIONALIZATIONALIZATIONALIZATIONAL"
  ],
  [
    "Text in CAPS style about licence:",
    "INTERPRETERMISATIONALIZATIONALIZATIONALIZATIONALIZATIONALIZATIONALIZATIONALIZATIONALIZATIONAL"
  ],
  [
    "Text in LOWERCASE style about justice:",
    "sustainable and environmentally sustainable and environmentally sustainable and environmentally sustainable and environmentally sustainable and environmentally sustain"
  ],
  [
    "Text in DIGITS style about report:",
    "0000000000000000000000000000000"
  ],
 