# GRPO Training project: teach an LLM to do additions, again

In this notebook, you'll find:
* A basic Transformer with basic tokenizer
* A basic dataset for additions
* A classical pre-trainer, minimizing cross entropy loss
* A Vanilla GRPO

You're not supposed to edit the existing code (you can if you want to...).
You should implement one (or more) of the following:
* GRPO with PPO (the `usual` one)
* RLOO
* ReMax
* DPO
* RAFT
* your own RLHF method!

In [None]:
import torch
from torch import nn
from torch.nn import functional as F

import random
import math
import re
import time

In [None]:
num_digits = 3

dataset_size = 64_000
train_proportion = 0.9

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


## Step 1: Construct a tokenizer

In [None]:
pad_token="[PAD]"
eos_token="[EOS]"

In [None]:
class character_level_tokenizer:
    """
    character-level
    """
    def __init__(self):
        self.vocab = [str(x) for x in range(10)] + ["+", "="] + [pad_token, eos_token]
        self.token_to_id = {v : k for k, v in enumerate(self.vocab)}
        self.id_to_token = {k : v for k, v in enumerate(self.vocab)}
        self.ntokens = len(self.vocab)
        self.pattern = f"[^{re.escape(''.join(self.vocab))}]"

    def clean(self, text):
        """
        removes all characters not in the vocabulary
        """
        out = re.sub(self.pattern, "", text)
        return out

    def pre_tokenization(self, text):
        """
        character-level
        """
        return [c for c in text]

    def encode(self, text):
        text_list = self.pre_tokenization(self.clean(text))
        return [self.token_to_id[c] for c in text_list]

    def decode(self, token_list):
        return "".join([self.id_to_token[x] for x in token_list])

In [None]:
tokenizer = character_level_tokenizer()
ntokens = tokenizer.ntokens
ntokens

14

In [None]:
prompt = "12 + 42 ="
inputs = tokenizer.encode(prompt)
inputs, tokenizer.decode(inputs)

([1, 2, 10, 4, 2, 11], '12+42=')

## Step 2: Create a dataset for arithmetic operations

In [None]:
def sample_datapoint(num_digits = 3):
    a_list = [random.randint(0, 9) for _ in range(num_digits)]
    b_list = [random.randint(0, 9) for _ in range(num_digits)]
    a_int = int("".join([str(x) for x in a_list]))
    b_int = int("".join([str(x) for x in b_list]))
    a_str = "".join([str(x) for x in a_list])
    b_str = "".join([str(x) for x in b_list])
    sum_int = a_int + b_int
    return (a_str + "+" + b_str + "=", str(sum_int))

sample_datapoint(3)

('702+620=', '1322')

In [None]:
data = []
for _ in range(dataset_size):
    data.append(sample_datapoint(num_digits))
data[:4]

[('725+380=', '1105'),
 ('573+282=', '855'),
 ('224+846=', '1070'),
 ('368+009=', '377')]

In [None]:
data_train = data[: int(train_proportion * dataset_size)]
data_test = data[int(train_proportion * dataset_size):]

len(data_train),len(data_test)

(57600, 6400)

## Step 3: Construct a model

In [None]:
class PositionalEncoding(nn.Module):
    r"""Inject some information about the relative or absolute position of the tokens in the sequence.
        The positional encodings have the same dimension as the embeddings, so that the two can be summed.
        Here, we use sine and cosine functions of different frequencies.
    .. math:
        \text{PosEncoder}(pos, 2i) = sin(pos/10000^(2i/d_model))
        \text{PosEncoder}(pos, 2i+1) = cos(pos/10000^(2i/d_model))
        \text{where pos is the word position and i is the embed idx)
    Args:
        d_model: the embed dim (required).
        dropout: the dropout value (default=0.1).
        max_len: the max. length of the incoming sequence (default=5000).
    """

    def __init__(self, d_model, dropout=0.1, max_len=5000):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)

        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).transpose(0, 1)
        self.register_buffer('pe', pe)

    def forward(self, x):
        r"""Inputs of forward function
        Args:
            x: the sequence fed to the positional encoder model (required).
        Shape:
            x: [sequence length, batch size, embed dim]
            output: [sequence length, batch size, embed dim]
        """

        x = x + self.pe[:x.size(0), :]
        return self.dropout(x)

In [None]:
class TransformerModel(nn.Transformer):
    def __init__(self, ntoken, ninp, nhead, nhid, nlayers, dropout=0.5):
        super(TransformerModel, self).__init__(d_model=ninp,
                                               nhead=nhead,
                                               dim_feedforward=nhid,
                                               num_encoder_layers=nlayers)
        self.input_emb = nn.Embedding(ntoken, ninp)
        self.pos_encoder = PositionalEncoding(ninp, dropout)
        self.decoder = nn.Linear(ninp, ntoken)

        self.ninp = ninp
        self.init_weights()

    def init_weights(self):
        initrange = 0.1
        nn.init.uniform_(self.input_emb.weight, -initrange, initrange)
        nn.init.zeros_(self.decoder.bias)
        nn.init.uniform_(self.decoder.weight, -initrange, initrange)

    def _generate_square_subsequent_mask(self, sz):
        return torch.log(torch.tril(torch.ones(sz,sz)))

    def forward(self, src):
        mask = self._generate_square_subsequent_mask(len(src)).to(device)
        self.src_mask = mask

        src = self.input_emb(src) * math.sqrt(self.ninp)
        src = self.pos_encoder(src)
        output_enc = self.encoder(src, mask=self.src_mask)
        output_dec = self.decoder(output_enc)
        return F.log_softmax(output_dec, dim=-1), output_enc

In [None]:
model = TransformerModel(ntoken = ntokens,
                         ninp = 128,
                         nhead = 16,
                         nhid = 64,
                         nlayers = 8)
model.to(device)



TransformerModel(
  (encoder): TransformerEncoder(
    (layers): ModuleList(
      (0-7): 8 x TransformerEncoderLayer(
        (self_attn): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=128, out_features=128, bias=True)
        )
        (linear1): Linear(in_features=128, out_features=64, bias=True)
        (dropout): Dropout(p=0.1, inplace=False)
        (linear2): Linear(in_features=64, out_features=128, bias=True)
        (norm1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
        (norm2): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
        (dropout1): Dropout(p=0.1, inplace=False)
        (dropout2): Dropout(p=0.1, inplace=False)
      )
    )
    (norm): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
  )
  (decoder): Linear(in_features=128, out_features=14, bias=True)
  (input_emb): Embedding(14, 128)
  (pos_encoder): PositionalEncoding(
    (dropout): Dropout(p=0.5, inplace=False)
  )
)

In [None]:
print("number of parameters: {}".format(sum([x.numel() for x in model.parameters()])))

number of parameters: 668942


### Useful functions

In [None]:
def generate(model, prompts, new_tokens = 5, mode = "greedy", num_samples = 1, temperature = 0.8):
    input_tensor = torch.repeat_interleave(prompts, repeats = num_samples, dim = 1).to(device)
    # (prompt_length, batch_size * num_samples)
    for _ in range(new_tokens):
        output, _ = model(input_tensor) # (prompt_length, batch_size * num_samples, ntokens)
        logits = output[-1,:,:] # (batch_size * num_samples, ntokens)
        if mode == "greedy":
            tokens = torch.argmax(logits, -1).view((1,-1)) # (1, batch_size * num_samples)
        else: # mode == "sampling"
            logits /= temperature
            probs = torch.softmax(logits, dim=-1)
            tokens = torch.multinomial(probs, num_samples = 1).view((1,-1)) # (1, batch_size * num_samples)
        input_tensor = torch.cat((input_tensor, tokens), 0)
    return input_tensor

In [None]:
model.eval()

prompt = "2+3="
prompt_tensor = torch.tensor(tokenizer.encode(prompt)).view((-1,1))
output = generate(model, prompt_tensor).view((1,-1))
output, tokenizer.decode(output[0].tolist())

(tensor([[ 2, 10,  3, 11,  9,  9,  9,  9,  9]], device='cuda:0'), '2+3=99999')

In [None]:
def pad(token_list, type_list = "prompts"):
    max_length = max([len(x) for x in token_list])
    out = []
    for x in token_list:
        if type_list == "prompts":
            out.append([tokenizer.token_to_id[pad_token]] * (max_length - len(x)) + x)
        if type_list == "answers":
            out.append(x + [tokenizer.token_to_id[eos_token]] + [tokenizer.token_to_id[pad_token]] * (max_length - len(x)))
    return out, max_length

In [None]:
prompts = [tokenizer.encode("1+1="), tokenizer.encode("21+35=")]
answers = [tokenizer.encode("2"), tokenizer.encode("56")]
padded_prompts, _ = pad(prompts, "prompts")
padded_answers, _ = pad(answers, "answers")
padded_prompts, padded_answers
[tokenizer.decode(p) for p in padded_prompts], [tokenizer.decode(p) for p in padded_answers]

(['[PAD][PAD]1+1=', '21+35='], ['2[EOS][PAD]', '56[EOS]'])

In [None]:
def get_batch(split, i, batch_size):
    data = data_train if split == 'train' else data_test

    prompts = [data[i][0] for i in range(i, i + batch_size)]
    encoded_prompts = [tokenizer.encode(prompt) for prompt in prompts]
    padded_prompts, prompt_length = pad(encoded_prompts, "prompts")

    answers = [data[i][1] for i in range(i, i + batch_size)]
    encoded_answers = [tokenizer.encode(answer) for answer in answers]
    padded_answers, answers_length = pad(encoded_answers, "answers")

    X = torch.stack([torch.tensor(x) for x in padded_prompts], 1)
    Y = torch.stack([torch.tensor(x) for x in padded_answers], 1)
    return X, Y, prompt_length, answers_length, prompts, answers

In [None]:
X, Y, prompt_length, answers_length, prompts, answers = get_batch("train", 43, 16)
X.shape, Y.shape, prompt_length, answers_length, prompts[0], answers[0]

(torch.Size([8, 16]), torch.Size([5, 16]), 8, 4, '775+390=', '1165')

## Step 4: Evaluate

In [None]:
batch_size = 16

In [None]:
def evaluate(batch_size = batch_size):
    # Turn on evaluation mode disables dropout.
    model.eval()
    correct = 0.
    with torch.no_grad():
        for batch, i in enumerate(range(0, len(data_test) - 1, batch_size)):
            prompts, target_answers, prompt_length, answers_length, _, _ = get_batch("test", i, batch_size)
            prompts = prompts.to(device) # (prompt_length, batch_size)
            target_answers = target_answers.to(device) # (answers_length + 1, batch_size)
            output = generate(model, prompts, answers_length + 1) # (prompt_length + answers_length + 1, batch_size)
            answers_tokens = output[prompt_length:, :] # (answers_length + 1, batch_size), contains tokens
            equality_test = answers_tokens == target_answers # (answers_length + 1, batch_size), contains boolean values
            correct += torch.all(equality_test, axis=0).float().sum()
        accuracy = correct / len(data_test)
    return accuracy.item()

In [None]:
evaluate()

0.0

## Step 5: Train the model, classical approach

### Hyperparameters

In [None]:
epochs = 5
batch_size = 16
learning_rate = 8e-4

reporting_per_epoch = 5
log_interval = len(data_train) // (reporting_per_epoch + 1)
assert(log_interval % batch_size == 0)

In [None]:
def train():
    model.train()
    optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)

    best_test_accuracy = None
    test_accuracy = evaluate()
    print('-' * 89)
    print('| initialisation | test accuracy {:5.2f}'.format(test_accuracy))
    print('-' * 89)
    for epoch in range(1, epochs+1):
        epoch_start_time = time.time()
        total_loss = 0.
        start_time = time.time()
        for batch, i in enumerate(range(0, len(data_train) - 1, batch_size)):
            prompts, target_answers, prompt_length, answers_length, _, _ = get_batch("train", i, batch_size)
            prompts = prompts.to(device) # (prompt_length, batch_size)
            target_answers = target_answers.to(device) # (answers_length + 1, batch_size)
            input_tensor = torch.cat((prompts, target_answers), 0) # (prompt_length + answers_length + 1, batch_size)
            model.zero_grad()
            output, _ = model(input_tensor) # (prompt_length + answers_length + 1, batch_size, ntokens)
            output_answers = output[prompt_length-1:-1,:,:].reshape(-1, ntokens) # ((answers_length + 1) * batch_size, ntokens)
            target_answers = target_answers.view(-1)
            loss = F.cross_entropy(output_answers, target_answers)
            loss.backward()
            optimizer.step()

            total_loss += loss.item()

            if i % log_interval == 0 and batch > 0:
                cur_loss = total_loss / log_interval
                elapsed = time.time() - start_time
                print('| {:5d}/{:5d} batches | ms/batch {:5.2f} | loss {:5.2f} | perplexity {:8.2f}'.format(batch, len(data_train) // batch_size,
                                                                                                            elapsed * 1000 / log_interval, cur_loss, math.exp(cur_loss)))
                total_loss = 0
                start_time = time.time()
        test_accuracy = evaluate()
        print('-' * 89)
        print('| end of epoch {:3d} | time: {:5.2f}s | test accuracy {:5.2f}'.format(epoch, (time.time() - epoch_start_time), test_accuracy))
        print('-' * 89)
        # Save the model if the test accuracy is the best we've seen so far.
        if not best_test_accuracy or test_accuracy < best_test_accuracy:
            with open("arithmetic.pt", 'wb') as f:
                torch.save(model, f)
            best_test_accuracy = test_accuracy

In [None]:
train()

-----------------------------------------------------------------------------------------
| initialisation | test accuracy  0.00
-----------------------------------------------------------------------------------------
|   600/ 3600 batches | ms/batch  1.10 | loss  0.01 | perplexity     1.01
|  1200/ 3600 batches | ms/batch  1.12 | loss  0.00 | perplexity     1.00
|  1800/ 3600 batches | ms/batch  1.10 | loss  0.00 | perplexity     1.00
|  2400/ 3600 batches | ms/batch  1.09 | loss  0.00 | perplexity     1.00
|  3000/ 3600 batches | ms/batch  1.10 | loss  0.00 | perplexity     1.00
-----------------------------------------------------------------------------------------
| end of epoch   1 | time: 72.81s | test accuracy  1.00
-----------------------------------------------------------------------------------------
|   600/ 3600 batches | ms/batch  1.09 | loss  0.00 | perplexity     1.00
|  1200/ 3600 batches | ms/batch  1.10 | loss  0.00 | perplexity     1.00
|  1800/ 3600 batches | ms/

KeyboardInterrupt: 

In [None]:
model.eval()

for i in range(20):
    prompt, answers = data_test[i]
    prompt_tensor = torch.tensor(tokenizer.encode(prompt)).view((-1,1))
    output = generate(model, prompt_tensor, len(answers) + 1).view((1,-1))
    print(tokenizer.decode(output.tolist()[0]) + "\t actual result: " + answers)

592+554=1144[EOS]	 actual result: 1146
852+172=1022[EOS]	 actual result: 1024
760+734=1494[EOS]	 actual result: 1494
295+549=844[EOS]	 actual result: 844
158+181=334[EOS]	 actual result: 339
995+086=1084[EOS]	 actual result: 1081
195+339=534[EOS]	 actual result: 534
526+962=1484[EOS]	 actual result: 1488
044+318=364[EOS]	 actual result: 362
885+887=1774[EOS]	 actual result: 1772
207+852=1052[EOS]	 actual result: 1059
250+793=1044[EOS]	 actual result: 1043
625+900=1522[EOS]	 actual result: 1525
331+957=1282[EOS]	 actual result: 1288
545+039=582[EOS]	 actual result: 584
685+742=1424[EOS]	 actual result: 1427
371+900=1274[EOS]	 actual result: 1271
592+954=1544[EOS]	 actual result: 1546
546+266=812[EOS]	 actual result: 812
375+894=1264[EOS]	 actual result: 1269


## Step 4 bis: Vanilla GRPO training

### Custom reward functions

In [None]:
def accuracy_reward(output, answer):
    pattern = r"\[EOS\]"
    output = re.sub(pattern, "", output)
    pattern = r"(\[PAD\])*$"
    output = re.sub(pattern, "", output)
    return 1. if output == answer else 0.

accuracy_reward("123[EOS][PAD][PAD]", "123"), accuracy_reward("123", "124"),

(1.0, 0.0)

In [None]:
def distance_accuracy_reward(output, answer):
    pattern = r"\[EOS\]"
    output = re.sub(pattern, "", output)
    pattern = r"(\[PAD\])*$"
    output = re.sub(pattern, "", output)
    int_output = int(output)
    int_answer = int(answer)
    return abs(int_output - int_answer) / max(int_output, int_answer)

distance_accuracy_reward("123[EOS]", "123"), distance_accuracy_reward("123[PAD]", "124"),

(0.0, 0.008064516129032258)

In [None]:
def digit_accuracy_reward(output, answer):
    pattern = r"\[EOS\]"
    output = re.sub(pattern, "", output)
    pattern = r"(\[PAD\])*$"
    output = re.sub(pattern, "", output)
    return sum(c1 == c2 for (c1,c2) in zip(output, answer)) / max(len(output), len(answer))

digit_accuracy_reward("123[EOS][PAD][PAD]", "123"), digit_accuracy_reward("123[EOS]", "123"),

(1.0, 1.0)

In [None]:
def reward_format(output):
    pattern = r"\d+\[EOS\](\[PAD\])*$"
    return 1. if bool(re.match(pattern, output)) else 0.

reward_format("123[EOS][PAD][PAD]"), reward_format("123[EOS]"), reward_format("123"),

(1.0, 1.0, 0.0)

### Hyperparameters

In [None]:
epochs = 20
batch_size = 16
learning_rate = 1e-4
num_samples = 16
temperature = .8

reporting_per_epoch = 5
log_interval = len(data_train) // (reporting_per_epoch + 1)
assert(log_interval % batch_size == 0)

reward_fun = digit_accuracy_reward
reward_format = reward_format

In [None]:
def compute_rewards(text_outputs, answers):
    repeated_answers = [answer for answer in answers for _ in range(num_samples)]
    rewards = torch.tensor(
        [0.2 * reward_format(output) + 0.8 * reward_fun(output, answer)
         for output, answer in zip(text_outputs, repeated_answers)],
        dtype=torch.float32,
        device=device
    )
    return rewards

In [None]:
def calculate_grpo_advantages(rewards):
    # reshape rewards to group by prompt
    # compute mean and standard deviation for each prompt group
    mean_rewards = rewards.view(-1, num_samples).mean(dim=1)
    std_rewards = rewards.view(-1, num_samples).std(dim=1)
    # expand the means and stds to match the original flat rewards tensor shape
    mean_rewards = mean_rewards.repeat_interleave(num_samples, dim=0)
    std_rewards = std_rewards.repeat_interleave(num_samples, dim=0)
    # normalize rewards to get advantages
    advantages = (rewards - mean_rewards) / (std_rewards + 1e-5)
    return advantages

In [None]:
def compute_log_probs(model, outputs, prompt_length):
    logits, _ = model(outputs)
    # logits.shape = (prompt_length + answers_length + 1, batch_size * num_samples, vocab_size)

    # we only need the log probabilities for the new tokens
    # this introduces a shift: the logits for a position are the predictions for the next token
    logits = logits[prompt_length-1:-1, :, :]
    # logits.shape = (answers_length + 1, batch_size * num_samples, vocab_size)

    # convert raw logits into log probabilities along the vocabulary axis
    log_probs = F.log_softmax(logits, dim=-1)
    # log_probs.shape = (answers_length + 1, batch_size * num_samples, vocab_size)
    return log_probs

In [None]:
def compute_loss(advantages, log_probs, responses):
    # reshape responses from (answers_length + 1, batch_size * num_samples)
    # to (answers_length + 1, batch_size * num_samples, 1) for gathering
    responses = responses.unsqueeze(-1)
    # log_probs.shape = (answers_length + 1, batch_size * num_samples, vocab_size)
    # responses.shape = (answers_length + 1, batch_size * num_samples)
    # gather the log probability for each token in responses
    selected_log_probs = log_probs.gather(dim=-1, index=responses)
    # remove the extra last dimension to get back to shape (answers_length + 1, batch_size * num_samples).
    selected_log_probs = selected_log_probs.squeeze(-1)

    # normalize
    selected_log_probs = (selected_log_probs - selected_log_probs.mean(-1, keepdim=True)) / (selected_log_probs.std(-1, keepdim=True) + 1e-5)

    # advantages.shape = (batch_size * num_samples)
    # we use the same advantages for all tokens in the response
    loss = -(advantages.unsqueeze(dim=0) * selected_log_probs).mean()
    return loss

In [None]:
def train_vanilla_GRPO(verbose = False):
    optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)

    best_test_accuracy = None
    test_accuracy = evaluate()
    print('-' * 89)
    print('| initialisation | test accuracy {:5.2f}'.format(test_accuracy))
    print('-' * 89)

    # switch eval for train model (enables dropout)
    model.train()

    for epoch in range(1, epochs+1):
        epoch_start_time = time.time()
        start_time = time.time()
        for batch, i in enumerate(range(0, len(data_train) - 1, batch_size)):

            # get a batch of prompts and answers
            prompts, _, prompt_length, answers_length, questions, answers = get_batch("train", i, batch_size)
            prompts = prompts.to(device) # (prompt_length, batch_size)

            # generate samples for each prompt
            outputs = generate(model,
                               prompts,
                               new_tokens = answers_length + 1,
                               mode = "sampling",
                               num_samples = num_samples,
                               temperature = temperature)
            # outputs.shape = (prompt_length + answers_length + 1, batch_size * num_samples)
            text_outputs = [tokenizer.decode(outputs[prompt_length:, i].tolist())
                            for i in range(outputs.size(1))]

            # compute rewards
            rewards = compute_rewards(text_outputs, answers)

            # compute advantages
            advantages = calculate_grpo_advantages(rewards)

            # compute log probabilities

            log_probs = compute_log_probs(model, outputs, prompt_length)

            # compute loss
            responses = outputs[prompt_length:, :]
            loss = compute_loss(advantages, log_probs, responses)

            # optimize
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()

            if i % log_interval == 0 and batch > 0:
                elapsed = time.time() - start_time
                print('| {:5d}/{:5d} batches | ms/batch {:5.2f}'.format(batch, len(data_train) // batch_size, elapsed))
                if verbose:
                    print("\nquestion:", questions[0],
                      "\nanswer", answers[0],
                      "\noutput:", text_outputs[:num_samples],
                      "\nreward:", rewards[:num_samples],
                      "\nadvantage:", advantages[:num_samples], "\n")

                start_time = time.time()
        test_accuracy = evaluate()
        print('-' * 89)
        print('| end of epoch {:3d} | time: {:5.2f}s | test accuracy {:5.2f}'.format(epoch, (time.time() - epoch_start_time), test_accuracy))
        print('-' * 89)
        # Save the model if the test accuracy is the best we've seen so far.
        if not best_test_accuracy or test_accuracy < best_test_accuracy:
            with open("arithmetic_vanilla_GRPO.pt", 'wb') as f:
                torch.save(model, f)
            best_test_accuracy = test_accuracy


In [None]:
train_vanilla_GRPO(verbose = False)

-----------------------------------------------------------------------------------------
| initialisation | test accuracy  0.93
-----------------------------------------------------------------------------------------
|   600/ 3600 batches | ms/batch 832.02
|  1200/ 3600 batches | ms/batch 815.49


KeyboardInterrupt: 

## GRPO

In [None]:
def compute_loss_grpo(advantages, log_probs, log_probs_old, log_probs_ref, responses, beta, eps):
    # reshape responses from (answers_length + 1, batch_size * num_samples)
    # to (answers_length + 1, batch_size * num_samples, 1) for gathering
    responses = responses.unsqueeze(-1)
    # log_probs.shape = (answers_length + 1, batch_size * num_samples, vocab_size)
    # responses.shape = (answers_length + 1, batch_size * num_samples)
    # gather the log probability for each token in responses
    selected_log_probs = log_probs.gather(dim=-1, index=responses)
    selected_log_probs_old = log_probs_old.gather(dim=-1, index=responses)
    selected_log_probs_ref = log_probs_ref.gather(dim=-1, index=responses)


    # remove the extra last dimension to get back to shape (answers_length + 1, batch_size * num_samples).
    selected_log_probs = selected_log_probs.squeeze(-1) # (answers_length + 1, batch_size * num_samples)
    selected_log_probs_old = selected_log_probs_old.squeeze(-1) # (answers_length + 1, batch_size * num_samples)
    selected_log_probs_ref = selected_log_probs_ref.squeeze(-1) # (answers_length + 1, batch_size * num_samples)

    selected_log_probs = torch.log(torch.clamp(selected_log_probs, min=1e-10))
    selected_log_probs_old = torch.log(torch.clamp(selected_log_probs_old, min=1e-10))
    selected_log_probs_ref = torch.log(torch.clamp(selected_log_probs_ref, min=1e-10))

    # normalize
    #selected_log_probs = (selected_log_probs - selected_log_probs.mean(-1, keepdim=True)) / (selected_log_probs.std(-1, keepdim=True) + 1e-5)

    # advantages.shape = (batch_size * num_samples)
    # we use the same advantages for all tokens in the response
    #loss = -(advantages.unsqueeze(dim=0) * selected_log_probs).mean()
    advantages = advantages.unsqueeze(dim=0) # (1, batch_size * num_samples)
    div_KL = (selected_log_probs_ref - selected_log_probs).exp() - selected_log_probs_ref + selected_log_probs - 1

    pi_norm  = (selected_log_probs - selected_log_probs_old).exp()
    loss_oi = - (torch.min(pi_norm * advantages, torch.clamp(pi_norm, 1 - eps, 1 + eps) * advantages) - beta * div_KL) # (answers_length + 1, batch_size * num_samples)
    loss = loss_oi.mean(dim=0).mean() # (1)
    return loss

In [None]:
verbose = True

epochs = 20
batch_size = 16
learning_rate = 1e-4
num_samples = 16
temperature = .8

reporting_per_epoch = 5
log_interval = len(data_train) // (reporting_per_epoch + 1)
assert(log_interval % batch_size == 0)

reward_fun = digit_accuracy_reward
reward_format = reward_format

mu = 1 # number of optimization steps (iterations per batch)
beta = 0.04 # penalty factor
eps = 0.8 # clip

#paper
lr = 1e-6
beta = 0.04
num_samples_paper = 64

In [None]:
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)

best_test_accuracy = None
test_accuracy = evaluate()
print('-' * 89)
print('| initialisation | test accuracy {:5.2f}'.format(test_accuracy))
print('-' * 89)

# switch eval for train model (enables dropout)
model.train()

for epoch in range(1, epochs+1):
    model_ref = TransformerModel(ntoken = ntokens,
                         ninp = 128,
                         nhead = 16,
                         nhid = 64,
                         nlayers = 8).to(device)
    model_ref.load_state_dict(model.state_dict())
    for param in model_ref.parameters():
        param.requires_grad = False


    epoch_start_time = time.time()
    start_time = time.time()
    for batch, i in enumerate(range(0, len(data_train) - 1, batch_size)):
        # get a batch of prompts and answers
        prompts, _, prompt_length, answers_length, questions, answers = get_batch("train", i, batch_size)

        model_old = TransformerModel(ntoken = ntokens,
                         ninp = 128,
                         nhead = 16,
                         nhid = 64,
                         nlayers = 8).to(device)
        model_old.load_state_dict(model.state_dict())
        for param in model_old.parameters():
            param.requires_grad = False

        prompts = prompts.to(device) # (prompt_length, batch_size)

        with torch.no_grad():
            # generate samples for each prompt
            outputs_old = generate(model_old,
                                prompts,
                                new_tokens = answers_length + 1,
                                mode = "sampling",
                                num_samples = num_samples,
                                temperature = temperature)
            # outputs.shape = (prompt_length + answers_length + 1, batch_size * num_samples)

            log_probs_old = compute_log_probs(model_old, outputs_old, prompt_length) # (answers_length + 1, batch_size * num_samples, vocab_size)
            log_probs_ref = compute_log_probs(model_ref, outputs_old, prompt_length) # (answers_length + 1, batch_size * num_samples, vocab_size)

        text_outputs = [tokenizer.decode(outputs_old[prompt_length:, i].tolist())
                        for i in range(outputs_old.size(1))] # (batch_size * num_samples, answers_length + 1)

        responses_old = outputs_old[prompt_length:, :] # (answers_length + 1, batch_size * num_samples)

        # compute rewards
        rewards = compute_rewards(text_outputs, answers) # (batch_size * num_samples)

        # # compute advantages
        advantages = calculate_grpo_advantages(rewards) # (batch_size * num_samples)


        for it in range(mu):
            log_probs = compute_log_probs(model, outputs_old, prompt_length)

            loss = compute_loss_grpo(advantages, log_probs, log_probs_old, log_probs_ref, responses_old, beta, eps)

            # optimize
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()

        if i % log_interval == 0 and batch > 0:
            elapsed = time.time() - start_time
            print('| {:5d}/{:5d} batches | ms/batch {:5.2f}'.format(batch, len(data_train) // batch_size, elapsed))
            if verbose:
                print("\nloss:", loss.item(),
                    "\nquestion:", questions[0],
                  "\nanswer", answers[0],
                  "\noutput:", text_outputs[:num_samples],
                  "\nreward:", rewards[:num_samples],
                  "\nadvantage:", advantages[:num_samples], "\n")

            start_time = time.time()

    test_accuracy = evaluate()
    print('-' * 89)
    print('| end of epoch {:3d} | time: {:5.2f}s | test accuracy {:5.2f}'.format(epoch, (time.time() - epoch_start_time), test_accuracy))
    print('-' * 89)
    if not best_test_accuracy or test_accuracy < best_test_accuracy:
        with open("arithmetic_GRPO.pt", 'wb') as f:
            torch.save(model, f)
        best_test_accuracy = test_accuracy

-----------------------------------------------------------------------------------------
| initialisation | test accuracy  0.54
-----------------------------------------------------------------------------------------
|   600/ 3600 batches | ms/batch 84.51

loss: -1.7462298274040222e-08 
question: 297+142= 
answer 439 
output: ['1162[EOS]', '586[EOS][PAD]', '527[EOS][PAD]', '119[EOS][PAD]', '474[EOS][PAD]', '1107[EOS]', '1108[EOS]', '5999[EOS]', '329[EOS][PAD]', '308[EOS][PAD]', '538[EOS][PAD]', '421[EOS][PAD]', '6177[EOS]', '419[EOS]9', '380[EOS][PAD]', '1171[EOS]'] 
reward: tensor([0.2000, 0.2000, 0.2000, 0.4667, 0.4667, 0.2000, 0.2000, 0.4000, 0.4667,
        0.2000, 0.4667, 0.4667, 0.2000, 0.4000, 0.2000, 0.2000],
       device='cuda:0') 
advantage: tensor([-0.8428, -0.8428, -0.8428,  1.2318,  1.2318, -0.8428, -0.8428,  0.7132,
         1.2318, -0.8428,  1.2318,  1.2318, -0.8428,  0.7132, -0.8428, -0.8428],
       device='cuda:0') 

|  1200/ 3600 batches | ms/batch 84.43

loss: -4