# GRPO Training project: teach an LLM to do additions, again

In this notebook, you'll find:
* A basic Transformer with basic tokenizer
* A basic dataset for additions
* A classical pre-trainer, minimizing cross entropy loss
* A Vanilla GRPO

You're not supposed to edit the existing code (you can if you want to...).
You should implement one (or more) of the following:
* GRPO with PPO (the `usual` one)
* RLOO
* ReMax
* DPO
* RAFT
* your own RLHF method!

In [1]:
import torch
from torch import nn
from torch.nn import functional as F

import random
import math
import re
import time

In [2]:
num_digits = 3

dataset_size = 64_000
train_proportion = 0.9

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


## Step 1: Construct a tokenizer

In [4]:
pad_token="[PAD]"
eos_token="[EOS]"

In [7]:
class character_level_tokenizer:
    """
    character-level
    """
    def __init__(self):
        self.vocab = [str(x) for x in range(10)] + ["+", "="] + [pad_token, eos_token]
        self.token_to_id = {v : k for k, v in enumerate(self.vocab)}
        self.id_to_token = {k : v for k, v in enumerate(self.vocab)}
        self.ntokens = len(self.vocab)
        self.pattern = f"[^{re.escape(''.join(self.vocab))}]"

    def clean(self, text):
        """
        removes all characters not in the vocabulary
        """
        out = re.sub(self.pattern, "", text)
        return out

    def pre_tokenization(self, text):
        """
        character-level
        """
        return [c for c in text]

    def encode(self, text):
        text_list = self.pre_tokenization(self.clean(text))
        return [self.token_to_id[c] for c in text_list]

    def decode(self, token_list):
        return "".join([self.id_to_token[x] for x in token_list])

In [8]:
tokenizer = character_level_tokenizer()
ntokens = tokenizer.ntokens
ntokens

14

In [9]:
prompt = "12 + 42 ="
inputs = tokenizer.encode(prompt)
inputs, tokenizer.decode(inputs)

([1, 2, 10, 4, 2, 11], '12+42=')

## Step 2: Create a dataset for arithmetic operations

In [10]:
def sample_datapoint(num_digits = 3):
    a_list = [random.randint(0, 9) for _ in range(num_digits)]
    b_list = [random.randint(0, 9) for _ in range(num_digits)]
    a_int = int("".join([str(x) for x in a_list]))
    b_int = int("".join([str(x) for x in b_list]))
    a_str = "".join([str(x) for x in a_list])
    b_str = "".join([str(x) for x in b_list])
    sum_int = a_int + b_int
    return (a_str + "+" + b_str + "=", str(sum_int))

sample_datapoint(3)

('616+989=', '1605')

In [11]:
data = []
for _ in range(dataset_size):
    data.append(sample_datapoint(num_digits))
data[:4]

[('681+848=', '1529'),
 ('034+937=', '971'),
 ('912+897=', '1809'),
 ('222+748=', '970')]

In [12]:
data_train = data[: int(train_proportion * dataset_size)]
data_test = data[int(train_proportion * dataset_size):]

len(data_train),len(data_test)

(57600, 6400)

## Step 3: Construct a model

In [13]:
class PositionalEncoding(nn.Module):
    r"""Inject some information about the relative or absolute position of the tokens in the sequence.
        The positional encodings have the same dimension as the embeddings, so that the two can be summed.
        Here, we use sine and cosine functions of different frequencies.
    .. math:
        \text{PosEncoder}(pos, 2i) = sin(pos/10000^(2i/d_model))
        \text{PosEncoder}(pos, 2i+1) = cos(pos/10000^(2i/d_model))
        \text{where pos is the word position and i is the embed idx)
    Args:
        d_model: the embed dim (required).
        dropout: the dropout value (default=0.1).
        max_len: the max. length of the incoming sequence (default=5000).
    """

    def __init__(self, d_model, dropout=0.1, max_len=5000):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)

        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).transpose(0, 1)
        self.register_buffer('pe', pe)

    def forward(self, x):
        r"""Inputs of forward function
        Args:
            x: the sequence fed to the positional encoder model (required).
        Shape:
            x: [sequence length, batch size, embed dim]
            output: [sequence length, batch size, embed dim]
        """

        x = x + self.pe[:x.size(0), :]
        return self.dropout(x)

In [14]:
class TransformerModel(nn.Transformer):
    def __init__(self, ntoken, ninp, nhead, nhid, nlayers, dropout=0.5):
        super(TransformerModel, self).__init__(d_model=ninp,
                                               nhead=nhead,
                                               dim_feedforward=nhid,
                                               num_encoder_layers=nlayers)
        self.input_emb = nn.Embedding(ntoken, ninp)
        self.pos_encoder = PositionalEncoding(ninp, dropout)
        self.decoder = nn.Linear(ninp, ntoken)

        self.ninp = ninp
        self.init_weights()

    def init_weights(self):
        initrange = 0.1
        nn.init.uniform_(self.input_emb.weight, -initrange, initrange)
        nn.init.zeros_(self.decoder.bias)
        nn.init.uniform_(self.decoder.weight, -initrange, initrange)

    def _generate_square_subsequent_mask(self, sz):
        return torch.log(torch.tril(torch.ones(sz,sz)))

    def forward(self, src):
        mask = self._generate_square_subsequent_mask(len(src)).to(device)
        self.src_mask = mask

        src = self.input_emb(src) * math.sqrt(self.ninp)
        src = self.pos_encoder(src)
        output_enc = self.encoder(src, mask=self.src_mask)
        output_dec = self.decoder(output_enc)
        return F.log_softmax(output_dec, dim=-1), output_enc

In [15]:
model = TransformerModel(ntoken = ntokens,
                         ninp = 128,
                         nhead = 16,
                         nhid = 64,
                         nlayers = 8)
model.to(device)



TransformerModel(
  (encoder): TransformerEncoder(
    (layers): ModuleList(
      (0-7): 8 x TransformerEncoderLayer(
        (self_attn): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=128, out_features=128, bias=True)
        )
        (linear1): Linear(in_features=128, out_features=64, bias=True)
        (dropout): Dropout(p=0.1, inplace=False)
        (linear2): Linear(in_features=64, out_features=128, bias=True)
        (norm1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
        (norm2): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
        (dropout1): Dropout(p=0.1, inplace=False)
        (dropout2): Dropout(p=0.1, inplace=False)
      )
    )
    (norm): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
  )
  (decoder): Linear(in_features=128, out_features=14, bias=True)
  (input_emb): Embedding(14, 128)
  (pos_encoder): PositionalEncoding(
    (dropout): Dropout(p=0.5, inplace=False)
  )
)

In [16]:
print("number of parameters: {}".format(sum([x.numel() for x in model.parameters()])))

number of parameters: 668942


### Useful functions

In [17]:
def generate(model, prompts, new_tokens = 5, mode = "greedy", num_samples = 1, temperature = 0.8):
    input_tensor = torch.repeat_interleave(prompts, repeats = num_samples, dim = 1).to(device)
    # (prompt_length, batch_size * num_samples)
    for _ in range(new_tokens):
        output, _ = model(input_tensor) # (prompt_length, batch_size * num_samples, ntokens)
        logits = output[-1,:,:] # (batch_size * num_samples, ntokens)
        if mode == "greedy":
            tokens = torch.argmax(logits, -1).view((1,-1)) # (1, batch_size * num_samples)
        else: # mode == "sampling"
            logits /= temperature
            probs = torch.softmax(logits, dim=-1)
            tokens = torch.multinomial(probs, num_samples = 1).view((1,-1)) # (1, batch_size * num_samples)
        input_tensor = torch.cat((input_tensor, tokens), 0)
    return input_tensor

In [18]:
model.eval()

prompt = "2+3="
prompt_tensor = torch.tensor(tokenizer.encode(prompt)).view((-1,1))
output = generate(model, prompt_tensor).view((1,-1))
output, tokenizer.decode(output[0].tolist())

(tensor([[ 2, 10,  3, 11, 12, 10, 10, 12, 10]], device='cuda:0'),
 '2+3=[PAD]++[PAD]+')

In [19]:
prompt_tensor

tensor([[ 2],
        [10],
        [ 3],
        [11]])

In [20]:
def pad(token_list, type_list = "prompts"):
    max_length = max([len(x) for x in token_list])
    out = []
    for x in token_list:
        if type_list == "prompts":
            out.append([tokenizer.token_to_id[pad_token]] * (max_length - len(x)) + x)
        if type_list == "answers":
            out.append(x + [tokenizer.token_to_id[eos_token]] + [tokenizer.token_to_id[pad_token]] * (max_length - len(x)))
    return out, max_length

In [21]:
prompts = [tokenizer.encode("1+1="), tokenizer.encode("21+35=")]
answers = [tokenizer.encode("2"), tokenizer.encode("56")]
padded_prompts, _ = pad(prompts, "prompts")
padded_answers, _ = pad(answers, "answers")
padded_prompts, padded_answers
[tokenizer.decode(p) for p in padded_prompts], [tokenizer.decode(p) for p in padded_answers]

(['[PAD][PAD]1+1=', '21+35='], ['2[EOS][PAD]', '56[EOS]'])

In [22]:
def get_batch(split, i, batch_size):
    data = data_train if split == 'train' else data_test

    prompts = [data[i][0] for i in range(i, i + batch_size)]
    encoded_prompts = [tokenizer.encode(prompt) for prompt in prompts]
    padded_prompts, prompt_length = pad(encoded_prompts, "prompts")

    answers = [data[i][1] for i in range(i, i + batch_size)]
    encoded_answers = [tokenizer.encode(answer) for answer in answers]
    padded_answers, answers_length = pad(encoded_answers, "answers")

    X = torch.stack([torch.tensor(x) for x in padded_prompts], 1)
    Y = torch.stack([torch.tensor(x) for x in padded_answers], 1)
    return X, Y, prompt_length, answers_length, prompts, answers

In [23]:
X, Y, prompt_length, answers_length, prompts, answers = get_batch("train", 43, 16)
X.shape, Y.shape, prompt_length, answers_length, prompts[0], answers[0]

(torch.Size([8, 16]), torch.Size([5, 16]), 8, 4, '661+105=', '766')

## Step 4: Evaluate

In [24]:
batch_size = 16

In [25]:
def evaluate(batch_size = batch_size):
    # Turn on evaluation mode disables dropout.
    model.eval()
    correct = 0.
    with torch.no_grad():
        for batch, i in enumerate(range(0, len(data_test) - 1, batch_size)):
            prompts, target_answers, prompt_length, answers_length, _, _ = get_batch("test", i, batch_size)
            prompts = prompts.to(device) # (prompt_length, batch_size)
            target_answers = target_answers.to(device) # (answers_length + 1, batch_size)
            output = generate(model, prompts, answers_length + 1) # (prompt_length + answers_length + 1, batch_size)
            answers_tokens = output[prompt_length:, :] # (answers_length + 1, batch_size), contains tokens
            equality_test = answers_tokens == target_answers # (answers_length + 1, batch_size), contains boolean values
            correct += torch.all(equality_test, axis=0).float().sum()
        accuracy = correct / len(data_test)
    return accuracy.item()

In [26]:
evaluate()

0.0

## Step 5: Train the model, classical approach

### Hyperparameters

In [27]:
epochs = 5
batch_size = 16
learning_rate = 8e-4

reporting_per_epoch = 5
log_interval = len(data_train) // (reporting_per_epoch + 1)
assert(log_interval % batch_size == 0)

In [28]:
def train():
    model.train()
    optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)

    best_test_accuracy = None
    test_accuracy = evaluate()
    print('-' * 89)
    print('| initialisation | test accuracy {:5.2f}'.format(test_accuracy))
    print('-' * 89)
    for epoch in range(1, epochs+1):
        epoch_start_time = time.time()
        total_loss = 0.
        start_time = time.time()
        for batch, i in enumerate(range(0, len(data_train) - 1, batch_size)):
            prompts, target_answers, prompt_length, answers_length, _, _ = get_batch("train", i, batch_size)
            prompts = prompts.to(device) # (prompt_length, batch_size)
            target_answers = target_answers.to(device) # (answers_length + 1, batch_size)
            input_tensor = torch.cat((prompts, target_answers), 0) # (prompt_length + answers_length + 1, batch_size)
            model.zero_grad()
            output, _ = model(input_tensor) # (prompt_length + answers_length + 1, batch_size, ntokens)
            output_answers = output[prompt_length-1:-1,:,:].reshape(-1, ntokens) # ((answers_length + 1) * batch_size, ntokens)
            target_answers = target_answers.view(-1)
            loss = F.cross_entropy(output_answers, target_answers)
            loss.backward()
            optimizer.step()

            total_loss += loss.item()

            if i % log_interval == 0 and batch > 0:
                cur_loss = total_loss / log_interval
                elapsed = time.time() - start_time
                print('| {:5d}/{:5d} batches | ms/batch {:5.2f} | loss {:5.2f} | perplexity {:8.2f}'.format(batch, len(data_train) // batch_size,
                                                                                                            elapsed * 1000 / log_interval, cur_loss, math.exp(cur_loss)))
                total_loss = 0
                start_time = time.time()
        test_accuracy = evaluate()
        print('-' * 89)
        print('| end of epoch {:3d} | time: {:5.2f}s | test accuracy {:5.2f}'.format(epoch, (time.time() - epoch_start_time), test_accuracy))
        print('-' * 89)
        # Save the model if the test accuracy is the best we've seen so far.
        if not best_test_accuracy or test_accuracy < best_test_accuracy:
            with open("arithmetic.pt", 'wb') as f:
                torch.save(model, f)
            best_test_accuracy = test_accuracy

In [29]:
train()

-----------------------------------------------------------------------------------------
| initialisation | test accuracy  0.00
-----------------------------------------------------------------------------------------
|   600/ 3600 batches | ms/batch  1.92 | loss  0.09 | perplexity     1.10
|  1200/ 3600 batches | ms/batch  1.86 | loss  0.07 | perplexity     1.08
|  1800/ 3600 batches | ms/batch  1.81 | loss  0.07 | perplexity     1.07
|  2400/ 3600 batches | ms/batch  1.80 | loss  0.07 | perplexity     1.07
|  3000/ 3600 batches | ms/batch  1.77 | loss  0.07 | perplexity     1.07
-----------------------------------------------------------------------------------------
| end of epoch   1 | time: 122.32s | test accuracy  0.01
-----------------------------------------------------------------------------------------
|   600/ 3600 batches | ms/batch  1.78 | loss  0.07 | perplexity     1.07
|  1200/ 3600 batches | ms/batch  1.75 | loss  0.06 | perplexity     1.07
|  1800/ 3600 batches | ms

In [30]:
model.eval()

for i in range(20):
    prompt, answers = data_test[i]
    prompt_tensor = torch.tensor(tokenizer.encode(prompt)).view((-1,1))
    output = generate(model, prompt_tensor, len(answers) + 1).view((1,-1))
    print(tokenizer.decode(output.tolist()[0]) + "\t actual result: " + answers)

540+644=1184[EOS]	 actual result: 1184
687+185=872[EOS]	 actual result: 872
504+792=1295[EOS]	 actual result: 1296
887+025=912[EOS]	 actual result: 912
200+564=775[EOS]	 actual result: 764
256+441=697[EOS]	 actual result: 697
660+030=681[EOS]	 actual result: 690
796+187=983[EOS]	 actual result: 983
599+686=1284[EOS]	 actual result: 1285
507+187=684[EOS]	 actual result: 694
163+583=749[EOS]	 actual result: 746
224+304=528[EOS]	 actual result: 528
790+006=899[EOS]	 actual result: 796
268+223=491[EOS]	 actual result: 491
151+191=343[EOS]	 actual result: 342
559+183=742[EOS]	 actual result: 742
138+947=1084[EOS]	 actual result: 1085
225+967=1191[EOS]	 actual result: 1192
689+042=739[EOS]	 actual result: 731
809+127=935[EOS]	 actual result: 936


## Step 4 bis: Vanilla GRPO training

### Custom reward functions

In [31]:
def accuracy_reward(output, answer):
    pattern = r"\[EOS\]"
    output = re.sub(pattern, "", output)
    pattern = r"(\[PAD\])*$"
    output = re.sub(pattern, "", output)
    return 1. if output == answer else 0.

accuracy_reward("123[EOS][PAD][PAD]", "123"), accuracy_reward("123", "124"),

(1.0, 0.0)

In [214]:
def distance_accuracy_reward(output, answer):
    pattern = r"\[EOS\]"
    output = re.sub(pattern, "", output)
    pattern = r"(\[PAD\])"
    output = re.sub(pattern, "", output)
    if len(output) > 0:
        int_output = int(output)
    else:
        return 1
    int_answer = int(answer)
    return abs(int_output - int_answer) / max(int_output, int_answer)

distance_accuracy_reward("182[PAD]1", "123"), distance_accuracy_reward("123[PAD]", "124"),

(0.9324546952224053, 0.008064516129032258)

In [211]:
def digit_accuracy_reward(output, answer):
    pattern = r"\[EOS\]"
    output = re.sub(pattern, "", output)
    pattern = r"(\[PAD\])*$"
    output = re.sub(pattern, "", output)
    return sum(c1 == c2 for (c1,c2) in zip(output, answer)) / max(len(output), len(answer))

digit_accuracy_reward("123[EOS][PAD][PAD]", "123"), digit_accuracy_reward("123[EOS]", "123"),

(1.0, 1.0)

In [34]:
def reward_format(output):
    pattern = r"\d+\[EOS\](\[PAD\])*$"
    return 1. if bool(re.match(pattern, output)) else 0.

reward_format("123[EOS][PAD][PAD]"), reward_format("123[EOS]"), reward_format("123"),

(1.0, 1.0, 0.0)

### Hyperparameters

In [35]:
epochs = 20
batch_size = 16
learning_rate = 1e-4
num_samples = 16
temperature = .8

reporting_per_epoch = 5
log_interval = len(data_train) // (reporting_per_epoch + 1)
assert(log_interval % batch_size == 0)

reward_fun = digit_accuracy_reward
reward_format = reward_format

In [36]:
def compute_rewards(text_outputs, answers):
    repeated_answers = [answer for answer in answers for _ in range(num_samples)]
    rewards = torch.tensor(
        [0.2 * reward_format(output) + 0.8 * reward_fun(output, answer)
         for output, answer in zip(text_outputs, repeated_answers)],
        dtype=torch.float32,
        device=device
    )
    return rewards

In [37]:
def calculate_grpo_advantages(rewards):
    # reshape rewards to group by prompt
    # compute mean and standard deviation for each prompt group
    mean_rewards = rewards.view(-1, num_samples).mean(dim=1)
    std_rewards = rewards.view(-1, num_samples).std(dim=1)
    # expand the means and stds to match the original flat rewards tensor shape
    mean_rewards = mean_rewards.repeat_interleave(num_samples, dim=0)
    std_rewards = std_rewards.repeat_interleave(num_samples, dim=0)
    # normalize rewards to get advantages
    advantages = (rewards - mean_rewards) / (std_rewards + 1e-5)
    return advantages

In [38]:
def compute_log_probs(model, outputs, prompt_length):
    logits, _ = model(outputs)
    # logits.shape = (prompt_length + answers_length + 1, batch_size * num_samples, vocab_size)

    # we only need the log probabilities for the new tokens
    # this introduces a shift: the logits for a position are the predictions for the next token
    logits = logits[prompt_length-1:-1, :, :]
    # logits.shape = (answers_length + 1, batch_size * num_samples, vocab_size)

    # convert raw logits into log probabilities along the vocabulary axis
    log_probs = F.log_softmax(logits, dim=-1)
    # log_probs.shape = (answers_length + 1, batch_size * num_samples, vocab_size)
    return log_probs

In [39]:
def compute_loss(advantages, log_probs, responses):
    # reshape responses from (answers_length + 1, batch_size * num_samples)
    # to (answers_length + 1, batch_size * num_samples, 1) for gathering
    responses = responses.unsqueeze(-1)
    # log_probs.shape = (answers_length + 1, batch_size * num_samples, vocab_size)
    # responses.shape = (answers_length + 1, batch_size * num_samples)
    # gather the log probability for each token in responses
    selected_log_probs = log_probs.gather(dim=-1, index=responses)
    # remove the extra last dimension to get back to shape (answers_length + 1, batch_size * num_samples).
    selected_log_probs = selected_log_probs.squeeze(-1)

    # normalize
    selected_log_probs = (selected_log_probs - selected_log_probs.mean(-1, keepdim=True)) / (selected_log_probs.std(-1, keepdim=True) + 1e-5)

    # advantages.shape = (batch_size * num_samples)
    # we use the same advantages for all tokens in the response
    loss = -(advantages.unsqueeze(dim=0) * selected_log_probs).mean()
    return loss

In [40]:
def train_vanilla_GRPO(verbose = False):
    optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)

    best_test_accuracy = None
    test_accuracy = evaluate()
    print('-' * 89)
    print('| initialisation | test accuracy {:5.2f}'.format(test_accuracy))
    print('-' * 89)

    # switch eval for train model (enables dropout)
    model.train()

    for epoch in range(1, epochs+1):
        epoch_start_time = time.time()
        start_time = time.time()
        for batch, i in enumerate(range(0, len(data_train) - 1, batch_size)):

            # get a batch of prompts and answers
            prompts, _, prompt_length, answers_length, questions, answers = get_batch("train", i, batch_size)
            prompts = prompts.to(device) # (prompt_length, batch_size)

            # generate samples for each prompt
            outputs = generate(model,
                               prompts,
                               new_tokens = answers_length + 1,
                               mode = "sampling",
                               num_samples = num_samples,
                               temperature = temperature)
            # outputs.shape = (prompt_length + answers_length + 1, batch_size * num_samples)
            text_outputs = [tokenizer.decode(outputs[prompt_length:, i].tolist())
                            for i in range(outputs.size(1))]

            # compute rewards
            rewards = compute_rewards(text_outputs, answers)

            # compute advantages
            advantages = calculate_grpo_advantages(rewards)

            # compute log probabilities
            log_probs = compute_log_probs(model, outputs, prompt_length)
            # compute loss
            responses = outputs[prompt_length:, :]
            loss = compute_loss(advantages, log_probs, responses)

            # optimize
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()

            if i % log_interval == 0 and batch > 0:
                elapsed = time.time() - start_time
                print('| {:5d}/{:5d} batches | ms/batch {:5.2f}'.format(batch, len(data_train) // batch_size, elapsed))
                if verbose:
                    print("\nquestion:", questions[0],
                      "\nanswer", answers[0],
                      "\noutput:", text_outputs[:num_samples],
                      "\nreward:", rewards[:num_samples],
                      "\nadvantage:", advantages[:num_samples], "\n")

                start_time = time.time()
        test_accuracy = evaluate()
        print('-' * 89)
        print('| end of epoch {:3d} | time: {:5.2f}s | test accuracy {:5.2f}'.format(epoch, (time.time() - epoch_start_time), test_accuracy))
        print('-' * 89)
        # Save the model if the test accuracy is the best we've seen so far.
        if not best_test_accuracy or test_accuracy < best_test_accuracy:
            with open("arithmetic_vanilla_GRPO.pt", 'wb') as f:
                torch.save(model, f)
            best_test_accuracy = test_accuracy


In [127]:
train_vanilla_GRPO(verbose = False)

-----------------------------------------------------------------------------------------
| initialisation | test accuracy  0.97
-----------------------------------------------------------------------------------------


In [39]:
model.eval()

for i in range(20):
    prompt, answers = data_test[i]
    prompt_tensor = torch.tensor(tokenizer.encode(prompt)).view((-1,1))
    output = generate(model, prompt_tensor, len(answers) + 1).view((1,-1))
    print(tokenizer.decode(output.tolist()[0]) + "\t actual result: " + answers)

596+042=638[EOS]	 actual result: 638
361+026=387[EOS]	 actual result: 387
476+468=944[EOS]	 actual result: 944
183+460=643[EOS]	 actual result: 643
181+472=653[EOS]	 actual result: 653
197+314=511[EOS]	 actual result: 511
107+344=451[EOS]	 actual result: 451
463+717=1180[EOS]	 actual result: 1180
853+302=1155[EOS]	 actual result: 1155
733+425=1158[EOS]	 actual result: 1158
666+642=1308[EOS]	 actual result: 1308
429+503=932[EOS]	 actual result: 932
669+763=1432[EOS]	 actual result: 1432
099+807=906[EOS]	 actual result: 906
693+680=1373[EOS]	 actual result: 1373
969+035=1004[EOS]	 actual result: 1004
266+340=606[EOS]	 actual result: 606
349+083=432[EOS]	 actual result: 432
279+313=592[EOS]	 actual result: 592
804+300=1104[EOS]	 actual result: 1104


## DPO (Direct Preference Optimization)
The goal is to implement the RL training framework as in [Direct Preference Optimization: 
Your Language Model is Secretly a Reward Mode](https://arxiv.org/pdf/2305.18290).


The reason why I've chosen this method is because it appeared to me to be the most "original" one, in the sense that we don't learn directly a reward model upon which we will learn a policy. But we will rather learns a policy that implicitly learn to maximize this reward.g

In [41]:
epochs = 2
batch_size = 16
learning_rate = 1e-4
num_samples = 16
temperature = .8
epsilon = 0.2
reporting_per_epoch = 5
n_trajectoris = 10000
log_interval = len(data_train) // (reporting_per_epoch + 1)
assert(log_interval % batch_size == 0)

reward_fun = digit_accuracy_reward
reward_format = reward_format

In [330]:
model = TransformerModel(ntoken = ntokens,
                         ninp = 128,
                         nhead = 16,
                         nhid = 64,
                         nlayers = 8)
model.to(device)
with open("arithmetic.pt", 'rb') as f:
    model = torch.load(f, weights_only = False)


### 1. Generate a dataset of pairs

We'll generate a dataset, where a sample = [prompt, answer 1, answer2].

The answer 1 is better than answer 2 according to the function compute_rewards.

In [183]:
def compute_rewards_dpo(text_outputs, answers):
    repeated_answers = [answer for answer in answers for _ in range(num_samples)]
    rewards = torch.tensor(
        [distance_accuracy_reward(output, answer)
         for output, answer in zip(text_outputs, repeated_answers)],
        dtype=torch.float32,
        device=device
    )
    return rewards

In [216]:
from tqdm import tqdm
def generate_dataset(model):
    dataset = []
    for batch, i in tqdm(enumerate(range(0, len(data_train) - 1, batch_size))):
        prompts, _, prompt_length, answers_length, questions, answers = get_batch("train", i, batch_size)
        prompts = prompts.to(device) # prompt_length, batch_size
        outputs_ref = generate(model,
                               prompts,
                               new_tokens = answers_length + 1,
                               mode = "sampling",
                               num_samples = 2,
                               temperature = temperature)
        text_outputs_ref = [tokenizer.decode(outputs_ref[prompt_length:, i].tolist())
                            for i in range(outputs_ref.size(1))]
        rewards = compute_rewards_dpo(text_outputs_ref, answers) 

        j = 0
        for i in range(0, len(text_outputs_ref), 2):
            if rewards[i] > rewards[i + 1]: # This is not really a reward, more like a score.
                left = outputs_ref[prompt_length:, i + 1]
                right = outputs_ref[prompt_length:, i ]
            else:
                left = outputs_ref[prompt_length:, i ]
                right = outputs_ref[prompt_length:, i + 1]
            dataset.append([prompts[:, j], left, right])
            j += 1
    return dataset

In [217]:
dataset_dpo = generate_dataset(model)

3600it [03:50, 15.64it/s]


In [218]:
i = random.randint(0, len(dataset_dpo))
print(tokenizer.decode(dataset_dpo[i ][0].tolist()))
print(tokenizer.decode(dataset_dpo[i ][1].tolist()))
print(tokenizer.decode(dataset_dpo[i ][2].tolist()))

473+484=
960[EOS][PAD]
955[EOS][PAD]


### 2. Compute the loss for DPO

The loss is defined only in terms of policy (ie, transformers models).

Given:
* a policy of reference $ \pi_{ref} $
* a policy $ \pi_{\theta} $ that we learn.
* a dataset $\cal{D}(x, y_b, y_w) $

We can formulate a maximum likelihood objective for $ \pi_{\theta} $ as :
$\mathcal{L}_{\text{DPO}}(\pi_{\theta}; \pi_{\text{ref}}) =
- \mathbb{E}_{(x, y_w, y_l) \sim \mathcal{D}} \left[
\log \sigma \left( \beta \log \frac{\pi_{\theta}(y_w \mid x)}{\pi_{\text{ref}}(y_w \mid x)}
- \beta \log \frac{\pi_{\theta}(y_l \mid x)}{\pi_{\text{ref}}(y_l \mid x)} \right)
\right]$ght]$ht]$ri$$]




In [219]:
def compute_loss_dpo(log_probs, log_probs_ref, beta = 1):
    return -torch.mean(torch.log(torch.sigmoid(beta * (log_probs[:, 0] - log_probs_ref[: , 0]) - beta * (log_probs[:, 1] - log_probs_ref[:, 1]))))

### 3. Train

In [220]:
import copy 
def train_DPO(verbose = False):
    optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)

    best_test_accuracy = None
    test_accuracy = evaluate()
    print('-' * 89)
    print('| initialisation | test accuracy {:5.2f}'.format(test_accuracy))
    print('-' * 89)
    
    # switch eval for train model (enables dropout)
    model.train()
    j = 0
    for epoch in range(1, epochs+1):
        epoch_start_time = time.time()
        start_time = time.time()
        for i in tqdm(range(0, len(dataset_dpo), batch_size)):
            prompts = torch.tensor([], device = device, dtype = torch.long)
            answers1 = torch.tensor([], device = device, dtype = torch.long)
            answers2 = torch.tensor([], device = device, dtype = torch.long)
            for j in range(i, i + batch_size):
                prompts = torch.cat((prompts, dataset_dpo[j][0].unsqueeze(1)), dim = 1)

                answers1 = torch.cat((answers1, dataset_dpo[j][1].unsqueeze(1)), dim = 1)
                answers2 = torch.cat((answers2, dataset_dpo[j][2].unsqueeze(1)), dim = 1)
            prompts = prompts.view(-1, batch_size)
            answers1 = answers1.view(-1, batch_size)
            answers2 = answers2.view(-1, batch_size)
            # get a batch of prompts and answers
            input_tensor = torch.repeat_interleave(prompts, repeats = 2, dim = 1).to(device)
            outputs = generate(model,
                               prompts,
                               new_tokens = answers_length + 1,
                               mode = "sampling",
                               num_samples = 2,
                               temperature = temperature)
            answers = torch.cat((answers1, answers2), dim = 1)
            outputs_ref = torch.cat((input_tensor, answers), 0)

            # compute rewards
            #rewards = compute_rewards(text_outputs, answers)
            #print(rewards)
            #return
            # compute advantages
            #advantages = calculate_grpo_advantages(rewards)

            # compute log probabilities

            log_probs = compute_log_probs(model, outputs, prompt_length).view(-1, 2, batch_size, 14)
            log_probs_ref = compute_log_probs(model_ref, outputs_ref, prompt_length).view(-1, 2, batch_size, 14)
            
                
            #log_probs_ref = compute_log_probs(model_ref, 
            # compute loss
            responses = outputs[prompt_length:, :]
            loss = compute_loss_dpo(log_probs, log_probs_ref)

            # optimize
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()

            if i % log_interval == 0 :
                elapsed = time.time() - start_time
                if verbose:
                    print("\nquestion:", questions[0],
                      "\nanswer", answers[0],
                      "\noutput:", text_outputs[:num_samples],
                      "\nreward:", rewards[:num_samples],
                      "\nadvantage:", advantages[:num_samples], "\n")

                start_time = time.time()
        test_accuracy = evaluate()
        print('-' * 89)
        print('| end of epoch {:3d} | time: {:5.2f}s | test accuracy {:5.2f}'.format(epoch, (time.time() - epoch_start_time), test_accuracy))
        print('-' * 89)
        # Save the model if the test accuracy is the best we've seen so far.
        if not best_test_accuracy or test_accuracy < best_test_accuracy:
            with open("arithmetic_dpo.pt", 'wb') as f:
                torch.save(model, f)
            best_test_accuracy = test_accuracy


In [221]:
model = copy.deepcopy(model_ref)

In [222]:
model_ref = copy.deepcopy(model)

In [223]:

train_DPO()

-----------------------------------------------------------------------------------------
| initialisation | test accuracy  0.39
-----------------------------------------------------------------------------------------


100%|███████████████████████████████████████████████████████████████████████████████| 3600/3600 [06:55<00:00,  8.67it/s]


-----------------------------------------------------------------------------------------
| end of epoch   1 | time: 431.71s | test accuracy  0.00
-----------------------------------------------------------------------------------------


100%|███████████████████████████████████████████████████████████████████████████████| 3600/3600 [06:26<00:00,  9.31it/s]


-----------------------------------------------------------------------------------------
| end of epoch   2 | time: 403.86s | test accuracy  0.00
-----------------------------------------------------------------------------------------


In [224]:
model.eval()

for i in range(20):
    prompt, answers = data_test[i]
    prompt_tensor = torch.tensor(tokenizer.encode(prompt)).view((-1,1))
    output = generate(model, prompt_tensor, len(answers) + 1).view((1,-1))
    print(tokenizer.decode(output.tolist()[0]) + "\t actual result: " + answers)

540+644=11843	 actual result: 1184
687+185=7774	 actual result: 872
504+792=12943	 actual result: 1296
887+025=9013	 actual result: 912
200+564=7764	 actual result: 764
256+441=6993	 actual result: 697
660+030=6893	 actual result: 690
796+187=9883	 actual result: 983
599+686=12843	 actual result: 1285
507+187=6893	 actual result: 694
163+583=7443	 actual result: 746
224+304=5233	 actual result: 528
790+006=7994	 actual result: 796
268+223=4994	 actual result: 491
151+191=3343	 actual result: 342
559+183=7434	 actual result: 742
138+947=10843	 actual result: 1085
225+967=11933	 actual result: 1192
689+042=7333	 actual result: 731
809+127=9034	 actual result: 936


Model's accuracy has completely collapsed, and it always outputs one digit too many,  and there can be different reasons for this to happen:
* Maybe start with a better model (Accuracy over 60 %).
* Increase parameter Beta.
* The loss I've implemented can be wrong.