In [1]:
%%capture
!pip install --upgrade torch transformers datasets accelerate
# !git clone https://github.com/andersonbcdefg/rewardmodeling.git
!cd rewardmodeling && git pull

In [2]:
from functools import partial
from itertools import chain
import torch
from torch import nn
from torch.utils.data import DataLoader
from transformers import AutoModelForSequenceClassification, AutoTokenizer, DefaultDataCollator
from accelerate import Accelerator
from rewardmodeling.data import *

In [3]:
%%capture
# use a deberta-v3 that has already been finetuned on massively-multitask dataset
tokenizer = AutoTokenizer.from_pretrained("microsoft/deberta-v3-base")
model = AutoModelForSequenceClassification.from_pretrained("sileod/deberta-v3-base-tasksource-nli", num_labels=1, ignore_mismatched_sizes=True)
model.gradient_checkpointing_enable()

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
Some weights of DebertaV2ForSequenceClassification were not initialized from the model checkpoint at sileod/deberta-v3-base-tasksource-nli and are newly initialized because the shapes did not match:
- classifier.weight: found shape torch.Size([3, 768]) in the checkpoint and torch.Size([1, 768]) in the model instantiated
- classifier.bias: found shape torch.Size([3]) in the checkpoint and torch.Size([1]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [4]:
# freeze embeddings and first 8 layers. save memory!
def freeze(module):
    for parameter in module.parameters():
        parameter.requires_grad = False
        
freeze(model.deberta.embeddings)
for layer in model.deberta.encoder.layer[:8]:
    freeze(layer)

In [5]:
# Get datasets and tokenize! To save time, I'm using pre-tokenized versions I uploaded to HuggingFace
# datasets = get_combined_datasets()
# data_long, data_short = datasets['long'], datasets['short']
# data_long_tokenized = data_long.map(partial(tokenize_function, tokenizer=tokenizer, max_len=2048), 
#                                     batched=True, remove_columns=data_long.column_names)
# data_short_tokenized = data_short.map(partial(tokenize_function, tokenizer=tokenizer, max_len=1024),
#                                       batched=True, remove_columns=data_short.column_names)
data_long_tokenized = load_dataset("andersonbcdefg/reward-modeling-long-tokenized", split="train")
data_short_tokenized = load_dataset("andersonbcdefg/reward-modeling-short-tokenized", split="train")

Found cached dataset parquet (/root/.cache/huggingface/datasets/andersonbcdefg___parquet/andersonbcdefg--reward-modeling-long-tokenized-5d7ae91cc6f67e21/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec)
Found cached dataset parquet (/root/.cache/huggingface/datasets/andersonbcdefg___parquet/andersonbcdefg--reward-modeling-short-tokenized-6ef67474794d1513/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec)


In [6]:
short_dataloader = torch.utils.data.DataLoader(data_short_tokenized, batch_size=64, pin_memory=True, collate_fn=DefaultDataCollator(), num_workers=1)
long_dataloader = torch.utils.data.DataLoader(data_long_tokenized, batch_size=16, pin_memory=True, collate_fn=DefaultDataCollator(), num_workers=1)
# 2 epochs on shorter data, then finish with longer data so that RM can handle longer sequences!
train_dataloader = chain(short_dataloader, long_dataloader)

In [7]:
# Model, optimizer, scheduler, accelerator
model.to(torch.device("cuda"))
max_lr = 3.0e-5
optimizer = torch.optim.AdamW(model.parameters(), lr=max_lr, betas=(0.9, 0.95), fused=True)
# constant learning rate with warmup
scheduler_kwargs = {
    "max_lr": max_lr,
    "total_steps": len(short_dataloader) + len(long_dataloader) + 10,
    "pct_start": 0.005,
    "div_factor": 10,
    "final_div_factor": 1,
    "anneal_strategy": "linear",
    "three_phase": False
}
scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, **scheduler_kwargs)
accelerator = Accelerator(mixed_precision="bf16")
model, optimizer, train_dataloader, scheduler = accelerator.prepare(
    model, optimizer, train_dataloader, scheduler
)

In [8]:
import wandb
wandb.login(key="23aafc180deb281d60c00aaa76932952ce9fdf38")
wandb.init(
    project="train_reward_model",
    config={
        "max_lr": max_lr,
    }
)

[34m[1mwandb[0m: Currently logged in as: [33mandersonbcdefg[0m. Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


In [None]:
effective_batch_size = 64
sample_count = 0
optimizer_steps = 0
save_every = 500

model.train()
for index, batch in enumerate(train_dataloader):
    pref_ids, pref_mask, dispref_ids, dispref_mask = (
        batch['preferred_input_ids'].to(accelerator.device), # bsz, seq_len
        batch['preferred_attention_masks'].to(accelerator.device),
        batch['dispreferred_input_ids'].to(accelerator.device),
        batch['dispreferred_attention_masks'].to(accelerator.device)
    )
    bsz = pref_ids.shape[0]
    sample_count += bsz
    if sample_count < effective_batch_size:
        with accelerator.no_sync(model):
            pref_rewards, dispref_rewards = model(
                torch.cat([pref_ids, dispref_ids], dim=0), 
                attention_mask=torch.cat([pref_mask, dispref_mask], dim=0)
            ).logits.chunk(2, dim=0)
            micro_batch_loss = -torch.log(torch.sigmoid(pref_rewards.view(-1) - dispref_rewards.view(-1))).sum() / bsz
            accelerator.backward(micro_batch_loss)
            # accelerator.clip_grad_norm_(model.parameters(), 1.0) # -- next time try 1.0
        wandb.log({"micro_batch_loss": micro_batch_loss.item()})
    else:
        pref_rewards, dispref_rewards = model(
            torch.cat([pref_ids, dispref_ids], dim=0), 
            attention_mask=torch.cat([pref_mask, dispref_mask], dim=0)
        ).logits.chunk(2, dim=0)
        micro_batch_loss = -torch.log(torch.sigmoid(pref_rewards.view(-1) - dispref_rewards.view(-1))).sum() / bsz
        accelerator.backward(micro_batch_loss)
        # accelerator.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        optimizer.zero_grad(set_to_none=True)
        sample_count = 0
        optimizer_steps += 1
        wandb.log({"micro_batch_loss": micro_batch_loss.item()})
    
    if (index + 1) % save_every == 0:
        print("Saving checkpoint...")
        accelerator.save_state("/storage")
    
    scheduler.step()    

Saving checkpoint...


In [1]:
unwrapped_model = accelerator.unwrap_model(model)
# state_dict = unwrapped_model.state_dict()
# torch.save(state_dict, "/storage/rm_checkpoint_2.pt")

NameError: name 'accelerator' is not defined

In [10]:
eval_dataset = load_dataset("Anthropic/hh-rlhf", split="test").select(range(1000))
eval_dataset = eval_dataset.map(
    process_anthropic, remove_columns=eval_dataset.column_names
)
eval_dataset_tokenized = eval_dataset.map(partial(tokenize_function, tokenizer=tokenizer, max_len=1024), 
                                     batched=True, remove_columns=eval_dataset.column_names)
eval_dataloader = torch.utils.data.DataLoader(eval_dataset_tokenized, batch_size=8, pin_memory=True, collate_fn=DefaultDataCollator())

Found cached dataset json (/root/.cache/huggingface/datasets/Anthropic___json/Anthropic--hh-rlhf-c8cd8dc58ab67414/0.0.0/e347ab1c932092252e717ff3f949105a4dd28b27e842dd53157d2f72e276c2e4)
Loading cached processed dataset at /root/.cache/huggingface/datasets/Anthropic___json/Anthropic--hh-rlhf-c8cd8dc58ab67414/0.0.0/e347ab1c932092252e717ff3f949105a4dd28b27e842dd53157d2f72e276c2e4/cache-c2868409ed9b2fce.arrow
Loading cached processed dataset at /root/.cache/huggingface/datasets/Anthropic___json/Anthropic--hh-rlhf-c8cd8dc58ab67414/0.0.0/e347ab1c932092252e717ff3f949105a4dd28b27e842dd53157d2f72e276c2e4/cache-672fbf2aaea8f228.arrow


In [11]:
device = torch.device("cuda")
unwrapped_model.load_state_dict(torch.load("/storage/rm_checkpoint.pt"))
unwrapped_model.eval()
results = []
for index, batch in enumerate(eval_dataloader):
    pref_ids, pref_mask, dispref_ids, dispref_mask = (
        batch['preferred_input_ids'].to(device), 
        batch['preferred_attention_masks'].to(device),
        batch['dispreferred_input_ids'].to(device),
        batch['dispreferred_attention_masks'].to(device)
    )
    pref_rewards = model(pref_ids, attention_mask=pref_mask).logits.view(-1)
    dispref_rewards = model(dispref_ids, attention_mask=dispref_mask).logits.view(-1)
    correct = (pref_rewards > dispref_rewards).long()
    results.extend(list(correct.cpu().numpy()))

In [12]:
# best open-soure models tend to get around 0.75 accuracy if i remember correctly
import numpy as np
np.mean(results)

0.65