# GRPO Training project: teach an LLM to do additions, again

In this notebook, you'll find:
* A basic Transformer with basic tokenizer
* A basic dataset for additions
* A classical pre-trainer, minimizing cross entropy loss
* A Vanilla GRPO

You're not supposed to edit the existing code (you can if you want to...).
You should implement one (or more) of the following:
* GRPO with PPO (the `usual` one)
* RLOO
* ReMax
* DPO
* RAFT
* your own RLHF method!

In [1]:
import torch
from torch import nn
from torch.nn import functional as F

import random
import math
import re
import time

In [2]:
num_digits = 3

dataset_size = 64_000
train_proportion = 0.9

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cpu


## Step 1: Construct a tokenizer

In [4]:
pad_token="[PAD]"
eos_token="[EOS]"

In [5]:
class character_level_tokenizer:
    """
    character-level
    """
    def __init__(self):
        self.vocab = [str(x) for x in range(10)] + ["+", "="] + [pad_token, eos_token]
        self.token_to_id = {v : k for k, v in enumerate(self.vocab)}
        self.id_to_token = {k : v for k, v in enumerate(self.vocab)}
        self.ntokens = len(self.vocab)
        self.pattern = f"[^{re.escape(''.join(self.vocab))}]"

    def clean(self, text):
        """
        removes all characters not in the vocabulary
        """
        out = re.sub(self.pattern, "", text)
        return out

    def pre_tokenization(self, text):
        """
        character-level
        """
        return [c for c in text]

    def encode(self, text):
        text_list = self.pre_tokenization(self.clean(text))
        return [self.token_to_id[c] for c in text_list]

    def decode(self, token_list):
        return "".join([self.id_to_token[x] for x in token_list])

In [6]:
tokenizer = character_level_tokenizer()
ntokens = tokenizer.ntokens
ntokens

14

In [7]:
prompt = "12 + 42 ="
inputs = tokenizer.encode(prompt)
inputs, tokenizer.decode(inputs)

([1, 2, 10, 4, 2, 11], '12+42=')

## Step 2: Create a dataset for arithmetic operations

In [8]:
def sample_datapoint(num_digits = 3):
    a_list = [random.randint(0, 9) for _ in range(num_digits)]
    b_list = [random.randint(0, 9) for _ in range(num_digits)]
    a_int = int("".join([str(x) for x in a_list]))
    b_int = int("".join([str(x) for x in b_list]))
    a_str = "".join([str(x) for x in a_list])
    b_str = "".join([str(x) for x in b_list])
    sum_int = a_int + b_int
    return (a_str + "+" + b_str + "=", str(sum_int))

sample_datapoint(3)

('629+685=', '1314')

In [9]:
data = []
for _ in range(dataset_size):
    data.append(sample_datapoint(num_digits))
data[:4]

[('473+618=', '1091'),
 ('965+073=', '1038'),
 ('606+123=', '729'),
 ('944+297=', '1241')]

In [10]:
data_train = data[: int(train_proportion * dataset_size)]
data_test = data[int(train_proportion * dataset_size):]

len(data_train),len(data_test)

(57600, 6400)

## Step 3: Construct a model

In [11]:
class PositionalEncoding(nn.Module):
    r"""Inject some information about the relative or absolute position of the tokens in the sequence.
        The positional encodings have the same dimension as the embeddings, so that the two can be summed.
        Here, we use sine and cosine functions of different frequencies.
    .. math:
        \text{PosEncoder}(pos, 2i) = sin(pos/10000^(2i/d_model))
        \text{PosEncoder}(pos, 2i+1) = cos(pos/10000^(2i/d_model))
        \text{where pos is the word position and i is the embed idx)
    Args:
        d_model: the embed dim (required).
        dropout: the dropout value (default=0.1).
        max_len: the max. length of the incoming sequence (default=5000).
    """

    def __init__(self, d_model, dropout=0.1, max_len=5000):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)

        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).transpose(0, 1)
        self.register_buffer('pe', pe)

    def forward(self, x):
        r"""Inputs of forward function
        Args:
            x: the sequence fed to the positional encoder model (required).
        Shape:
            x: [sequence length, batch size, embed dim]
            output: [sequence length, batch size, embed dim]
        """

        x = x + self.pe[:x.size(0), :]
        return self.dropout(x)

In [None]:
class TransformerModel(nn.Transformer):
    def __init__(self, ntoken, ninp, nhead, nhid, nlayers, dropout=0.5):
        super(TransformerModel, self).__init__(d_model=ninp,
                                               nhead=nhead,
                                               dim_feedforward=nhid,
                                               num_encoder_layers=nlayers)
        self.input_emb = nn.Embedding(ntoken, ninp)
        self.pos_encoder = PositionalEncoding(ninp, dropout)
        self.decoder = nn.Linear(ninp, ntoken) # projection sur le vocabulaire

        self.ninp = ninp
        self.init_weights()

    def init_weights(self):
        initrange = 0.1
        nn.init.uniform_(self.input_emb.weight, -initrange, initrange)
        nn.init.zeros_(self.decoder.bias)
        nn.init.uniform_(self.decoder.weight, -initrange, initrange)

    def _generate_square_subsequent_mask(self, sz):
        return torch.log(torch.tril(torch.ones(sz,sz))) # masque de l'attention

    def forward(self, src):
        # S = sequence lenght = len(src), B = batch size, E = embedding dimension = ninp, V = vocab size = ntoken
        # src: (S, B)
        mask = self._generate_square_subsequent_mask(len(src)).to(device)
        self.src_mask = mask # (S, S)

        src = self.input_emb(src) * math.sqrt(self.ninp) # (S, B, E)
        src = self.pos_encoder(src)  # (S, B, E)
        output_enc = self.encoder(src, mask=self.src_mask) # (S, B, E) couches MultiHeadSelfAttention() pytorch
        output_dec = self.decoder(output_enc) # (S, B, V)
        return F.log_softmax(output_dec, dim=-1), output_enc # (S, B, V)

In [16]:
model = TransformerModel(ntoken = ntokens,
                         ninp = 128,
                         nhead = 16,
                         nhid = 64,
                         nlayers = 8)
model.to(device)

TransformerModel(
  (encoder): TransformerEncoder(
    (layers): ModuleList(
      (0-7): 8 x TransformerEncoderLayer(
        (self_attn): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=128, out_features=128, bias=True)
        )
        (linear1): Linear(in_features=128, out_features=64, bias=True)
        (dropout): Dropout(p=0.1, inplace=False)
        (linear2): Linear(in_features=64, out_features=128, bias=True)
        (norm1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
        (norm2): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
        (dropout1): Dropout(p=0.1, inplace=False)
        (dropout2): Dropout(p=0.1, inplace=False)
      )
    )
    (norm): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
  )
  (decoder): Linear(in_features=128, out_features=14, bias=True)
  (input_emb): Embedding(14, 128)
  (pos_encoder): PositionalEncoding(
    (dropout): Dropout(p=0.5, inplace=False)
  )
)

In [17]:
print("number of parameters: {}".format(sum([x.numel() for x in model.parameters()])))

number of parameters: 668942


### Useful functions

In [18]:
def generate(model, prompts, new_tokens = 5, mode = "greedy", num_samples = 1, temperature = 0.8):
    input_tensor = torch.repeat_interleave(prompts, repeats = num_samples, dim = 1).to(device) # (S, B*G) G= num generations per prompt
    # (prompt_length, batch_size * num_samples)
    for _ in range(new_tokens):
        output, _ = model(input_tensor) # (prompt_length, batch_size * num_samples, ntokens)
        logits = output[-1,:,:] # dernier token (batch_size * num_samples, ntokens)
        if mode == "greedy":
            tokens = torch.argmax(logits, -1).view((1,-1)) # (1, batch_size * num_samples)
        else: # mode == "sampling"
            logits /= temperature
            probs = torch.softmax(logits, dim=-1)
            tokens = torch.multinomial(probs, num_samples = 1).view((1,-1)) # (1, batch_size * num_samples)
        input_tensor = torch.cat((input_tensor, tokens), 0)
    return input_tensor

In [20]:
model.eval()

prompt = "2+3="
prompt_tensor = torch.tensor(tokenizer.encode(prompt)).view((-1,1))
output = generate(model, prompt_tensor).view((1,-1))
output, tokenizer.decode(output[0].tolist())

(tensor([[ 2, 10,  3, 11,  6,  7,  6,  7,  7]]), '2+3=67677')

In [26]:
def pad(token_list, type_list = "prompts"):
    max_length = max([len(x) for x in token_list]) # taille commune des sequences
    out = []
    for x in token_list:
        if type_list == "prompts":
            out.append([tokenizer.token_to_id[pad_token]] * (max_length - len(x)) + x) # padding à gauche
        if type_list == "answers":
            out.append(x + [tokenizer.token_to_id[eos_token]] + [tokenizer.token_to_id[pad_token]] * (max_length - len(x))) # padding à droite
    return out, max_length

In [25]:
prompts = [tokenizer.encode("1+1="), tokenizer.encode("21+35=")]
answers = [tokenizer.encode("2"), tokenizer.encode("56")]
padded_prompts, _ = pad(prompts, "prompts")
padded_answers, _ = pad(answers, "answers")
padded_prompts, padded_answers
[tokenizer.decode(p) for p in padded_prompts], [tokenizer.decode(p) for p in padded_answers]

6
2


(['[PAD][PAD]1+1=', '21+35='], ['2[EOS][PAD]', '56[EOS]'])

In [27]:
def get_batch(split, i, batch_size):
    data = data_train if split == 'train' else data_test

    prompts = [data[i][0] for i in range(i, i + batch_size)]
    encoded_prompts = [tokenizer.encode(prompt) for prompt in prompts]
    padded_prompts, prompt_length = pad(encoded_prompts, "prompts")

    answers = [data[i][1] for i in range(i, i + batch_size)]
    encoded_answers = [tokenizer.encode(answer) for answer in answers]
    padded_answers, answers_length = pad(encoded_answers, "answers")

    X = torch.stack([torch.tensor(x) for x in padded_prompts], 1)
    Y = torch.stack([torch.tensor(x) for x in padded_answers], 1)
    return X, Y, prompt_length, answers_length, prompts, answers

In [28]:
X, Y, prompt_length, answers_length, prompts, answers = get_batch("train", 43, 16)
X.shape, Y.shape, prompt_length, answers_length, prompts[0], answers[0]

(torch.Size([8, 16]), torch.Size([5, 16]), 8, 4, '037+338=', '375')

## Step 4: Evaluate

In [29]:
batch_size = 16

In [31]:
def evaluate(batch_size = batch_size):
    # Turn on evaluation mode disables dropout.
    model.eval()
    correct = 0.
    with torch.no_grad():
        for batch, i in enumerate(range(0, len(data_test) - 1, batch_size)):
            prompts, target_answers, prompt_length, answers_length, _, _ = get_batch("test", i, batch_size)
            prompts = prompts.to(device) # (prompt_length, batch_size) (P, B)
            target_answers = target_answers.to(device) # (answers_length + 1, batch_size) (A, B)
            output = generate(model, prompts, answers_length + 1) # (prompt_length + answers_length + 1, batch_size) (P+A, B)
            answers_tokens = output[prompt_length:, :] # (answers_length + 1, batch_size), contains tokens (A, B)
            equality_test = answers_tokens == target_answers # (answers_length + 1, batch_size), contains boolean values
            correct += torch.all(equality_test, axis=0).float().sum()
        accuracy = correct / len(data_test)
    return accuracy.item()

In [32]:
evaluate()

0.0

## Step 5: Train the model, classical approach

### Hyperparameters

In [33]:
epochs = 5
batch_size = 16
learning_rate = 8e-4

reporting_per_epoch = 5
log_interval = len(data_train) // (reporting_per_epoch + 1)
assert(log_interval % batch_size == 0)

In [34]:
def train():
    model.train()
    optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)

    best_test_accuracy = None
    test_accuracy = evaluate()
    print('-' * 89)
    print('| initialisation | test accuracy {:5.2f}'.format(test_accuracy))
    print('-' * 89)
    for epoch in range(1, epochs+1):
        epoch_start_time = time.time()
        total_loss = 0.
        start_time = time.time()
        for batch, i in enumerate(range(0, len(data_train) - 1, batch_size)):
            prompts, target_answers, prompt_length, answers_length, _, _ = get_batch("train", i, batch_size)
            prompts = prompts.to(device) # (prompt_length, batch_size) (P, B)
            target_answers = target_answers.to(device) # (answers_length + 1, batch_size) (A, B)
            input_tensor = torch.cat((prompts, target_answers), 0) # (prompt_length + answers_length + 1, batch_size) (P+A, B)
            model.zero_grad()
            output, _ = model(input_tensor) # (prompt_length + answers_length + 1, batch_size, ntokens) # (P+A, B, V) probabilité sur le vocabulaire pour chaque token
            output_answers = output[prompt_length-1:-1,:,:].reshape(-1, ntokens) # ((answers_length + 1) * batch_size, ntokens) # (A*B, V) prob sur le vocab de la réponse
            target_answers = target_answers.view(-1) # (A*B)
            loss = F.cross_entropy(output_answers, target_answers)
            loss.backward()
            optimizer.step()

            total_loss += loss.item()

            if i % log_interval == 0 and batch > 0:
                cur_loss = total_loss / log_interval
                elapsed = time.time() - start_time
                print('| {:5d}/{:5d} batches | ms/batch {:5.2f} | loss {:5.2f} | perplexity {:8.2f}'.format(batch, len(data_train) // batch_size,
                                                                                                            elapsed * 1000 / log_interval, cur_loss, math.exp(cur_loss)))
                total_loss = 0
                start_time = time.time()
        test_accuracy = evaluate()
        print('-' * 89)
        print('| end of epoch {:3d} | time: {:5.2f}s | test accuracy {:5.2f}'.format(epoch, (time.time() - epoch_start_time), test_accuracy))
        print('-' * 89)
        # Save the model if the test accuracy is the best we've seen so far.
        if not best_test_accuracy or test_accuracy < best_test_accuracy:
            with open("arithmetic.pt", 'wb') as f:
                torch.save(model, f)
            best_test_accuracy = test_accuracy

In [35]:
train()

-----------------------------------------------------------------------------------------
| initialisation | test accuracy  0.00
-----------------------------------------------------------------------------------------
|   600/ 3600 batches | ms/batch  2.36 | loss  0.09 | perplexity     1.09
|  1200/ 3600 batches | ms/batch  2.41 | loss  0.07 | perplexity     1.08
|  1800/ 3600 batches | ms/batch  2.33 | loss  0.07 | perplexity     1.07
|  2400/ 3600 batches | ms/batch  2.33 | loss  0.07 | perplexity     1.07
|  3000/ 3600 batches | ms/batch  2.36 | loss  0.07 | perplexity     1.07
-----------------------------------------------------------------------------------------
| end of epoch   1 | time: 152.85s | test accuracy  0.01
-----------------------------------------------------------------------------------------
|   600/ 3600 batches | ms/batch  2.39 | loss  0.07 | perplexity     1.07
|  1200/ 3600 batches | ms/batch  2.28 | loss  0.07 | perplexity     1.07
|  1800/ 3600 batches | ms

In [36]:
model.eval()

for i in range(20):
    prompt, answers = data_test[i]
    prompt_tensor = torch.tensor(tokenizer.encode(prompt)).view((-1,1))
    output = generate(model, prompt_tensor, len(answers) + 1).view((1,-1))
    print(tokenizer.decode(output.tolist()[0]) + "\t actual result: " + answers)

288+799=1087[EOS]	 actual result: 1087
714+051=765[EOS]	 actual result: 765
417+294=711[EOS]	 actual result: 711
869+641=1510[EOS]	 actual result: 1510
925+275=1100[EOS]	 actual result: 1200
532+127=659[EOS]	 actual result: 659
728+077=705[EOS]	 actual result: 805
195+420=615[EOS]	 actual result: 615
556+200=756[EOS]	 actual result: 756
830+397=1227[EOS]	 actual result: 1227
322+759=1081[EOS]	 actual result: 1081
163+111=274[EOS]	 actual result: 274
558+635=1193[EOS]	 actual result: 1193
315+416=721[EOS]	 actual result: 731
442+639=1081[EOS]	 actual result: 1081
088+161=249[EOS]	 actual result: 249
445+142=587[EOS]	 actual result: 587
771+557=1328[EOS]	 actual result: 1328
033+828=761[EOS]	 actual result: 861
179+053=222[EOS]	 actual result: 232


## Step 4 bis: Vanilla GRPO training

### Custom reward functions

In [37]:
def accuracy_reward(output, answer):
    # 1 si output=answer
    pattern = r"\[EOS\]"
    output = re.sub(pattern, "", output)
    pattern = r"(\[PAD\])*$"
    output = re.sub(pattern, "", output)
    return 1. if output == answer else 0.

accuracy_reward("123[EOS][PAD][PAD]", "123"), accuracy_reward("123", "124"),

(1.0, 0.0)

In [39]:
def distance_accuracy_reward(output, answer):
    # ecart relatif entre output et answer
    pattern = r"\[EOS\]"
    output = re.sub(pattern, "", output)
    pattern = r"(\[PAD\])*$"
    output = re.sub(pattern, "", output)
    int_output = int(output)
    int_answer = int(answer)
    return abs(int_output - int_answer) / max(int_output, int_answer)

distance_accuracy_reward("123[EOS]", "123"), distance_accuracy_reward("123[PAD]", "124"),

(0.0, 0.008064516129032258)

In [40]:
def digit_accuracy_reward(output, answer):
    # ecart relatif entre les digits
    pattern = r"\[EOS\]"
    output = re.sub(pattern, "", output)
    pattern = r"(\[PAD\])*$"
    output = re.sub(pattern, "", output)
    return sum(c1 == c2 for (c1,c2) in zip(output, answer)) / max(len(output), len(answer))

digit_accuracy_reward("123[EOS][PAD][PAD]", "123"), digit_accuracy_reward("123[EOS]", "123"),

(1.0, 1.0)

In [41]:
def reward_format(output):
    # 1 si format avec EOS
    pattern = r"\d+\[EOS\](\[PAD\])*$"
    return 1. if bool(re.match(pattern, output)) else 0.

reward_format("123[EOS][PAD][PAD]"), reward_format("123[EOS]"), reward_format("123"),

(1.0, 1.0, 0.0)

### Hyperparameters

In [None]:
epochs = 20
batch_size = 16 # B
learning_rate = 1e-4
num_samples = 16 # G
temperature = .8

reporting_per_epoch = 5
log_interval = len(data_train) // (reporting_per_epoch + 1)
assert(log_interval % batch_size == 0)

reward_fun = digit_accuracy_reward
reward_format = reward_format

In [None]:
def compute_rewards(text_outputs, answers):
    # reward pondérée
    repeated_answers = [answer for answer in answers for _ in range(num_samples)]
    rewards = torch.tensor(
        [0.2 * reward_format(output) + 0.8 * reward_fun(output, answer)
         for output, answer in zip(text_outputs, repeated_answers)],
        dtype=torch.float32,
        device=device
    )
    return rewards # (B*G)

In [54]:
answers = ["1", "22", "3"]
[answer for answer in answers for _ in range(3)]

['1', '1', '1', '22', '22', '22', '3', '3', '3']

In [44]:
def calculate_grpo_advantages(rewards):
    # reshape rewards to group by prompt
    # compute mean and standard deviation for each prompt group
    # rewards (B*G,)
    mean_rewards = rewards.view(-1, num_samples).mean(dim=1) # (B,)
    std_rewards = rewards.view(-1, num_samples).std(dim=1) # (B,)
    # expand the means and stds to match the original flat rewards tensor shape
    mean_rewards = mean_rewards.repeat_interleave(num_samples, dim=0) # (B*G,)
    std_rewards = std_rewards.repeat_interleave(num_samples, dim=0) # (B*G,)
    # normalize rewards to get advantages
    advantages = (rewards - mean_rewards) / (std_rewards + 1e-5) # (B*G,)
    return advantages

In [None]:
def compute_log_probs(model, outputs, prompt_length):
    """ 
    Calcule les log-probabilités (sur tout le vocabulaire) des tokens des réponses.
    """
    logits, _ = model(outputs) #(P+A, B*G, V)
    # logits.shape = (prompt_length + answers_length + 1, batch_size * num_samples, vocab_size)

    # we only need the log probabilities for the new tokens
    # this introduces a shift: the logits for a position are the predictions for the next token
    logits = logits[prompt_length-1:-1, :, :] #(A, B*G, V)
    # logits.shape = (answers_length + 1, batch_size * num_samples, vocab_size)

    # convert raw logits into log probabilities along the vocabulary axis
    log_probs = F.log_softmax(logits, dim=-1) #(A, B*G, V)
    # log_probs.shape = (answers_length + 1, batch_size * num_samples, vocab_size)
    return log_probs #(A, B*G, V)

In [None]:
def compute_loss(advantages, log_probs, responses):
    """ 
    Calcule la loss d'un batch
    """
    # reshape responses from (answers_length + 1, batch_size * num_samples)
    # to (answers_length + 1, batch_size * num_samples, 1) for gathering
    responses = responses.unsqueeze(-1) #(A, B*G, 1)
    # log_probs.shape = (answers_length + 1, batch_size * num_samples, vocab_size) (A, B*G, V)
    # responses.shape = (answers_length + 1, batch_size * num_samples) 
    
    # gather the log probability for each token in responses
    selected_log_probs = log_probs.gather(dim=-1, index=responses)
    # remove the extra last dimension to get back to shape (answers_length + 1, batch_size * num_samples).
    selected_log_probs = selected_log_probs.squeeze(-1) # (A, B*G)

    # normalize
    selected_log_probs = (selected_log_probs - selected_log_probs.mean(-1, keepdim=True)) / (selected_log_probs.std(-1, keepdim=True) + 1e-5) # (A, B*G)

    # advantages.shape = (batch_size * num_samples) (B*G)
    # we use the same advantages for all tokens in the response
    loss = -(advantages.unsqueeze(dim=0) * selected_log_probs).mean()
    return loss

In [58]:
### MON CODE

def compute_loss_DeepSeek(advantages, responses, answers_length, log_probs, log_probs_old, log_probs_ref):
    """ 
    Calcule la loss GRPO (DeepSeek) d'un batch, en utilisant les politiques du modèle, de model_ref et model_old.
    Args:
        advantages: (B*G) avantage de chaque séquence générée
        responses: (A, B*G) token générés par le modèle old pour chaque séquence générée
        answers_length: int longueur des réponses générées (complétée par du padding)
        log_probs: (A, B*G, V) log-probabilités des tokens du vocabulaire, pour chaque étape de la génération par le modèle (mis à jour plusieurs fois par batch)
        log_probs_old: (A, B*G, V) même chose pour le modèle old (mis à jour une fois par batch)
        log_probs_ref: (A, B*G, V) même chose pour le modèle de réference (mis à jour une fois par époque)

    """
    A, B, G = answers_length+1, batch_size, num_samples # dimensions
    eps = 0.2 # clipping coefficient
    beta = 0.4 # KL coefficient

    responses = responses.unsqueeze(-1) # (A, B*G) -> (A, B*G, 1) pour gathering

    # log_probs dimensions: (A, B*G, V)

    # Récupère et reformatte les log-probs des tokens générés par le modèle
    selected_log_probs = log_probs.gather(dim=-1, index=responses) # selectionne uniquement les probabilités des tokens générés
    selected_log_probs = selected_log_probs.squeeze(-1) # (A, B*G)
    selected_log_probs = selected_log_probs.view(A, B, G).permute(1, 2, 0) # reformatte en (B, G, A) pour le calcul de la loss

    # Récupère et reformatte les log-probs des tokens générés par le model_old
    selected_log_probs_old = log_probs_old.gather(dim=-1, index=responses)
    selected_log_probs_old = selected_log_probs_old.squeeze(-1) # (A, B*G)
    selected_log_probs_old = selected_log_probs_old.view(A, B, G).permute(1, 2, 0) # (A, B*G) -> (B, G, A) pour uniformiser les log_probs

    # Récupère et reformatte les log-probs des tokens générés par le model_ref
    selected_log_probs_ref = log_probs_ref.gather(dim=-1, index=responses)
    selected_log_probs_ref = selected_log_probs_ref.squeeze(-1) # (A, B*G)
    selected_log_probs_ref = selected_log_probs_ref.view(A, B, G).permute(1, 2, 0) # reformatte en (B, G, A) pour le calcul de la loss

    # Expension des avantages pour avoir le même format que les politiques
    advantages = advantages.view(B, G) # (B*G)  -> (B, G)
    advantages = advantages.unsqueeze(-1).expand(-1, -1, A) # (B, G) -> (B, G, A) on étend les avantages sur les token pour les mutliplier avec les ratios

    # Calcul des ratios de politiques entre model et model_old sur les token générés
    ratios_old = torch.exp(selected_log_probs - selected_log_probs_old) # (B, G, A) rapport des probabilités sur les tokens générés entre model et model_old
    clipped_ratios = torch.clamp(ratios_old, 1 - eps, 1 + eps) # (B, G, A) clipped ratio

    coeff_1 = ratios_old * advantages # (B, G, A)
    coeff_2 = clipped_ratios * advantages # (B, G, A)

    # Calcul de la KL divergence entre la policy du modèle et celle de model_ref
    KL = torch.exp(selected_log_probs_ref - selected_log_probs) - (selected_log_probs_ref - selected_log_probs) -1 # (B, G, A) Estimateur différentiable de la KL
    
    loss_tensor = torch.minimum(coeff_1, coeff_2) - beta*KL # (B, G, A)
    loss = - loss_tensor.mean(dim=-1).mean(dim=-1).sum() #moyenne selon A puis G puis on fait la somme des moyennes sur G

    return loss
  
    

In [104]:
import copy

class GRPOTrainer:
    def __init__(self, model, ):

        self.model = model
        self.ref_model = None # reference model
        self.learning_rate = 1e-4
        self.optimizer = torch.optim.AdamW(self.model.parameters(), lr=self.learning_rate)
        self.epochs = 20
        self.batch_size = 16 # B
        self.num_samples = 5 # G
        self.temperature = .8
        self.num_iterations=3

    def calculate_grpo_advantages(self, rewards):
        # rewards (B*G,)
        mean_rewards = rewards.view(-1, self.num_samples).mean(dim=1) # (B,)
        std_rewards = rewards.view(-1, self.num_samples).std(dim=1) # (B,)
        # expand the means and stds to match the original flat rewards tensor shape
        mean_rewards = mean_rewards.repeat_interleave(self.num_samples, dim=0) # (B*G,)
        std_rewards = std_rewards.repeat_interleave(self.num_samples, dim=0) # (B*G,)
        # normalize rewards to get advantages
        advantages = (rewards - mean_rewards) / (std_rewards + 1e-5) # (B*G,)
        return advantages
    
    def compute_log_probs(self, model, outputs, prompt_length):
        """ 
        Calcule les log-probabilités (sur tout le vocabulaire) des tokens des réponses.
        """
        logits, _ = model(outputs) #(P+A, B*G, V)

        # we only need the log probabilities for the new tokens
        # this introduces a shift: the logits for a position are the predictions for the next token
        logits = logits[prompt_length-1:-1, :, :] #(A, B*G, V)

        # convert raw logits into log probabilities along the vocabulary axis
        log_probs = F.log_softmax(logits, dim=-1) #(A, B*G, V)
        return log_probs #(A, B*G, V)

    def compute_loss_deepseek(self, advantages, responses, answers_length, log_probs, log_probs_old, log_probs_ref):
        """ 
        Calcule la loss GRPO (DeepSeek) d'un batch, en utilisant les politiques du modèle, de model_ref et model_old.
        Args:
            advantages: (B*G) avantage de chaque séquence générée
            responses: (A, B*G) token générés par le modèle old pour chaque séquence générée
            answers_length: int longueur des réponses générées (complétée par du padding)
            log_probs: (A, B*G, V) log-probabilités des tokens du vocabulaire, pour chaque étape de la génération par le modèle (mis à jour plusieurs fois par batch)
            log_probs_old: (A, B*G, V) même chose pour le modèle old (mis à jour une fois par batch)
            log_probs_ref: (A, B*G, V) même chose pour le modèle de réference (mis à jour une fois par époque)

        """
        A, B, G = answers_length+1, self.batch_size, self.num_samples # dimensions
        eps = 0.2 # clipping coefficient
        beta = 0.4 # KL coefficient

        responses = responses.unsqueeze(-1) # (A, B*G) -> (A, B*G, 1) pour gathering

        # log_probs dimensions: (A, B*G, V)

        # Récupère et reformatte les log-probs des tokens générés par le modèle
        selected_log_probs = log_probs.gather(dim=-1, index=responses) # selectionne uniquement les probabilités des tokens générés
        selected_log_probs = selected_log_probs.squeeze(-1) # (A, B*G)
        selected_log_probs = selected_log_probs.view(A, B, G).permute(1, 2, 0) # reformatte en (B, G, A) pour le calcul de la loss

        # Récupère et reformatte les log-probs des tokens générés par le model_old
        selected_log_probs_old = log_probs_old.gather(dim=-1, index=responses)
        selected_log_probs_old = selected_log_probs_old.squeeze(-1) # (A, B*G)
        selected_log_probs_old = selected_log_probs_old.view(A, B, G).permute(1, 2, 0) # (A, B*G) -> (B, G, A) pour uniformiser les log_probs

        # Récupère et reformatte les log-probs des tokens générés par le model_ref
        selected_log_probs_ref = log_probs_ref.gather(dim=-1, index=responses)
        selected_log_probs_ref = selected_log_probs_ref.squeeze(-1) # (A, B*G)
        selected_log_probs_ref = selected_log_probs_ref.view(A, B, G).permute(1, 2, 0) # reformatte en (B, G, A) pour le calcul de la loss

        # Expension des avantages pour avoir le même format que les politiques
        advantages = advantages.view(B, G) # (B*G)  -> (B, G)
        advantages = advantages.unsqueeze(-1).expand(-1, -1, A) # (B, G) -> (B, G, A) on étend les avantages sur les token pour les mutliplier avec les ratios

        # Calcul des ratios de politiques entre model et model_old sur les token générés
        ratios_old = torch.exp(selected_log_probs - selected_log_probs_old) # (B, G, A) rapport des probabilités sur les tokens générés entre model et model_old
        clipped_ratios = torch.clamp(ratios_old, 1 - eps, 1 + eps) # (B, G, A) clipped ratio

        coeff_1 = ratios_old * advantages # (B, G, A)
        coeff_2 = clipped_ratios * advantages # (B, G, A)

        # Calcul de la KL divergence entre la policy du modèle et celle de model_ref
        KL = torch.exp(selected_log_probs_ref - selected_log_probs) - (selected_log_probs_ref - selected_log_probs) -1 # (B, G, A) Estimateur différentiable de la KL
        
        loss_tensor = torch.minimum(coeff_1, coeff_2) - beta*KL # (B, G, A)
        loss = - loss_tensor.mean(dim=-1).mean(dim=-1).sum() #moyenne selon A puis G puis on fait la somme des moyennes sur G

        return loss

    def train_step(self, prompts, prompt_length, answers_length, questions, answers):
        """
            Do a traing step with multiple iterations on a specific batch.
            Args:
                prompts: prompts to optimize the model on it
        """
        
        self.model.train()

        # generate samples for each prompt in inference_mode as, it is just a sampling step
        #with torch.inference_mode():
        outputs = generate(self.model,
                            prompts,
                            new_tokens = answers_length + 1,
                            mode = "sampling",
                            num_samples = self.num_samples,
                            temperature = temperature) # (P+A, B*G)
        
        responses = outputs[prompt_length:, :] # (A, B*G)

        text_outputs = [tokenizer.decode(outputs[prompt_length:, i].tolist())
                        for i in range(outputs.size(1))]

        # Calcul de log_probs_old log_probs_ref sans backpropagation pour la loss
        self.model.eval()
        self.model_ref.eval()
        with torch.inference_mode():
            log_probs_old = self.compute_log_probs(self.model, outputs, prompt_length) # (A, B*G, V)
            log_probs_ref = self.compute_log_probs(self.model_ref, outputs, prompt_length) # (A, B*G, V)

        # Calcul des avantages
        rewards = compute_rewards(text_outputs, answers) # (B*G)
        advantages = self.calculate_grpo_advantages(rewards) # (B*G)

        for i in range(self.num_iterations): # num_optimizations steps of the model on the same batch
            #print(f"GRPO iteration {i}")
            self.optimizer.zero_grad()  

            # Calcul de log_probs avec le modèle à jour
            log_probs = self.compute_log_probs(self.model, outputs, prompt_length) # (A, B*G, V)
                
            # compute loss
            loss = self.compute_loss_deepseek(advantages, responses, answers_length, log_probs, log_probs_old, log_probs_ref)

            #print(f"Loss iteration {i}, {loss}")

            # optimize
            loss.backward()
            self.optimizer.step()

        return loss.item()
    
    def train(self, verbose = False):

        print(f"{(len(data_train) - 1)//self.batch_size} batches to go!")

        best_test_accuracy = None
        test_accuracy = evaluate()
        print('-' * 89)
        print('| initialisation | test accuracy {:5.2f}'.format(test_accuracy))
        print('-' * 89)

        # switch eval for train model (enables dropout)
        self.model_ref = copy.deepcopy(model)
        epochs = self.epochs

        for epoch in range(1, epochs+1):
            epoch_start_time = time.time()
            start_time = time.time()

            total_loss=0
            for batch, i in enumerate(range(0, len(data_train) - 1, self.batch_size)):

                # get a batch of prompts and answers
                prompts, _, prompt_length, answers_length, questions, answers = get_batch("train", i, self.batch_size)
                prompts = prompts.to(device) # (prompt_length, batch_size)

                # perform train_step
                loss = self.train_step(prompts, prompt_length, answers_length, questions, answers)
                total_loss += loss
                print(f"Batch {batch} loss : {loss}")
            
            # Update ref_model
            self.model_ref = copy.deepcopy(model)

            avg_loss = total_loss / len(data_train)
            print(f"Epoch {epoch+1}/{epochs} - Loss: {avg_loss:.4f}")            

            test_accuracy = evaluate()
            print('-' * 89)
            print('| end of epoch {:3d} | time: {:5.2f}s | test accuracy {:5.2f}'.format(epoch, (time.time() - epoch_start_time), test_accuracy))
            print('-' * 89)


Trainer =  GRPOTrainer(model)
Trainer.train()
          

3599 batches to go!
-----------------------------------------------------------------------------------------
| initialisation | test accuracy  0.01
-----------------------------------------------------------------------------------------
Batch 0 loss : -0.511246919631958
Batch 1 loss : -0.23518361151218414
Batch 2 loss : -0.16449134051799774
Batch 3 loss : 0.031492091715335846
Batch 4 loss : -0.04103913903236389
Batch 5 loss : 0.23252439498901367
Batch 6 loss : 0.22064754366874695
Batch 7 loss : 0.15894721448421478
Batch 8 loss : 0.08772365003824234
Batch 9 loss : 0.14583761990070343
Batch 10 loss : 0.3208909034729004
Batch 11 loss : 0.3034358024597168
Batch 12 loss : 0.23609769344329834
Batch 13 loss : 0.27937084436416626
Batch 14 loss : 0.17302371561527252
Batch 15 loss : 0.30169597268104553
Batch 16 loss : 0.1210220530629158
Batch 17 loss : 0.31678953766822815
Batch 18 loss : 0.38117867708206177
Batch 19 loss : 0.2225230634212494
Batch 20 loss : 0.3279609680175781
Batch 21 loss : 0

KeyboardInterrupt: 

In [87]:
Trainer =  GRPOTrainer(model)
Trainer.train()

-----------------------------------------------------------------------------------------
| initialisation | test accuracy  0.85
-----------------------------------------------------------------------------------------
Batch 0
torch.Size([5, 40, 14]) torch.Size([5, 40, 14])
GRPO iteration 0
 log_probs, log_probs_old, log_probs_ref torch.Size([5, 40, 14]) torch.Size([5, 40, 14]) torch.Size([5, 40, 14])
A, B, G 5 16 16 torch.Size([5, 40])


RuntimeError: shape '[5, 16, 16]' is invalid for input of size 200

In [70]:
train_DeepSeek_GRPO(verbose = False)

-----------------------------------------------------------------------------------------
| initialisation | test accuracy  0.85
-----------------------------------------------------------------------------------------
GRPO iteration 0
GRPO iteration 1


RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.FloatTensor [128, 14]], which is output 0 of AsStridedBackward0, is at version 36014; expected version 36012 instead. Hint: enable anomaly detection to find the operation that failed to compute its gradient, with torch.autograd.set_detect_anomaly(True).

In [None]:
def train_vanilla_GRPO(verbose = False):
    optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)

    best_test_accuracy = None
    test_accuracy = evaluate()
    print('-' * 89)
    print('| initialisation | test accuracy {:5.2f}'.format(test_accuracy))
    print('-' * 89)

    # switch eval for train model (enables dropout)
    model.train()

    for epoch in range(1, epochs+1):
        epoch_start_time = time.time()
        start_time = time.time()
        for batch, i in enumerate(range(0, len(data_train) - 1, batch_size)):

            # get a batch of prompts and answers
            prompts, _, prompt_length, answers_length, questions, answers = get_batch("train", i, batch_size)
            prompts = prompts.to(device) # (prompt_length, batch_size)

            # generate samples for each prompt
            outputs = generate(model,
                               prompts,
                               new_tokens = answers_length + 1,
                               mode = "sampling",
                               num_samples = num_samples,
                               temperature = temperature) # (P+A, B*G)
            # outputs.shape = (prompt_length + answers_length + 1, batch_size * num_samples)
            text_outputs = [tokenizer.decode(outputs[prompt_length:, i].tolist())
                            for i in range(outputs.size(1))]

            # compute rewards
            rewards = compute_rewards(text_outputs, answers) # (B*G)

            # compute advantages
            advantages = calculate_grpo_advantages(rewards) # (B*G)

            # compute log probabilities pour tous les tokens des réponses
            log_probs = compute_log_probs(model, outputs, prompt_length) # (A, B*G, V)

            # compute loss
            responses = outputs[prompt_length:, :] # (A, B*G)
            loss = compute_loss(advantages, log_probs, responses)

            # optimize
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()

            if i % log_interval == 0 and batch > 0:
                elapsed = time.time() - start_time
                print('| {:5d}/{:5d} batches | ms/batch {:5.2f}'.format(batch, len(data_train) // batch_size, elapsed))
                if verbose:
                    print("\nquestion:", questions[0],
                      "\nanswer", answers[0],
                      "\noutput:", text_outputs[:num_samples],
                      "\nreward:", rewards[:num_samples],
                      "\nadvantage:", advantages[:num_samples], "\n")

                start_time = time.time()
        test_accuracy = evaluate()
        print('-' * 89)
        print('| end of epoch {:3d} | time: {:5.2f}s | test accuracy {:5.2f}'.format(epoch, (time.time() - epoch_start_time), test_accuracy))
        print('-' * 89)
        # Save the model if the test accuracy is the best we've seen so far.
        if not best_test_accuracy or test_accuracy < best_test_accuracy:
            with open("arithmetic_vanilla_GRPO.pt", 'wb') as f:
                torch.save(model, f)
            best_test_accuracy = test_accuracy


In [48]:
train_vanilla_GRPO(verbose = False)

-----------------------------------------------------------------------------------------
| initialisation | test accuracy  0.85
-----------------------------------------------------------------------------------------


NameError: name 'compute_loss' is not defined

In [353]:
model.eval()

for i in range(20):
    prompt, answers = data_test[i]
    prompt_tensor = torch.tensor(tokenizer.encode(prompt)).view((-1,1))
    output = generate(model, prompt_tensor, len(answers) + 1).view((1,-1))
    print(tokenizer.decode(output.tolist()[0]) + "\t actual result: " + answers)

128+705=833[EOS]	 actual result: 833
953+527=1480[EOS]	 actual result: 1480
262+889=1151[EOS]	 actual result: 1151
222+843=1065[EOS]	 actual result: 1065
777+632=1409[EOS]	 actual result: 1409
719+662=1381[EOS]	 actual result: 1381
789+402=1191[EOS]	 actual result: 1191
587+869=1456[EOS]	 actual result: 1456
386+632=1018[EOS]	 actual result: 1018
038+837=875[EOS]	 actual result: 875
200+272=472[EOS]	 actual result: 472
680+684=1364[EOS]	 actual result: 1364
594+534=1128[EOS]	 actual result: 1128
350+593=943[EOS]	 actual result: 943
867+755=1622[EOS]	 actual result: 1622
978+223=1201[EOS]	 actual result: 1201
571+512=1083[EOS]	 actual result: 1083
148+465=613[EOS]	 actual result: 613
873+373=1246[EOS]	 actual result: 1246
700+356=1056[EOS]	 actual result: 1056
