In [None]:
import torch, os
from tqdm import tqdm
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, GemmaTokenizer
from trl import AutoModelForCausalLMWithValueHead, PPOConfig, PPOTrainer
from dataset import load_data, RubiksDataset
from torch.utils.data import DataLoader
from reward import reward_model_strict, reward_R, reward_U
from peft import LoraConfig
import bitsandbytes as bnb

LR = 1.41e-5
BATCH_SIZE = 1
MINI_BATCH_SIZE = 1

In [None]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.float16
)

model_id = "google/gemma-2b-it"
# model_id = "vicgalle/gpt2-open-instruct-v1"
config = PPOConfig(
    model_name=model_id,
    learning_rate=LR,
    batch_size=BATCH_SIZE,
    mini_batch_size=MINI_BATCH_SIZE
)

lora_config = LoraConfig(
    r=8,
    target_modules=["q_proj", "o_proj", "k_proj", "v_proj", "gate_proj", "up_proj", "down_proj"],
    task_type="CAUSAL_LM",
)

In [None]:
tokenizer = AutoTokenizer.from_pretrained(config.model_name, token=os.environ['HF_TOKEN']))
tokenizer.pad_token = tokenizer.eos_token

model = AutoModelForCausalLMWithValueHead.from_pretrained(config.model_name, quantization_config=bnb_config, peft_config=lora_config, device_map="auto", token=os.environ['HF_TOKEN']))
# model = AutoModelForCausalLMWithValueHead.from_pretrained(config.model_name, peft_config=lora_config, device_map="auto", token=os.environ['HF_TOKEN']))
#model = AutoModelForCausalLMWithValueHead.from_pretrained(config.model_name,  device_map="auto", token=os.environ['HF_TOKEN'])

optimizer = bnb.optim.Adam8bit(model.parameters(), lr=LR)

In [None]:
#text = "Write me a poem about Machine Learning. Length of the poem should not exceed 10 words."
with open('prompt.txt', 'r') as file:
    text = file.read()
print(text)

In [None]:
device = "cuda:0"
inputs = tokenizer(text, return_tensors="pt").to(device)

outputs = model.generate(**inputs, max_new_tokens=100)
print(tokenizer.decode(outputs[0], skip_special_tokens=True))

In [None]:
data_path = 'datasets/Kociemba_solutions.csv'
data = load_data(data_path)
dataset = RubiksDataset(tokenizer, data)
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)

In [None]:
ppo_trainer = PPOTrainer(
    model=model,
    config=config,
    tokenizer=tokenizer,
    optimizer= optimizer
)

generation_kwargs = {
    "min_length": -1,
    "top_k": 0.0,
    "top_p": 1.0,
    "do_sample": True,
    "pad_token_id": tokenizer.eos_token_id,
    "max_length": 570
}

In [None]:
epochs = 1

for epoch in tqdm(range(epochs), "epoch: "):
    for query_tensors, correct_answers in tqdm(dataloader):
        query_tensors = query_tensors.squeeze(1)
        query_tensors = [list(torch.unbind(query_tensors, dim=0))[0] for i in range(BATCH_SIZE)]
        #### Get response from model
        response_tensors = [ppo_trainer.generate(query_tensor, **generation_kwargs)[0] for query_tensor in query_tensors]
        responses = [tokenizer.decode(r.squeeze()) for r in response_tensors]

        #### Compute reward score
        rewards = [torch.tensor(reward_R(correct_answer, response), dtype=torch.float16) for correct_answer, response in zip(correct_answers, responses)]
        #### Run PPO step
        stats = ppo_trainer.step(query_tensors, response_tensors, rewards)

        #TODO: log stats
        # ppo_trainer.log_stats(stats, batch, rewards)

#### Save model
ppo_trainer.save_pretrained(f"gemma-2b-it-rlhf-kociemba")

In [None]:
#### Save model
ppo_trainer.save_pretrained(f"gemma-2b-it-rlhf-kociemba")