In [1]:
import torch, os
from tqdm import tqdm
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, GemmaTokenizer
from trl import AutoModelForCausalLMWithValueHead, PPOConfig, PPOTrainer
from dataset import load_data, RubiksDataset
from torch.utils.data import DataLoader
from reward import reward_model_strict
from peft import LoraConfig
import bitsandbytes as bnb

LR = 1.41e-5
BATCH_SIZE = 1
MINI_BATCH_SIZE = 1

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    #bnb_4bit_quant_type="nf4",
    #bnb_4bit_compute_dtype=torch.float16
)

model_id = "google/gemma-2b-it"
#model_id = "vicgalle/gpt2-open-instruct-v1"
config = PPOConfig(
    model_name=model_id,
    learning_rate=LR,
    batch_size=BATCH_SIZE,
    mini_batch_size=MINI_BATCH_SIZE
)

lora_config = LoraConfig(
    r=8,
    target_modules=["q_proj", "o_proj", "k_proj", "v_proj", "gate_proj", "up_proj", "down_proj"],
    task_type="CAUSAL_LM",
)

cuda:0


In [2]:
tokenizer = AutoTokenizer.from_pretrained(config.model_name, token=os.environ['HF_TOKEN'])
tokenizer.pad_token = tokenizer.eos_token

model = AutoModelForCausalLMWithValueHead.from_pretrained(config.model_name, quantization_config=bnb_config, peft_config=lora_config, device_map="auto", token=os.environ['HF_TOKEN'])
#model = AutoModelForCausalLMWithValueHead.from_pretrained(config.model_name,  device_map="auto", token=os.environ['HF_TOKEN'])

optimizer = bnb.optim.Adam8bit(model.parameters(), lr=LR)

Gemma's activation function should be approximate GeLU and not exact GeLU.
Changing the activation function to `gelu_pytorch_tanh`.if you want to use the legacy `gelu`, edit the `model.config` to set `hidden_activation=gelu`   instead of `hidden_act`. See https://github.com/huggingface/transformers/pull/29402 for more details.


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]



In [3]:
#text = "Write me a poem about Machine Learning. Length of the poem should not exceed 10 words."
with open('prompt.txt', 'r') as file:
    text = file.read()
print(text)

You are a Rubik's cube solving assistant. Your job is to generate the next best move
when solving a Rubik's cube when given the a Rubik's cube scramble. A scramble is a list of
moves that are performed on a fully solved Rubik's cube in order to scramble it up. When replying,
you must only reply with a single move.

Below I describe the possible moves:
U (Up): Rotate the upper face 90 degrees clockwise.
U' (Up Prime): Rotate the upper face 90 degrees counter-clockwise.
U2: Rotate the upper face 180 degrees.
D (Down): Rotate the bottom face 90 degrees clockwise.
D' (Down Prime): Rotate the bottom face 90 degrees counter-clockwise.
D2: Rotate the bottom face 180 degrees.
F (Front): Rotate the front face 90 degrees clockwise.
F' (Front Prime): Rotate the front face 90 degrees counter-clockwise.
F2: Rotate the front face 180 degrees.
B (Back): Rotate the back face 90 degrees clockwise.
B' (Back Prime): Rotate the back face 90 degrees counter-clockwise.
B2: Rotate the back face 180 degrees.


In [4]:
device = "cuda:0"
inputs = tokenizer(text, return_tensors="pt").to(device)

outputs = model.generate(**inputs, max_new_tokens=100)
print(tokenizer.decode(outputs[0], skip_special_tokens=True))



You are a Rubik's cube solving assistant. Your job is to generate the next best move
when solving a Rubik's cube when given the a Rubik's cube scramble. A scramble is a list of
moves that are performed on a fully solved Rubik's cube in order to scramble it up. When replying,
you must only reply with a single move.

Below I describe the possible moves:
U (Up): Rotate the upper face 90 degrees clockwise.
U' (Up Prime): Rotate the upper face 90 degrees counter-clockwise.
U2: Rotate the upper face 180 degrees.
D (Down): Rotate the bottom face 90 degrees clockwise.
D' (Down Prime): Rotate the bottom face 90 degrees counter-clockwise.
D2: Rotate the bottom face 180 degrees.
F (Front): Rotate the front face 90 degrees clockwise.
F' (Front Prime): Rotate the front face 90 degrees counter-clockwise.
F2: Rotate the front face 180 degrees.
B (Back): Rotate the back face 90 degrees clockwise.
B' (Back Prime): Rotate the back face 90 degrees counter-clockwise.
B2: Rotate the back face 180 degrees.


In [5]:
data_path = 'datasets/Kociemba_solutions.csv'
data = load_data(data_path)
dataset = RubiksDataset(tokenizer, data)
dataloader = DataLoader(dataset, batch_size=1, shuffle=True)

In [6]:
ppo_trainer = PPOTrainer(
    model=model,
    config=config,
    tokenizer=tokenizer,
    optimizer= optimizer
)



In [None]:
epochs = 1

for epoch in tqdm(range(epochs), "epoch: "):
    for query_tensors, correct_answers in tqdm(dataloader):
        query_tensors = query_tensors.squeeze(1)
        query_tensors = list(torch.unbind(query_tensors, dim=0))
        #### Get response from model
        response_tensors = [ppo_trainer.generate(query_tensor, max_length=570).squeeze(0) for query_tensor in query_tensors]
        responses = [tokenizer.decode(r.squeeze()) for r in response_tensors]

        #### Compute reward score
        rewards = [torch.tensor(reward_model_strict(correct_answer, response), dtype=torch.float16) for correct_answer, response in zip(correct_answers, responses)]

        #### Run PPO step
        stats = ppo_trainer.step(query_tensors, response_tensors, rewards)

        #TODO: log stats
        # ppo_trainer.log_stats(stats, batch, rewards)

#### Save model
ppo_trainer.save_model(f"gemma-2b-it-rlhf-kociemba")

epoch:   0%|          | 0/1 [00:00<?, ?it/s]
  0%|          | 0/1000 [00:00<?, ?it/s][A

 U<eos>


  std_scores = data["scores"].std()
  stats["tokens/queries_len_std"] = torch.std(query_lens).cpu().numpy().item()
  stats["tokens/responses_len_std"] = torch.std(response_lens).cpu().numpy().item()

  0%|          | 1/1000 [00:04<1:18:39,  4.72s/it][A

 U'<eos>



  0%|          | 2/1000 [00:09<1:17:48,  4.68s/it][A

 U<eos>



  0%|          | 3/1000 [00:13<1:16:53,  4.63s/it][A

 U<eos>



  0%|          | 4/1000 [00:18<1:17:10,  4.65s/it][A

 U<eos>



  0%|          | 5/1000 [00:23<1:17:18,  4.66s/it][A

 U'<eos>



  1%|          | 6/1000 [00:27<1:16:52,  4.64s/it][A

 U<eos>



  1%|          | 7/1000 [00:32<1:16:22,  4.61s/it][A

 U<eos>



  1%|          | 8/1000 [00:37<1:16:00,  4.60s/it][A

 U<eos>



  1%|          | 9/1000 [00:41<1:15:43,  4.58s/it][A

 U'<eos>



  1%|          | 10/1000 [00:46<1:15:49,  4.60s/it][A

 U'<eos>



  1%|          | 11/1000 [00:50<1:15:45,  4.60s/it][A

 U'<eos>



  1%|          | 12/1000 [00:55<1:17:09,  4.69s/it][A

 U'<eos>



  1%|▏         | 13/1000 [01:00<1:17:41,  4.72s/it][A

 ................<eos>



  1%|▏         | 14/1000 [01:05<1:16:53,  4.68s/it][A

 
<eos>



  2%|▏         | 15/1000 [01:09<1:16:24,  4.65s/it][A

 
<eos>



  2%|▏         | 16/1000 [01:14<1:15:59,  4.63s/it][A

 
<eos>



  2%|▏         | 17/1000 [01:18<1:15:22,  4.60s/it][A

 
<eos>



  2%|▏         | 18/1000 [01:23<1:15:20,  4.60s/it][A

 ................<eos>



  2%|▏         | 19/1000 [01:27<1:15:05,  4.59s/it][A

 ................<eos>
