In [None]:
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForCausalLM, get_scheduler
import torch
from torch.utils.data import DataLoader
from torch.optim import AdamW
from torch import nn
from torch.distributions import Categorical

In [None]:
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/pythia-410m")
dataset = load_dataset("conceptofmind/flan2021_submix_original")
# dataset = dataset.train_test_split(test_size=0.8)
system_prompt = tokenizer.bos_token + "system\n The following is a conversation between user and an AI assistant. " \
                                      "The assistant is helpful, creative, clever, and very friendly.\n" \
                + tokenizer.eos_token

In [None]:
def tokenize_function(example):
    return tokenizer(system_prompt + tokenizer.bos_token + "user: " + example["inputs"] + tokenizer.eos_token +
                     tokenizer.bos_token + "\nassistant: ", truncation=True)

In [None]:
train_dataset = dataset["train"].shuffle(seed=42).select(range(5000))
tokenized_dataset = train_dataset.map(tokenize_function)
tokenized_dataset = tokenized_dataset.remove_columns(["inputs", "targets", "task_source", "task_name",
                                                      "template_type"])
tokenized_dataset.set_format("torch")
train_dataloader = DataLoader(tokenized_dataset, shuffle=True, batch_size=1)

In [None]:
class PolicyModel(nn.Module):
    def __init__(self):
        super(PolicyModel, self).__init__()
        self.base_model = AutoModelForCausalLM.from_pretrained("EleutherAI/pythia-410m")
        self.linear = nn.Linear(self.base_model.config.hidden_size, 1)
        self.reward_layer = nn.Linear(self.base_model.config.hidden_size, 1)
        self.reward_layer.load_state_dict(torch.load("rm-pythia.pt"))

    def generate_response(self, logits, temperature=1.0):
        scaled_logits = logits / (temperature + 0.000000001)
        probabilities = nn.functional.softmax(scaled_logits, dim=-1)
        reshaped_probabilities = probabilities.view(probabilities.size()[0] * probabilities.size()[1],
                                                    probabilities.size()[2])
        m = Categorical(reshaped_probabilities)
        response_tokens = m.sample()
        response = tokenizer.decode(response_tokens)
        return m, response_tokens, response

    def forward(self, input_ids, attention_mask):
        base_model_output = self.base_model(input_ids, attention_mask, output_hidden_states=True)
        value = nn.functional.relu(torch.mean(self.linear(base_model_output.hidden_states[-1].detach())))
        m, response_tokens, response= self.generate_response(base_model_output.logits, temperature=0.7)
        return m, response_tokens, response, value

    def get_reward(self, input_ids, attention_mask):
        base_model_output = self.base_model(input_ids, attention_mask, output_hidden_states=True)
        return torch.mean(self.reward_layer(base_model_output.hidden_states[-1].detach()))

In [None]:
model = PolicyModel()
ref_model = PolicyModel()
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
model.to(device)
optimizer = AdamW(model.base_model.parameters(), lr=4e-6)
value_optimizer = AdamW(model.linear.parameters(), lr=1e-5)
num_epochs = 1
num_training_steps = num_epochs * len(train_dataloader)
lr_scheduler = get_scheduler(
    name="linear", optimizer=optimizer, num_warmup_steps=0, num_training_steps=num_training_steps
)
mse = nn.MSELoss()
beta = 0.02

In [None]:
model.train()
ref_model.eval()
i = 0
running_loss = 0.0
training_loss = []
model_save_path = 'rlhf-pythia.pt'