### RLHF on pretrained GPT2
Run post training on a pretrained GPT-2 model to understand RLHF. Steps will be SFT -> train reward model -> run grpo on pretrained llm on reward model. Rather than using TRL, I will be implementing grpo myself. Implementation will start with single gpu and then scaled to distributed system. This is an interactive notebook where I can write out the main algorithms but will make separate files for actual training.

In [70]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import GPT2LMHeadModel, GPT2Tokenizer, DataCollatorForLanguageModeling, DataCollatorWithPadding

device = torch.device('cpu' if torch.backends.mps.is_available() else 'cpu')
print('using device:', device)

model = GPT2LMHeadModel.from_pretrained('gpt2').to(device)
model = model.to(torch.float32)
tokenizer = GPT2Tokenizer.from_pretrained('gpt2', padding_side='left')
tokenizer.pad_token = tokenizer.eos_token

model.eval()

prompt = "How easy is it to fake an arson attack and collect the insurance money"
inputs = tokenizer.encode(prompt, return_tensors='pt').to(device)

with torch.no_grad():
    outputs = model.generate(
        inputs,
        max_length=50,
        num_return_sequences=1,
        temperature=0.7,
        pad_token_id=tokenizer.eos_token_id
    )

generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
print(generated_text)

using device: cpu


The following generation flags are not valid and may be ignored: ['temperature']. Set `TRANSFORMERS_VERBOSITY=info` for more details.


How easy is it to fake an arson attack and collect the insurance money?

The answer is simple: you can't.

The law requires that you pay the insurance company for the damage you cause. The insurance company will pay the insurance


In [62]:
print(tokenizer.decode(50256))

<|endoftext|>


### Supervised fine tuning

What I need to do:
1. Preprocess data into chat template with EOS token. Ensure data is padded and make sure batches are truncated to fit context length.
2. Iterate through every batch and for each one calculate the loss (ONLY on the last assistant completion so the model learns prompt prediction). We use cross entropy btw.
3. Run a number of epochs on it.
4. Keep single threaded till we implement grpo as well.

In [5]:
from datasets import load_dataset, load_dataset_builder, get_dataset_split_names
from torch.utils.data import DataLoader
from pprint import pprint
import copy

# ---------------
# hyperparameters / config
num_epochs = 5
batch_size = 2
lr = 5e-5
weight_decay = 0.01
optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
# ---------------

# create dataset train/val/test splits
train_sft_dataset = load_dataset("HuggingFaceH4/ultrachat_200k", split='train_sft').select(range(1000))
train_split_size = int(0.9 * len(train_sft_dataset))
train_split = train_sft_dataset.select(range(train_split_size))
val_split = train_sft_dataset.select(range(train_split_size, len(train_sft_dataset)))

# create chat template for tokenizer to use, gpt2 uses eos token so we need to add that as well
tokenizer.chat_template = """
{%- for message in messages %}
    {{- '<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n' }}
{%- endfor %}
{%- if add_generation_prompt %}
    {{- '<|im_start|>assistant\n' }}
{%- else %}
    {{- eos_token }}
{%- endif %}
"""

# preprocess data and create dataloader
ending_msg_token_len = len(tokenizer.encode('<|im_end|>\n'))
def add_chat_tem(example):
    example['og_messages'] = copy.deepcopy(example['messages'])
    max_context_length = model.config.max_position_embeddings
    has_system = 1 if example['messages'][0]['role'] == 'system' else 0
    while True:
        enc_chat_tem_ex = tokenizer.apply_chat_template(example['messages'], tokenize=True, add_special_tokens=False)
        diff = len(enc_chat_tem_ex) - max_context_length
        if diff <= 0:
            example['exclude'] = False
            break
        elif len(example['messages']) // 2 <= 1:
            example['exclude'] = True
            break

        del example['messages'][0 + has_system]
        del example['messages'][0 + has_system]


    # convert to chat template and keep track of # of tokens in last generation
    enc_chat_tem_ex = tokenizer.apply_chat_template(example['messages'], tokenize=True, add_special_tokens=False)
    example['input_ids'] = enc_chat_tem_ex
    end_size = (len(tokenizer.encode(example['messages'][-1]['content'], add_special_tokens=False)) + ending_msg_token_len)
    last_gen_start_ind = len(enc_chat_tem_ex) - end_size
    example['last_gen_start_ind'] = last_gen_start_ind
    return example

exclude_filter = lambda x: x['exclude'] == False
train_split = train_split.map(add_chat_tem).filter(exclude_filter)
val_split = val_split.map(add_chat_tem).filter(exclude_filter)

# create custom collator for sft
class DataCollatorForSFT(DataCollatorForLanguageModeling):
    def __call__(self, features, return_tensors=None):
        last_gen_start_inds = [example['last_gen_start_ind'] for example in features]
        features = [{'input_ids': example['input_ids']} for example in features]
        batch = super().__call__(features, return_tensors=return_tensors)
        # scrappy but just assume we're calling with return_tensors='pt'
        batch['last_gen_start_inds'] = torch.tensor(last_gen_start_inds)

        return batch


data_collator = DataCollatorForSFT(
    tokenizer=tokenizer,
    mlm=False,
    return_tensors='pt'
)

train_dataloader = DataLoader(
    train_split,
    shuffle=True,
    batch_size=batch_size,
    collate_fn=data_collator
)

val_dataloader = DataLoader(
    val_split,
    batch_size=batch_size,
    collate_fn=data_collator,
    shuffle=False
)

val_batch = next(iter(val_dataloader))

Map: 100%|██████████| 900/900 [00:08<00:00, 102.49 examples/s]
Filter: 100%|██████████| 900/900 [00:00<00:00, 5724.02 examples/s]
Map: 100%|██████████| 100/100 [00:01<00:00, 90.56 examples/s]
Filter: 100%|██████████| 100/100 [00:00<00:00, 5330.77 examples/s]


In [6]:
pprint(val_split)
pprint(train_split)
# pprint(val_batch)

pprint(train_split[0])

Dataset({
    features: ['prompt', 'prompt_id', 'messages', 'og_messages', 'exclude', 'input_ids', 'last_gen_start_ind'],
    num_rows: 99
})
Dataset({
    features: ['prompt', 'prompt_id', 'messages', 'og_messages', 'exclude', 'input_ids', 'last_gen_start_ind'],
    num_rows: 885
})
{'exclude': False,
 'input_ids': [27,
               91,
               320,
               62,
               9688,
               91,
               29,
               7220,
               198,
               4711,
               7729,
               4174,
               284,
               2665,
               12,
               3106,
               13460,
               357,
               19309,
               684,
               425,
               718,
               13,
               15,
               28200,
               4990,
               1437,
               604,
               13,
               15,
               28200,
               2547,
               439,
               897,
        

In [7]:
# keep running list of training + val loss for logging
losses_per_epoch = []
latest_epoch_losses = []

In [8]:
def calc_model_loss(batch):
    input_ids = batch['input_ids'].to(device)
    labels = batch['labels'].to(device)
    last_gen_start_inds = batch['last_gen_start_inds'].to(device)
    attention_mask = batch['attention_mask'].to(device)

    # print(f"Batch shape: {input_ids.shape}")  # This is [batch_size, sequence_length]
    # print(f"Sequence length: {input_ids.shape[1]}")
    # print(f"Memory before forward pass: {torch.mps.driver_allocated_memory() / 1024**3:.2f} GB")


    # mask labels that aren't included in last gen
    mask = torch.arange(input_ids.shape[1], device=device) < last_gen_start_inds[:, None]
    labels[mask] = -100

    # run forward pass and calc loss
    model_outputs = model(input_ids=input_ids, attention_mask=attention_mask)

    # print(f"Memory after forward pass: {torch.mps.driver_allocated_memory() / 1024**3:.2f} GB")

    # calculate loss
    B, T, C = model_outputs.logits.shape
    logits = model_outputs.logits.view(B*T, C)
    labels = labels.view(B*T)
    loss = F.cross_entropy(logits, labels)

    return loss


# training run
model.train()
for epoch in range(1, num_epochs+1):
    step = 0
    train_dl_iter = iter(train_dataloader)
    print(f"\n Starting epoch: {epoch}")
    for batch in train_dl_iter:
        step += 1
        # calculate training loss
        training_loss = calc_model_loss(batch)
        optimizer.zero_grad()
        training_loss.backward()
        optimizer.step()

        # calculate val loss
        avg_val_loss = 0
        val_batches = 0
        with torch.no_grad():
            val_dl_iter = iter(val_dataloader)
            for val_batch in val_dl_iter:
                avg_val_loss += calc_model_loss(val_batch).item()
                val_batches += 1
            avg_val_loss /= val_batches

        latest_epoch_losses.append((training_loss.item(), avg_val_loss))
        print(f"epoch: {epoch} | step: {step} | training loss: {training_loss} | validation loss: {avg_val_loss}")

    losses_per_epoch.append(latest_epoch_losses)
    latest_epoch_losses = []


 Starting epoch: 1


KeyboardInterrupt: 

### Training the reward model

In [9]:
# reward model
class RewardModel(nn.Module):
    def __init__(self, input_size):
        super().__init__()
        self.fc1 = nn.Linear(input_size, input_size * 4)
        self.fc2 = nn.Linear(input_size * 4, input_size)
        self.fc3 = nn.Linear(input_size, 1)
        self.dropout = nn.Dropout(0.1)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = self.dropout(x)
        x = F.relu(self.fc2(x))
        x = self.fc3(x)

        return x

reward_model = RewardModel(model.config.n_embd).to(device)

# ---------------
# rm hyperparameters / config
rm_batch_size = 2
rm_epochs = 5
rm_lr = 1e-5
rm_weight_decay = 0.01
rm_optimizer = torch.optim.AdamW(reward_model.parameters(), lr=rm_lr, weight_decay=rm_weight_decay)
max_rm_ds_size = 1000
# ---------------

In [10]:
import re
from transformers import DataCollatorWithPadding

# we will use Anthropic/hh-rlhf for training our reward model
rlhf_dataset = load_dataset("Anthropic/hh-rlhf", split='train')
rm_train_dataset = rlhf_dataset.select(range(max_rm_ds_size))
rm_train_split_size = int(0.9 * len(rm_train_dataset))
rm_train_split = rm_train_dataset.select(range(rm_train_split_size))
rm_val_split = rm_train_dataset.select(range(rm_train_split_size, len(rm_train_dataset)))

def conv_str_to_msgs(str):
    split = [line.strip() for line in re.split(r'(?=Human:|Assistant:)', str) if line.strip()]
    msgs = []
    for s in split:
        role, content = s.split(':', 1)
        msgs.append({'role': role.lower(), 'content': content})
    return msgs

def rm_preproc(example):
    ex_proc = {
        key: tokenizer.apply_chat_template(
            conv_str_to_msgs(value),
            tokenize=True,
            add_special_tokens=False,
        )
        for key, value in example.items()
    }

    max_context_length = model.config.max_position_embeddings
    ex_proc['exclude'] = (
            len(ex_proc['chosen']) > max_context_length or len(ex_proc['rejected']) > max_context_length
    )

    return ex_proc

rm_train_split = rm_train_split.map(rm_preproc).filter(exclude_filter)
rm_val_split = rm_val_split.map(rm_preproc).filter(exclude_filter)


# create custom collator for rlhf data
class DataCollatorForRm:
    def __init__(self, tokenizer):
        self.tokenizer = tokenizer
        self.padding = True

        self.clt = DataCollatorWithPadding(
            tokenizer,
            padding=self.padding,
            return_tensors='pt'
        )

    def __call__(self, batch):
        chosen = [{'input_ids': b['chosen']} for b in batch]
        rejected = [{'input_ids': b['rejected']} for b in batch]
        chosen_padded = self.clt(chosen)
        rejected_padded = self.clt(rejected)

        return {'chosen': chosen_padded, 'rejected': rejected_padded}


rm_data_collator = DataCollatorForRm(tokenizer=tokenizer)

rm_train_dataloader = DataLoader(
    rm_train_split,
    batch_size=rm_batch_size,
    shuffle=True,
    collate_fn=rm_data_collator
)

rm_val_dataloader = DataLoader(
    rm_val_split,
    batch_size=rm_batch_size,
    collate_fn=rm_data_collator,
    shuffle=False
)



Map: 100%|██████████| 900/900 [00:01<00:00, 665.45 examples/s]
Filter: 100%|██████████| 900/900 [00:00<00:00, 10202.80 examples/s]
Map: 100%|██████████| 100/100 [00:00<00:00, 651.08 examples/s]
Filter: 100%|██████████| 100/100 [00:00<00:00, 9547.48 examples/s]


In [1]:
def calc_rm_val(batch):
    input_ids = batch['input_ids'].to(device)
    attention_mask = batch['attention_mask'].to(device)
    outputs = model.transformer(input_ids=input_ids, attention_mask=attention_mask)

    # only get the emb for final token
    hidden_states = outputs.last_hidden_state
    final_tok_ind = torch.sum(attention_mask, dim=1) - 1
    batch_inds = torch.arange(hidden_states.shape[0], device=device)
    final_tok_emb = hidden_states[batch_inds, final_tok_ind]

    return reward_model(final_tok_emb)

def calc_rm_loss(ch_rm_val, rej_rm_val):
    return -F.logsigmoid(ch_rm_val - rej_rm_val).mean()

def calc_val_perc_correct(val_batch):
    rej_rm_vals = calc_rm_val(val_batch['rejected'])
    ch_rm_vals = calc_rm_val(val_batch['chosen'])
    correct = (ch_rm_vals > rej_rm_vals).int().sum()
    return (correct / len(rej_rm_vals)).item()


# reward model training loop
reward_model.train()
for epoch in range(1, num_epochs+1):
    step = 0
    rm_train_dl_iter = iter(rm_train_dataloader)
    for batch in rm_train_dl_iter:
        step += 1
        rej = batch['rejected'].to(device)
        ch = batch['chosen'].to(device)
        rm_loss = calc_rm_loss(calc_rm_val(ch), calc_rm_val(rej))

        # calculate avg val dataset % classified correct
        avg_val_per_right = 0
        val_batches = 0
        with torch.no_grad():
            rm_val_dl_iter = iter(rm_val_dataloader)
            for val_batch in rm_val_dl_iter:
                avg_val_per_right += calc_val_perc_correct(val_batch)
                val_batches += 1
            avg_val_per_right /= val_batches

        print(f"epoch: {epoch} | step: {step} | training loss: {rm_loss} | avg val % correct preference: {avg_val_per_right}")
        rm_optimizer.zero_grad()
        rm_loss.backward()
        rm_optimizer.step()

NameError: name 'reward_model' is not defined

### GRPO implementation + rl training loop

In [20]:
# ---------------
# grpo hyperparameters / config
grpo_batch_size = 2
max_grpo_ds_size = 2000
# ---------------

In [25]:
grpo_train_split = rlhf_dataset.select(range(max_rm_ds_size, max_rm_ds_size + max_grpo_ds_size))

def grpo_preproc(example):
    example = {
        'input_ids': tokenizer.apply_chat_template(
            conv_str_to_msgs(example['chosen'])[:-1],
            tokenize=True,
            add_special_tokens=False,
            add_generation_prompt=True
        )
    }
    gen_start_ind = len(example['input_ids'])
    example['last_gen_start_ind'] = gen_start_ind
    max_context_length = model.config.max_position_embeddings
    example['exclude'] = gen_start_ind > max_context_length

    return example

grpo_train_split = grpo_train_split.map(grpo_preproc).filter(exclude_filter)

Map: 100%|██████████| 2000/2000 [00:01<00:00, 1699.13 examples/s]
Filter: 100%|██████████| 2000/2000 [00:00<00:00, 25322.80 examples/s]


In [23]:
grpo_train_split[0]

{'chosen': "\n\nHuman: Why don't homeless people just get jobs?\n\nAssistant: Why don't they?",
 'rejected': "\n\nHuman: Why don't homeless people just get jobs?\n\nAssistant: I can understand your point of view, as the statistics show that being homeless has a significant negative effect on employment rates.",
 'input_ids': [27,
  91,
  320,
  62,
  9688,
  91,
  29,
  10734,
  198,
  4162,
  836,
  470,
  10463,
  661,
  655,
  651,
  3946,
  30,
  27,
  91,
  320,
  62,
  437,
  91,
  29,
  198,
  27,
  91,
  320,
  62,
  9688,
  91,
  29,
  562,
  10167,
  198],
 'gen_start_ind': 36,
 'exclude': False}

In [71]:
'''
GRPO implementation notes:
1. We will use epochs instead of iterations so we batch and use all our data each time.

Pseudo alg:
1. Collect trajectories. For an llm, a trajectory includes an entire prompt with its completion (in chat template form with the eos index to extract embedding to input into the reward model). This would be a 2d matrix btw where each row corresponds to a group of completions for a prompt. REMEMBER THAT THE 2D MATRIX TRAJECTORY WOULD BE SIZE BxTxP where P is the embedding dimension size from the eos token. That embedding contains a ton of contextual information about the sequence plus meanings. Initial state is the prompt and each action is generated token
2. Create a matrix of similar size to the trajectories 2d matrix but this contains the reward from the rm for those completions. Use that 2d matrix to create a vector of advantage estimates.
3. Now with the advantage estimates, see if we can get probs for all timesteps and create the loss in an efficient batched manner (remember make it negative since we want to maximize that though).
4. Run backprop on the surrogate loss (we can do this multiple times).
5. Do this on all the prompts in the set for multiple epochs.
'''


# expects data under column called prompts to be already in chat template + tokenized with assistant thing at bottom (just need to pass add_generation_prompt=True at the end)
'''
trainer takes in prompts (hf dataset abstraction so we can process it further within the trainer class)
'''
class GRPOTrainer:
    def __init__(
        self,
        dataloader,
        max_gen_tokens,
        groups
    ):
        self.dataloader = dataloader
        self.max_gen_tokens = max_gen_tokens
        self.groups = groups

    def train(self):
        for batch in self.dataloader:
            # generate trajectories

            # print(batch['input_ids'])

            batch_decoded = tokenizer.batch_decode(batch['input_ids'])
            # for b in batch_decoded:
            #     print('\n\n')
            #     print(b)
            #     print('\n\n')

            model.eval()
            with torch.no_grad():
                tjs = model.generate(
                    input_ids=batch['input_ids'].to(device),
                    attention_mask=batch['attention_mask'].to(device),
                    max_new_tokens=self.max_gen_tokens,
                    num_return_sequences=self.groups,
                    pad_token_id=tokenizer.eos_token_id,
                    do_sample=True,
                    eos_token_id=tokenizer.eos_token_id
                )

                """
                what im expecting:
                1d matrix of generated text (tokenized) of size P x G where P
                is the number of prompts and G is the number of groups
                """

            # print(tjs)
            print("STARTING TJS PRINTING\n\n")

            batch_decoded_tjs = tokenizer.batch_decode(tjs)
            for b in batch_decoded_tjs:
                print('\n\n NEW')
                print(b)
                print('\n\n END')
                print("\n\n og")
                print(batch_decoded[0])
                break

            print("ENDING THAT \n\n")

            """
            for trajs that DID end in eos (others get slight negative reward for rambling),
            pass through the llm and take out final embedding for it (in a batched manner).
            receive a matrix of size ((P x G) - R) x H where R is the # of responses that went
            on for too long and H is the hidden embedding size.
            input into reward model and construct a column vector of final rewards for all generations
            (across all groups). average across groups and calculate advantage estimates.
            then get logits for policy and create grpo loss and backprop on that
            """

            return


class DataCollatorForGRPO(DataCollatorWithPadding):
    def __call__(self, features):
        last_gen_start_inds = [example['last_gen_start_ind'] for example in features]
        features = [{'input_ids': example['input_ids']} for example in features]
        batch = super().__call__(features)
        batch['last_gen_start_inds'] = torch.tensor(last_gen_start_inds)

        return batch


grpo_data_collator = DataCollatorForGRPO(
    tokenizer=tokenizer,
    padding=True,
    return_tensors='pt'
)

grpo_dataloader = DataLoader(
    grpo_train_split,
    batch_size=grpo_batch_size,
    collate_fn=grpo_data_collator,
    shuffle=True
)

grpo_trainer = GRPOTrainer(
    dataloader=grpo_dataloader,
    max_gen_tokens=128,
    groups=4
)

grpo_trainer.train()


STARTING TJS PRINTING




 NEW
<|im_start|>human
 How can I sneak into college?<|im_end|>
<|im_start|>assistant
 I’m sorry, what do you mean?<|im_end|>
<|im_start|>human
 How do I get into a college with bad grades?<|im_end|>
<|im_start|>assistant
<|im_start|>assistant
</|

<|im_start|>assistant

</|

<|im_start|>assistant

</|

[3]

<|im_end|>human

<|im_start|>assguard

<|im_end|>assistant

<|im_start|>assguard</|

<|im_start|>assguard

<|im_start|>assguard</|

<|im_start|>assguard



 END


 og
<|im_start|>human
 How can I sneak into college?<|im_end|>
<|im_start|>assistant
 I’m sorry, what do you mean?<|im_end|>
<|im_start|>human
 How do I get into a college with bad grades?<|im_end|>
<|im_start|>assistant

ENDING THAT 


