In [None]:
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForCausalLM, get_scheduler
import torch
from torch.utils.data import DataLoader
from torch.optim import AdamW
from torch import nn

In [None]:
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/pythia-410m")
dataset = load_dataset("stanfordnlp/SHP")
# dataset = dataset.train_test_split(test_size=0.8)
system_prompt = tokenizer.bos_token + "system\n The following is a conversation between user and an AI assistant. " \
                                      "The assistant is helpful, creative, clever, and very friendly.\n" \
                + tokenizer.eos_token

In [None]:
def tokenize_function(example):
    dictionary = {}
    completion_a = tokenizer(system_prompt + tokenizer.bos_token + "user: " + example["history"] + tokenizer.eos_token +
                           tokenizer.bos_token + "assistant: " + example["human_ref_A"] + tokenizer.eos_token,
                           truncation=True)
    dictionary["input_ids_A"] = completion_a.pop("input_ids")
    dictionary["attention_mask_A"] = completion_a.pop("attention_mask")
    completion_b = tokenizer(system_prompt + tokenizer.bos_token + "user: " + example["history"] + tokenizer.eos_token +
                           tokenizer.bos_token + "assistant: " + example["human_ref_B"] + tokenizer.eos_token,
                           truncation=True)
    dictionary["input_ids_B"] = completion_b.pop("input_ids")
    dictionary["attention_mask_B"] = completion_b.pop("attention_mask")
    return dictionary

In [None]:
train_dataset = dataset["train"].shuffle(seed=42).select(range(10000))
tokenized_dataset = train_dataset.map(tokenize_function)
tokenized_dataset = tokenized_dataset.remove_columns(['post_id', 'domain', 'upvote_ratio', 'history', 'c_root_id_A',
                                                      'c_root_id_B', 'created_at_utc_A', 'created_at_utc_B',
                                                      'human_ref_A', 'human_ref_B', 'labels', 'seconds_difference',
                                                      'score_ratio'])
tokenized_dataset.set_format("torch")
train_dataloader = DataLoader(tokenized_dataset, batch_size=1)

In [None]:
class RewardModel(nn.Module):
    def __init__(self):
        super(RewardModel, self).__init__()
        self.base_model = AutoModelForCausalLM.from_pretrained("EleutherAI/pythia-410m")
        self.linear = nn.Linear(self.base_model.config.hidden_size, 1)

    def forward(self, input_ids, attention_mask):
        base_model_output = self.base_model(input_ids, attention_mask, output_hidden_states=True)
        #We take mean of reward score for every token's last layer hidden state.
        #Another valid startegy can be to only consider reward score for the last token's last layer hidden state.
        return torch.mean(self.linear(base_model_output.hidden_states[-1].detach()))

In [None]:
reward_model = RewardModel()
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
reward_model.to(device)
optimizer = AdamW(reward_model.linear.parameters(), lr=1e-5)
num_epochs = 1
num_training_steps = num_epochs * len(train_dataloader)
lr_scheduler = get_scheduler(
    name="linear", optimizer=optimizer, num_warmup_steps=0, num_training_steps=num_training_steps
)

In [None]:
reward_model.train()
i = 0
running_loss = 0.0
training_loss = []
model_save_path = 'rm-pythia.pt'

In [None]:
for batch in train_dataloader:
    i += 1
    print(i)
    batch = {k: v.to(device) for k, v in batch.items()}
    model_score_A = reward_model(input_ids=batch["input_ids_A"], attention_mask=batch["attention_mask_A"])
    model_score_B = reward_model(input_ids=batch["input_ids_B"], attention_mask=batch["attention_mask_B"])
    if batch["score_A"] > batch["score_B"]:
        loss = -1 * nn.functional.logsigmoid(model_score_A - model_score_B)
    else:
        loss = -1 * nn.functional.logsigmoid(model_score_B - model_score_A)
    loss.backward()
    optimizer.step()
    lr_scheduler.step()
    optimizer.zero_grad()
    running_loss += loss.item()
    if i % 50 == 0:
        print(running_loss / 50)
        training_loss.append(running_loss / 50)
        running_loss = 0.0
        torch.save(reward_model.linear.state_dict(), model_save_path)
        i = 0
torch.save(reward_model.linear.state_dict(), model_save_path)