In [1]:
import torch, os
from tqdm import tqdm
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, GemmaTokenizer, TrainingArguments
from trl import AutoModelForCausalLMWithValueHead, PPOConfig, PPOTrainer
from torch.utils.data import DataLoader
from reward import reward_model_strict
from peft import LoraConfig
import bitsandbytes as bnb
import csv, json
from dotenv import load_dotenv
from datasets import Dataset, load_dataset

In [2]:
load_dotenv()

True

In [3]:
LR = 1.41e-5
BATCH_SIZE = 1
MINI_BATCH_SIZE = 1

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    #bnb_4bit_quant_type="nf4",
    #bnb_4bit_compute_dtype=torch.float16
)

model_id = "../gemma-2b-sft_old"
#model_id = "vicgalle/gpt2-open-instruct-v1"
config = PPOConfig(
    model_name=model_id,
    learning_rate=LR,
    batch_size=BATCH_SIZE,
    mini_batch_size=MINI_BATCH_SIZE
)

lora_config = LoraConfig(
    r=8,
    target_modules=["q_proj", "o_proj", "k_proj", "v_proj", "gate_proj", "up_proj", "down_proj"],
    task_type="CAUSAL_LM",
)

In [4]:
tokenizer = AutoTokenizer.from_pretrained(config.model_name, token=os.environ['HF_TOKEN'])
tokenizer.pad_token = tokenizer.eos_token

model = AutoModelForCausalLMWithValueHead.from_pretrained(config.model_name, quantization_config=bnb_config, peft_config=lora_config, device_map="auto", token=os.environ['HF_TOKEN'])
#model = AutoModelForCausalLMWithValueHead.from_pretrained(config.model_name,  device_map="auto", token=os.environ['HF_TOKEN'])

optimizer = bnb.optim.Adam8bit(model.parameters(), lr=LR)

Gemma's activation function should be approximate GeLU and not exact GeLU.
Changing the activation function to `gelu_pytorch_tanh`.if you want to use the legacy `gelu`, edit the `model.config` to set `hidden_activation=gelu`   instead of `hidden_act`. See https://github.com/huggingface/transformers/pull/29402 for more details.


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]



In [5]:
def csv_to_jsonl(csv_path, jsonl_path):
    with open('../ttt_prompt.txt', 'r') as file:
        # Read the entire file into a string
        prompt = file.read()
    with open(csv_path, 'r') as csv_file, open(jsonl_path, 'w') as jsonl_file:
        reader = csv.DictReader(csv_file)
        for row in reader:
            jsonl_file.write(
                json.dumps({"prompt": prompt.format(state = row["Game States"]), "completion": f'{row["Optimal Moves"]}'}) + "\n")
            
def tokenize(sample):
    sample["input_ids"] = tokenizer.encode(sample["prompt"])
    return sample

In [6]:
csv_to_jsonl('../examples/ttt_data_ppo_train.csv', "data.jsonl")
dataset = load_dataset("json", data_files="data.jsonl", split='train')
dataset = dataset.map(tokenize, batched=False)
dataset = dataset.with_format("torch")

Generating train split: 0 examples [00:00, ? examples/s]

Map:   0%|          | 0/10000 [00:00<?, ? examples/s]

In [7]:
print(dataset[0])

{'prompt': 'You are a tic-tac-toe solver. A tic-tac-toe board is a 3x3 grid. For example\n\nb,o,b\nx,b,b\nb,b,o\n\nb represents an empty position\no represents a mark by player 1\nx represents a mark by player 2\n\nThis state can also be represented in one line eg.\nbobxbbbbo\n\nThe grid is also numbered where each number represents a position on the grid. eg.\n1,2,3\n4,5,6\n7,8,9\n\na move can thus be represented by mark+number. Here are some examples:\no5 means player 1 marks position 5 on the grid\nx1 means player 2 marks positoin 4 on the grid\n\nYour job is to generate the next best move given a tic-tac-toe board state.\n\nYou must only answer with mark+number format and nothing else eg:\no7\n\n\nGiven the following state, what is the next best move?\nxobxoboxb\n\nThe next best move is ', 'completion': 'o3', 'input_ids': tensor([     2,   2045,    708,    476,  62859, 235290,  33638, 235290,  59771,
         75921, 235265,    586,  62859, 235290,  33638, 235290,  59771,   4924,
  

In [8]:
ppo_trainer = PPOTrainer(
    model=model,
    config=config,
    tokenizer=tokenizer,
    optimizer= optimizer,
    dataset=dataset
)

generation_kwargs = {
    "min_length": -1,
    "max_new_tokens": 4,
    "top_k": 0.0,
    "top_p": 1.0,
    "do_sample": True,
    "pad_token_id": tokenizer.eos_token_id,
}



In [9]:
response_tensors = ppo_trainer.generate(dataset[0]['input_ids'], **generation_kwargs)
print(tokenizer.decode(response_tensors[0], skip_special_tokens=True))



You are a tic-tac-toe solver. A tic-tac-toe board is a 3x3 grid. For example

b,o,b
x,b,b
b,b,o

b represents an empty position
o represents a mark by player 1
x represents a mark by player 2

This state can also be represented in one line eg.
bobxbbbbo

The grid is also numbered where each number represents a position on the grid. eg.
1,2,3
4,5,6
7,8,9

a move can thus be represented by mark+number. Here are some examples:
o5 means player 1 marks position 5 on the grid
x1 means player 2 marks positoin 4 on the grid

Your job is to generate the next best move given a tic-tac-toe board state.

You must only answer with mark+number format and nothing else eg:
o7


Given the following state, what is the next best move?
xobxoboxb

The next best move is 
<strong>o1


In [None]:
epochs = 1

for epoch in range(epochs):
    for sample in tqdm(dataset):
        query = sample['input_ids']
        response_tensor = ppo_trainer.generate(query, return_prompt = False, **generation_kwargs)
        response = tokenizer.decode(response_tensor[0], skip_special_tokens=True)
        correct_answer = sample['completion']
        #### Compute reward score
        reward = 0
        if correct_answer in response:
            reward = 1.0
        #print(response, sample['completion'], reward)
        #### Run PPO step
        stats = ppo_trainer.step([query], [response_tensor[0]], [torch.tensor(reward, dtype=torch.float)])

        #TODO: log stats
        # ppo_trainer.log_stats(stats, batch, rewards)

#### Save model
ppo_trainer.save_model(f"gemma-2b-rlhf-ttt")

  std_scores = data["scores"].std()
  stats["tokens/queries_len_std"] = torch.std(query_lens).cpu().numpy().item()
  stats["tokens/responses_len_std"] = torch.std(response_lens).cpu().numpy().item()
  0%|          | 5/10000 [00:10<5:32:46,  2.00s/it]

In [None]:
query.shape

In [None]:
response_tensor