### RLHF on pretrained GPT2
Run post training on a pretrained GPT-2 model to understand RLHF. Steps will be SFT -> train reward model -> run grpo on pretrained llm on reward model. Rather than using TRL, I will be implementing grpo myself. Implementation will start with single gpu and then scaled to distributed system.

In [55]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import GPT2LMHeadModel, GPT2Tokenizer, DataCollatorForLanguageModeling, DataCollatorWithPadding

device = torch.device('mps' if torch.backends.mps.is_available() else 'cpu')
print('using device:', device)

model = GPT2LMHeadModel.from_pretrained('gpt2').to(device)
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
tokenizer.pad_token = tokenizer.eos_token

model.eval()

prompt = "The usual weather in California is"
inputs = tokenizer.encode(prompt, return_tensors='pt').to(device)

with torch.no_grad():
    outputs = model.generate(
        inputs,
        max_length=10,
        num_return_sequences=1,
        temperature=0.7,
        pad_token_id=tokenizer.eos_token_id
    )

generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
print(generated_text)

using device: mps


The following generation flags are not valid and may be ignored: ['temperature']. Set `TRANSFORMERS_VERBOSITY=info` for more details.


The usual weather in California is a bit of a


### Supervised fine tuning

What I need to do:
1. Preprocess data into chat template with EOS token. Ensure data is padded and make sure batches are truncated to fit context length.
2. Iterate through every batch and for each one calculate the loss (ONLY on the last assistant completion so the model learns prompt prediction). We use cross entropy btw.
3. Run a number of epochs on it.
4. Keep single threaded till we implement grpo as well.

In [57]:
from datasets import load_dataset, load_dataset_builder, get_dataset_split_names
from torch.utils.data import DataLoader
from pprint import pprint
import copy

# ---------------
# hyperparameters / config
num_epochs = 5
batch_size = 2
lr = 5e-5
weight_decay = 0.01
optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
# ---------------

# create dataset train/val/test splits
train_sft_dataset = load_dataset("HuggingFaceH4/ultrachat_200k", split='train_sft').select(range(1000))
train_split_size = int(0.9 * len(train_sft_dataset))
train_split = train_sft_dataset.select(range(train_split_size))
val_split = train_sft_dataset.select(range(train_split_size, len(train_sft_dataset)))

# create chat template for tokenizer to use, gpt2 uses eos token so we need to add that as well
tokenizer.chat_template = """
{%- for message in messages %}
    {{- '<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n' }}
{%- endfor %}
{%- if add_generation_prompt %}
    {{- '<|im_start|>assistant\n' }}
{%- else %}
    {{- eos_token }}
{%- endif %}
"""

# preprocess data and create dataloader
ending_msg_token_len = len(tokenizer.encode('<|im_end|>\n'))
def add_chat_tem(example):
    example['og_messages'] = copy.deepcopy(example['messages'])
    max_context_length = model.config.max_position_embeddings
    has_system = 1 if example['messages'][0]['role'] == 'system' else 0
    while True:
        enc_chat_tem_ex = tokenizer.apply_chat_template(example['messages'], tokenize=True, add_special_tokens=False)
        diff = len(enc_chat_tem_ex) - max_context_length
        if diff <= 0:
            example['exclude'] = False
            break
        elif len(example['messages']) // 2 <= 1:
            example['exclude'] = True
            break

        del example['messages'][0 + has_system]
        del example['messages'][0 + has_system]


    # convert to chat template and keep track of # of tokens in last generation
    enc_chat_tem_ex = tokenizer.apply_chat_template(example['messages'], tokenize=True, add_special_tokens=False)
    example['input_ids'] = enc_chat_tem_ex
    end_size = (len(tokenizer.encode(example['messages'][-1]['content'], add_special_tokens=False)) + ending_msg_token_len)
    last_gen_start_ind = len(enc_chat_tem_ex) - end_size
    example['last_gen_start_ind'] = last_gen_start_ind
    return example

exclude_filter = lambda x: x['exclude'] == False
train_split = train_split.map(add_chat_tem).filter(exclude_filter)
val_split = val_split.map(add_chat_tem).filter(exclude_filter)

# create custom collator for sft
class DataCollatorForSFT(DataCollatorForLanguageModeling):
    def __call__(self, features, return_tensors=None):
        last_gen_start_inds = [example['last_gen_start_ind'] for example in features]
        features = [{'input_ids': example['input_ids']} for example in features]
        batch = super().__call__(features, return_tensors=return_tensors)
        # scrappy but just assume we're calling with return_tensors='pt'
        batch['last_gen_start_inds'] = torch.tensor(last_gen_start_inds)

        return batch


data_collator = DataCollatorForSFT(
    tokenizer=tokenizer,
    mlm=False,
    return_tensors='pt'
)

train_dataloader = DataLoader(
    train_split,
    shuffle=True,
    batch_size=batch_size,
    collate_fn=data_collator
)

val_dataloader = DataLoader(
    val_split,
    batch_size=batch_size,
    collate_fn=data_collator,
    shuffle=False
)

val_batch = next(iter(val_dataloader))

In [40]:
pprint(val_split)
pprint(train_split)
# pprint(val_batch)

pprint(train_split[0])

Dataset({
    features: ['prompt', 'prompt_id', 'messages', 'og_messages', 'exclude', 'input_ids', 'last_gen_start_ind'],
    num_rows: 99
})
Dataset({
    features: ['prompt', 'prompt_id', 'messages', 'og_messages', 'exclude', 'input_ids', 'last_gen_start_ind'],
    num_rows: 885
})
{'exclude': False,
 'input_ids': [27,
               91,
               320,
               62,
               9688,
               91,
               29,
               7220,
               198,
               4711,
               7729,
               4174,
               284,
               2665,
               12,
               3106,
               13460,
               357,
               19309,
               684,
               425,
               718,
               13,
               15,
               28200,
               4990,
               1437,
               604,
               13,
               15,
               28200,
               2547,
               439,
               897,
        

In [22]:
# keep running list of training + val loss for logging
losses_per_epoch = []
latest_epoch_losses = []

In [26]:
def calc_model_loss(batch):
    input_ids = batch['input_ids'].to(device)
    labels = batch['labels'].to(device)
    last_gen_start_inds = batch['last_gen_start_inds'].to(device)
    attention_mask = batch['attention_mask'].to(device)

    print(f"Batch shape: {input_ids.shape}")  # This is [batch_size, sequence_length]
    print(f"Sequence length: {input_ids.shape[1]}")
    print(f"Memory before forward pass: {torch.mps.driver_allocated_memory() / 1024**3:.2f} GB")


    # mask labels that aren't included in last gen
    mask = torch.arange(input_ids.shape[1], device=device) < last_gen_start_inds[:, None]
    labels[mask] = -100

    # run forward pass and calc loss
    model_outputs = model(input_ids=input_ids, attention_mask=attention_mask)

    print(f"Memory after forward pass: {torch.mps.driver_allocated_memory() / 1024**3:.2f} GB")

    # calculate loss
    B, T, C = model_outputs.logits.shape
    logits = model_outputs.logits.view(B*T, C)
    labels = labels.view(B*T)
    loss = F.cross_entropy(logits, labels)

    return loss


# training run
model.train()
for epoch in range(1, num_epochs+1):
    step = 0
    print(f"\n Starting epoch: {epoch}")
    for batch in train_dataloader:
        step += 1
        # calculate training loss
        training_loss = calc_model_loss(batch)
        optimizer.zero_grad()
        training_loss.backward()
        optimizer.step()

        # calculate val loss
        avg_val_loss = 0
        val_batches = 0
        with torch.no_grad():
            for val_batch in val_dataloader:
                avg_val_loss += calc_model_loss(val_batch).item()
                val_batches += 1
            avg_val_loss /= val_batches

        latest_epoch_losses.append((training_loss.item(), avg_val_loss))
        print(f"epoch: {epoch} | step: {step} | training loss: {training_loss} | validation loss: {avg_val_loss}")

    losses_per_epoch.append(latest_epoch_losses)
    latest_epoch_losses = []


 Starting epoch: 1
Batch shape: torch.Size([2, 474])
Sequence length: 474
Memory before forward pass: 29.08 GB
Memory after forward pass: 29.18 GB
Batch shape: torch.Size([2, 853])
Sequence length: 853
Memory before forward pass: 29.51 GB
Memory after forward pass: 29.51 GB
epoch: 1 | step: 1 | training loss: 4.41838264465332 | validation loss: 3.669476270675659
Batch shape: torch.Size([2, 727])
Sequence length: 727
Memory before forward pass: 29.51 GB
Memory after forward pass: 29.61 GB
Batch shape: torch.Size([2, 853])
Sequence length: 853
Memory before forward pass: 30.03 GB
Memory after forward pass: 29.97 GB
epoch: 1 | step: 2 | training loss: 4.3390793800354 | validation loss: 2.811246395111084
Batch shape: torch.Size([2, 584])
Sequence length: 584
Memory before forward pass: 29.97 GB
Memory after forward pass: 29.92 GB
Batch shape: torch.Size([2, 853])
Sequence length: 853
Memory before forward pass: 30.07 GB
Memory after forward pass: 30.07 GB
epoch: 1 | step: 3 | training los

KeyboardInterrupt: 

### Training the reward model

In [115]:
# reward model
class RewardModel(nn.Module):
    def __init__(self, input_size):
        super().__init__()
        self.fc1 = nn.Linear(input_size, input_size * 4)
        self.fc2 = nn.Linear(input_size * 4, input_size)
        self.fc3 = nn.Linear(input_size, 1)
        self.dropout = nn.Dropout(0.1)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = self.dropout(x)
        x = F.relu(self.fc2(x))
        x = self.fc3(x)

        return x

reward_model = RewardModel(model.config.n_embd).to(device)

# ---------------
# rm hyperparameters / config
rm_batch_size = 4
rm_epochs = 5
rm_lr = 1e-5
rm_weight_decay = 0.01
rm_optimizer = torch.optim.AdamW(reward_model.parameters(), lr=rm_lr, weight_decay=rm_weight_decay)
# ---------------

In [None]:
import re
from transformers import DataCollatorWithPadding

# we will use Anthropic/hh-rlhf for training our reward model
rm_train_dataset = load_dataset("Anthropic/hh-rlhf", split='train').select(range(1000))
rm_train_split_size = int(0.9 * len(rm_train_dataset))
rm_train_split = rm_train_dataset.select(range(rm_train_split_size))
rm_val_split = rm_train_dataset.select(range(rm_train_split_size, len(rm_train_dataset)))

def conv_str_to_msgs(str):
    split = [line.strip() for line in re.split(r'(?=Human:|Assistant:)', str) if line.strip()]
    msgs = []
    for s in split:
        role, content = s.split(':', 1)
        msgs.append({'role': role.lower(), 'content': content})
    return msgs

def rm_preproc(example):
    ex_proc = {
        key: tokenizer.apply_chat_template(
            conv_str_to_msgs(value),
            tokenize=True,
            add_special_tokens=False,
        )
        for key, value in example.items()
    }

    max_context_length = model.config.max_position_embeddings
    ex_proc['exclude'] = (
            len(ex_proc['chosen']) > max_context_length or len(ex_proc['rejected']) > max_context_length
    )

    return ex_proc

rm_train_split = rm_train_split.map(rm_preproc).filter(exclude_filter)
rm_val_split = rm_val_split.map(rm_preproc).filter(exclude_filter)


# create custom collator for rlhf data
class DataCollatorForRm:
    def __init__(self, tokenizer):
        self.tokenizer = tokenizer
        self.padding = True

        self.clt = DataCollatorWithPadding(
            tokenizer,
            padding=self.padding,
            return_tensors='pt'
        )

    def __call__(self, batch):
        chosen = [{'input_ids': b['chosen']} for b in batch]
        rejected = [{'input_ids': b['rejected']} for b in batch]
        chosen_padded = self.clt(chosen)
        rejected_padded = self.clt(rejected)

        return {'chosen': chosen_padded, 'rejected': rejected_padded}


rm_data_collator = DataCollatorForRm(tokenizer=tokenizer)

rm_train_dataloader = DataLoader(
    rm_train_split,
    batch_size=rm_batch_size,
    shuffle=True,
    collate_fn=rm_data_collator
)

rm_val_dataloader = DataLoader(
    rm_val_split,
    batch_size=rm_batch_size,
    collate_fn=rm_data_collator,
    shuffle=False
)



In [114]:
def calc_rm_val(batch):
    input_ids = batch['input_ids'].to(device)
    attention_mask = batch['attention_mask'].to(device)
    outputs = model.transformer(input_ids=input_ids, attention_mask=attention_mask)

    # only get the emb for final token
    hidden_states = outputs.last_hidden_state
    final_tok_ind = torch.sum(attention_mask, dim=1) - 1
    batch_inds = torch.arange(hidden_states.shape[0], device=device)
    final_tok_emb = hidden_states[batch_inds, final_tok_ind]

    return reward_model(final_tok_emb)

def calc_rm_loss(ch_rm_val, rej_rm_val):
    return -F.logsigmoid(ch_rm_val - rej_rm_val).mean()

def calc_val_perc_correct(val_batch):
    rej_rm_vals = calc_rm_val(val_batch['rejected'])
    ch_rm_vals = calc_rm_val(val_batch['chosen'])
    correct = (ch_rm_vals > rej_rm_vals).int().sum()
    return (correct / len(rej_rm_vals)).item()


# reward model training loop
reward_model.train()
for epoch in range(1, num_epochs+1):
    step = 0
    for batch in rm_train_dataloader:
        step += 1
        rej = batch['rejected'].to(device)
        ch = batch['chosen'].to(device)
        rm_loss = calc_rm_loss(calc_rm_val(ch), calc_rm_val(rej))

        # calculate avg val dataset % classified correct
        avg_val_per_right = 0
        val_batches = 0
        with torch.no_grad():
            for val_batch in rm_val_dataloader:
                avg_val_per_right += calc_val_perc_correct(val_batch)
                val_batches += 1
            avg_val_per_right /= val_batches

        print(f"epoch: {epoch} | step: {step} | training loss: {rm_loss} | avg val % correct preference: {avg_val_per_right}")
        rm_optimizer.zero_grad()
        rm_loss.backward()
        rm_optimizer.step()

epoch: 1 | step: 1 | training loss: 0.7633222937583923 | avg val % correct preference: 0.58
epoch: 1 | step: 2 | training loss: 0.5859189033508301 | avg val % correct preference: 0.53
epoch: 1 | step: 3 | training loss: 0.7964285016059875 | avg val % correct preference: 0.49
epoch: 1 | step: 4 | training loss: 0.7925732731819153 | avg val % correct preference: 0.56
epoch: 1 | step: 5 | training loss: 0.6669638156890869 | avg val % correct preference: 0.57
epoch: 1 | step: 6 | training loss: 0.48933303356170654 | avg val % correct preference: 0.54
epoch: 1 | step: 7 | training loss: 0.8430061936378479 | avg val % correct preference: 0.65
epoch: 1 | step: 8 | training loss: 0.7581450939178467 | avg val % correct preference: 0.51
epoch: 1 | step: 9 | training loss: 0.6295057535171509 | avg val % correct preference: 0.59
epoch: 1 | step: 10 | training loss: 0.5836320519447327 | avg val % correct preference: 0.51
epoch: 1 | step: 11 | training loss: 0.7567492127418518 | avg val % correct pr

RuntimeError: MPS backend out of memory (MPS allocated: 15.86 GB, other allocations: 20.38 GB, max allowed: 36.27 GB). Tried to allocate 147.24 MB on private pool. Use PYTORCH_MPS_HIGH_WATERMARK_RATIO=0.0 to disable upper limit for memory allocations (may cause system failure).

### GRPO implementation + rl training loop