In [1]:
from datasets import Dataset
import pandas as pd
import torch, csv, json
from trl import SFTTrainer, DataCollatorForCompletionOnlyLM
import bitsandbytes as bnb
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, TrainingArguments
from peft import LoraConfig
from datasets import load_dataset
from dotenv import load_dotenv

In [2]:
load_dotenv()

True

In [4]:
def csv_to_jsonl(csv_path, jsonl_path):
    with open('ttt_prompt.txt', 'r') as file:
        # Read the entire file into a string
        prompt = file.read()
    with open(csv_path, 'r') as csv_file, open(jsonl_path, 'w') as jsonl_file:
        reader = csv.DictReader(csv_file)
        for row in reader:
            jsonl_file.write(
                json.dumps({"prompt": prompt.format(state = row["Game States"]), "completion": f'{row["Optimal Moves"]} \n'}))

In [5]:
# Load the csv file into a pandas DataFrame
df = pd.read_csv('../examples/ttt_data.csv')
csv_to_jsonl('../examples/ttt_data.csv', "data.jsonl")

In [6]:
dataset = load_dataset("json", data_files="data.jsonl", split="train")

Generating train split: 0 examples [00:00, ? examples/s]

In [15]:
print(dataset[39]['prompt'])

You are a tic-tac-toe solver. A tic-tac-toe board is a 3x3 grid. For example

b,o,b
x,b,b
b,b,o

b represents an empty position
o represents a mark by player 1
x represents a mark by player 2

This state can also be represented in one line eg.
bobxbbbbo

The grid is also numbered where each number represents a position on the grid. eg.
1,2,3
4,5,6
7,8,9

a move can thus be represented by mark+number. Here are some examples:
o5 means player 1 marks position 5 on the grid
x1 means player 2 marks positoin 4 on the grid

Your job is to generate the next best move given a tic-tac-toe board state.

You must only answer with mark+number format and nothing else:
o7


Given the following state, what is the next best move?
bbxxoobbb

The next best move is


In [13]:
print(dataset[39]['completion'])

o3 


In [9]:
model_id = "google/gemma-2b"
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16
)

lora_config = LoraConfig(
    r=8,
    target_modules=["q_proj", "o_proj", "k_proj", "v_proj", "gate_proj", "up_proj", "down_proj"],
    task_type="CAUSAL_LM",
)

In [77]:
model = AutoModelForCausalLM.from_pretrained(model_id, quantization_config=bnb_config, device_map='auto')
tokenizer = AutoTokenizer.from_pretrained(model_id)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [None]:
trainer = SFTTrainer(
    model=model,
    train_dataset=dataset,
    args=TrainingArguments(
        per_device_train_batch_size=1,
        gradient_accumulation_steps=4,
        warmup_steps=2,
        #max_steps=30,
        learning_rate=2e-4,
        fp16=True,
        logging_steps=500,
        output_dir="outputs",
        optim="paged_adamw_8bit"
    ),
    peft_config=lora_config,
)
trainer.train()
trainer.save_model("gemma-2b-sft")

In [14]:
model = AutoModelForCausalLM.from_pretrained("gemma-2b-sft", quantization_config=bnb_config, device_map='auto')
tokenizer = AutoTokenizer.from_pretrained("gemma-2b-sft")

inputs = tokenizer(dataset[39]['prompt'], return_tensors="pt", return_attention_mask=False)
outputs = model.generate(**inputs, max_new_tokens=4)

text = tokenizer.batch_decode(outputs)[0]

print(text)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]



<bos>You are a tic-tac-toe solver. A tic-tac-toe board is a 3x3 grid. For example

b,o,b
x,b,b
b,b,o

b represents an empty position
o represents a mark by player 1
x represents a mark by player 2

This state can also be represented in one line eg.
bobxbbbbo

The grid is also numbered where each number represents a position on the grid. eg.
1,2,3
4,5,6
7,8,9

a move can thus be represented by mark+number. Here are some examples:
o5 means player 1 marks position 5 on the grid
x1 means player 2 marks positoin 4 on the grid

Your job is to generate the next best move given a tic-tac-toe board state.

You must only answer with mark+number format and nothing else:
o7


Given the following state, what is the next best move?
bbxxoobbb

The next best move is
o1 

