In [None]:
#paper: https://arxiv.org/pdf/2406.12845

In [1]:
!pip install datasets trl bitsandbytes -qq

In [129]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from datasets import load_dataset
import random

In [3]:
SYSTEM_PROMPT = 'You are Qwen, created by Alibaba Cloud. You are a helpful assistant.'

In [4]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda')

In [5]:
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type='nf4',
    bnb_4bit_compute_dtype=torch.bfloat16,
    bnb_4bit_use_double_quant=True
)

In [None]:
model_name = 'Qwen/Qwen2.5-0.5B-Instruct'
model = AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True, quantization_config=bnb_config).to(device)
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)

In [21]:
dataset = load_dataset('nvidia/HelpSteer2')

In [22]:
dataset

DatasetDict({
    train: Dataset({
        features: ['prompt', 'response', 'helpfulness', 'correctness', 'coherence', 'complexity', 'verbosity'],
        num_rows: 20324
    })
    validation: Dataset({
        features: ['prompt', 'response', 'helpfulness', 'correctness', 'coherence', 'complexity', 'verbosity'],
        num_rows: 1038
    })
})

In [25]:
# DEBUG
for example in dataset['train']:
    messages = [
        {'role': 'system', 'content': SYSTEM_PROMPT},
        {'role': 'user', 'content': example['prompt']},
        {'role': 'assistant', 'content': example['response']}
    ]
    text = tokenizer.apply_chat_template(messages, max_length=1024, padding='max_length',
                                           truncation=True, return_dict=True)
    only_prompt_messages = [
        {'role': 'system', 'content': SYSTEM_PROMPT},
        {'role': 'user', 'content': example['prompt']}
    ]
    prompt_text = tokenizer.apply_chat_template(only_prompt_messages)
    print(text['input_ids'])
    print('==========')
    print(len(prompt_text))
    print(text['input_ids'][len(prompt_text)-1])
    print(example['prompt'])
    break

[151644, 8948, 198, 2610, 525, 1207, 16948, 11, 3465, 553, 54364, 14817, 13, 1446, 525, 264, 10950, 17847, 13, 151645, 198, 151644, 872, 198, 66, 2, 151645, 198, 151644, 77091, 198, 34, 2, 374, 264, 1550, 11591, 11, 1633, 35085, 15473, 4128, 7881, 553, 5100, 438, 949, 315, 1181, 659, 15373, 20162, 13, 1084, 572, 3465, 438, 264, 6481, 10555, 311, 7943, 323, 11554, 264, 8045, 315, 15473, 27317, 343, 1011, 11, 2670, 47596, 11, 15629, 11, 323, 1538, 31405, 13, 356, 2, 374, 15503, 1483, 369, 5515, 3766, 4401, 11, 714, 432, 646, 1083, 387, 1483, 369, 3482, 11, 6371, 11, 323, 1809, 4401, 13, 576, 4128, 374, 6188, 311, 387, 6092, 11, 9767, 11, 323, 11050, 11, 323, 432, 5707, 13402, 448, 264, 9080, 738, 315, 20186, 323, 7375, 369, 4752, 21765, 323, 68211, 8357, 13, 356, 2, 374, 1083, 13570, 1483, 304, 279, 1809, 4401, 4958, 11, 7945, 304, 279, 4401, 315, 3868, 369, 279, 20577, 220, 18, 21, 15, 323, 20577, 3776, 50093, 13, 151645, 198, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 1516

In [26]:
def preproc(example):
    messages = [
        {'role': 'system', 'content': SYSTEM_PROMPT},
        {'role': 'user', 'content': example['prompt']},
        {'role': 'assistant', 'content': example['response']}
    ]
    text = tokenizer.apply_chat_template(messages, max_length=1024, padding='max_length',
                                           truncation=True, return_dict=True)
    target_reward = [example['helpfulness'], example['correctness'],
                     example['coherence'], example['complexity'], example['verbosity']]

    only_prompt_messages = [
        {'role': 'system', 'content': SYSTEM_PROMPT},
        {'role': 'user', 'content': example['prompt']}
    ]
    prompt_len = len(tokenizer.apply_chat_template(only_prompt_messages))

    return {
        'prompt_and_answer': text,
        'target_reward': target_reward,
        'prompt_len': prompt_len
    }

In [27]:
dataset = dataset.map(preproc)

Map:   0%|          | 0/20324 [00:00<?, ? examples/s]

Map:   0%|          | 0/1038 [00:00<?, ? examples/s]

In [28]:
dataset = dataset.select_columns(['prompt_and_answer', 'target_reward', 'prompt_len'])

In [29]:
dataset

DatasetDict({
    train: Dataset({
        features: ['prompt_and_answer', 'target_reward', 'prompt_len'],
        num_rows: 20324
    })
    validation: Dataset({
        features: ['prompt_and_answer', 'target_reward', 'prompt_len'],
        num_rows: 1038
    })
})

In [150]:
def collate_fn(batch):
    input_ids = torch.stack([torch.tensor(item['prompt_and_answer']['input_ids']) for item in batch])
    attn_mask = torch.stack([torch.tensor(item['prompt_and_answer']['attention_mask']) for item in batch])
    target_reward = torch.stack([torch.tensor(item['target_reward'], dtype=torch.float32) for item in batch])
    prompt_len = torch.stack([torch.tensor(item['prompt_len']) for item in batch])
    return {
        'input_ids': input_ids,
        'attn_mask': attn_mask,
        'target_reward': target_reward,
        'prompt_len': prompt_len
    }

In [151]:
batch_size = 2
train_loader = torch.utils.data.DataLoader(dataset['train'], batch_size=batch_size,
                                           shuffle=True, collate_fn=collate_fn)
test_loader = torch.utils.data.DataLoader(dataset['validation'], batch_size=batch_size,
                                          shuffle=False, collate_fn=collate_fn)

In [153]:
for batch in train_loader:
    print(batch['input_ids'].shape, batch['input_ids'].dtype)
    print(batch['attn_mask'].shape, batch['attn_mask'].dtype)
    print(batch['target_reward'], batch['target_reward'].dtype)
    print(batch['prompt_len'].shape, batch['prompt_len'].dtype)
    break

torch.Size([2, 1024]) torch.int64
torch.Size([2, 1024]) torch.int64
tensor([[3., 3., 4., 2., 2.],
        [4., 4., 4., 1., 2.]]) torch.float32
torch.Size([2]) torch.int64


In [154]:
# DEBUG
model.lm_head = nn.Identity()
for batch in train_loader:
    print(model(input_ids=batch['input_ids'].to(device),
                attention_mask=batch['attn_mask'].to(device)).logits.shape)
    break

torch.Size([2, 1024, 896])


In [119]:
class FirstStageArmoRM(nn.Module):
    def __init__(self, model):
        super(FirstStageArmoRM, self).__init__()
        self.model = model
        self.model.lm_head = nn.Identity()
        self.freeze_model()
        self.emb_size = model.config.hidden_size
        self.first_stage_linear = nn.Linear(self.emb_size, 5) # k=5 as we have 5 criteries

    def forward(self, input_ids, attn_mask):
        logits = self.model(input_ids=input_ids, attention_mask=attn_mask).logits
        last_token_logits = logits[:, -1, :].to(torch.float32)
        return self.first_stage_linear(last_token_logits)

    def freeze_model(self):
        for param in self.model.parameters():
            param.requires_grad = False

In [120]:
first_stage_armorm_model = FirstStageArmoRM(model).to(device)

In [121]:
print(sum([p.numel() for p in model.parameters() if p.requires_grad]))
print(sum([p.numel() for p in first_stage_armorm_model.parameters() if p.requires_grad]))

0
4485


In [114]:
first_stage_optim = torch.optim.AdamW(first_stage_armorm_model.parameters(), lr=1e-3)
first_stage_loss = nn.MSELoss()

In [None]:
for epoch in range(10):
    running_loss = 0.0
    first_stage_armorm_model.train()
    for batch in tqdm(train_loader):
        input_ids = batch['input_ids'].to(device)
        attn_mask = batch['attn_mask'].to(device)
        target_reward = batch['target_reward'].to(device)

        reward_logits = first_stage_armorm_model(input_ids=input_ids, attn_mask=attn_mask)
        loss = first_stage_loss(reward_logits, target_reward)
        running_loss += loss.item()

        first_stage_optim.zero_grad()
        loss.backward()
        first_stage_optim.step()

    print(f'Epoch: {epoch}, Train mse loss: {running_loss / len(train_loader)}')

    running_loss = 0.0
    first_stage_armorm_model.eval()
    for batch in tqdm(test_loader):
        input_ids = batch['input_ids'].to(device)
        attn_mask = batch['attn_mask'].to(device)
        target_reward = batch['target_reward'].to(device)

        with torch.no_grad():
            reward_logits = first_stage_armorm_model(input_ids=input_ids, attn_mask=attn_mask)
            loss = first_stage_loss(reward_logits, target_reward)
        running_loss += loss.item()

    print(f'Epoch: {epoch}, Test mse loss: {running_loss / len(test_loader)}')

In [155]:
class SecondStageArmoRM(nn.Module):
    def __init__(self, model):
        super(SecondStageArmoRM, self).__init__()
        self.model = model
        self.freeze_model()
        self.second_stage_gate_layer = nn.Linear(model.emb_size, 5) # k=5 as we have 5 criteries

    def forward(self, input_ids, attn_mask, prompt_len):
        logits = self.model.model(input_ids=input_ids, attention_mask=attn_mask).logits
        last_prompt_token_logits = logits[:, prompt_len-1, :].to(torch.float32)
        last_token_logits = logits[:, -1, :].to(torch.float32)
        last_token_reward_logits = self.model.first_stage_linear(last_token_logits)
        return self.second_stage_gate_layer(last_prompt_token_logits), last_token_reward_logits

    def freeze_model(self):
        for param in self.model.parameters():
            param.requires_grad = False

In [156]:
second_stage_armorm_model = SecondStageArmoRM(first_stage_armorm_model).to(device)

In [157]:
print(sum([p.numel() for p in first_stage_armorm_model.parameters() if p.requires_grad]))
print(sum([p.numel() for p in second_stage_armorm_model.parameters() if p.requires_grad]))

0
4485


In [158]:
class BradleyTerryLoss(nn.Module):
    def __init__(self):
        super(BradleyTerryLoss, self).__init__()
        self.beta = nn.Parameter(torch.ones(1))

    def forward(self, chosen_reward, rejected_reward):
        return -F.logsigmoid(self.beta * (chosen_reward - rejected_reward))

In [172]:
second_stage_optim = torch.optim.AdamW(second_stage_armorm_model.parameters(), lr=1e-3)
second_stage_loss = BradleyTerryLoss().to(device)
second_stage_loss_optim = torch.optim.AdamW(second_stage_loss.parameters(), lr=3e-4)

In [173]:
# TODO: add chosen and rejected into dataset to perform second stage

In [176]:
for epoch in range(10):
    running_loss = 0.0
    second_stage_armorm_model.train()
    for batch in tqdm(train_loader):
        input_ids = batch['input_ids'].to(device)
        attn_mask = batch['attn_mask'].to(device)
        target_reward = batch['target_reward'].to(device)
        prompt_len = batch['prompt_len'].to(device)

        reward_gates, reward_logits = second_stage_armorm_model(input_ids=input_ids, attn_mask=attn_mask, prompt_len=prompt_len)
        reward_gates = F.softmax(reward_gates, dim=-1)
        chosen_reward = torch.sum(reward_gates * reward_logits).to(device)
        rejected_reward = torch.randint(low=0, high=5, size=[1]).to(device) # TODO: change this
        loss = second_stage_loss(chosen_reward, rejected_reward)
        running_loss += loss.item()

        second_stage_optim.zero_grad()
        second_stage_loss_optim.zero_grad()
        loss.backward()
        first_stage_optim.step()
        second_stage_loss_optim.step()

    print(f'Epoch: {epoch}, Train mse loss: {running_loss / len(train_loader)}')

    running_loss = 0.0
    second_stage_armorm_model.eval()
    for batch in tqdm(test_loader):
        input_ids = batch['input_ids'].to(device)
        attn_mask = batch['attn_mask'].to(device)
        target_reward = batch['target_reward'].to(device)
        prompt_len = batch['prompt_len'].to(device)

        with torch.no_grad():
            reward_gates, reward_logits = second_stage_armorm_model(input_ids=input_ids, attn_mask=attn_mask, prompt_len=prompt_len)
            reward_gates = F.softmax(reward_gates, dim=-1)
            chosen_reward = torch.sum(reward_gates * reward_logits).to(device)
            rejected_reward = torch.randint(low=0, high=5, size=[1]).to(device) # TODO: change this
            loss = second_stage_loss(chosen_reward, rejected_reward)
        running_loss += loss.item()

    print(f'Epoch: {epoch}, Test mse loss: {running_loss / len(test_loader)}')

  0%|          | 10/10162 [00:07<2:10:01,  1.30it/s]


KeyboardInterrupt: 