In [1]:
from datasets import Dataset
import pandas as pd
import torch, csv, json
from trl import SFTTrainer, DataCollatorForCompletionOnlyLM
import bitsandbytes as bnb
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, TrainingArguments
from peft import LoraConfig
from datasets import load_dataset
from dotenv import load_dotenv
from tqdm import tqdm

In [2]:
load_dotenv()

True

In [3]:
def csv_to_jsonl(csv_path, jsonl_path):
    with open('ttt_prompt.txt', 'r') as file:
        # Read the entire file into a string
        prompt = file.read()
    with open(csv_path, 'r') as csv_file, open(jsonl_path, 'w') as jsonl_file:
        reader = csv.DictReader(csv_file)
        for row in reader:
            jsonl_file.write(
                json.dumps({"prompt": prompt.format(state = row["Game States"]), "completion": f'{row["Optimal Moves"]}'}) + "\n")

In [4]:
# Load the csv file into a pandas DataFrame
df = pd.read_csv('../examples/ttt_data.csv')
csv_to_jsonl('../examples/ttt_data.csv', "data.jsonl")

In [5]:
dataset = load_dataset("json", data_files="data.jsonl", split="train")

Generating train split: 0 examples [00:00, ? examples/s]

In [6]:
dataset[42]['prompt']

'You are a tic-tac-toe solver. A tic-tac-toe board is a 3x3 grid. For example\n\nb,o,b\nx,b,b\nb,b,o\n\nb represents an empty position\no represents a mark by player 1\nx represents a mark by player 2\n\nThis state can also be represented in one line eg.\nbobxbbbbo\n\nThe grid is also numbered where each number represents a position on the grid. eg.\n1,2,3\n4,5,6\n7,8,9\n\na move can thus be represented by mark+number. Here are some examples:\no5 means player 1 marks position 5 on the grid\nx1 means player 2 marks positoin 4 on the grid\n\nYour job is to generate the next best move given a tic-tac-toe board state.\n\nYou must only answer with mark+number format and nothing else eg:\no7\n\n\nGiven the following state, what is the next best move?\nxbboxbbbo\n\nThe next best move is '

In [7]:
dataset[42]['completion']

'o3'

In [8]:
model_id = "google/gemma-2b"
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16
)

lora_config = LoraConfig(
    r=8,
    target_modules=["q_proj", "o_proj", "k_proj", "v_proj", "gate_proj", "up_proj", "down_proj"],
    task_type="CAUSAL_LM",
)

In [9]:
model = AutoModelForCausalLM.from_pretrained(model_id, quantization_config=bnb_config, device_map='auto')
tokenizer = AutoTokenizer.from_pretrained(model_id)

Gemma's activation function should be approximate GeLU and not exact GeLU.
Changing the activation function to `gelu_pytorch_tanh`.if you want to use the legacy `gelu`, edit the `model.config` to set `hidden_activation=gelu`   instead of `hidden_act`. See https://github.com/huggingface/transformers/pull/29402 for more details.


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [10]:
trainer = SFTTrainer(
    model=model,
    train_dataset=dataset,
    args=TrainingArguments(
        per_device_train_batch_size=1,
        gradient_accumulation_steps=4,
        warmup_steps=2,
        #max_steps=30,
        learning_rate=2e-4,
        fp16=True,
        logging_steps=500,
        output_dir="outputs",
        optim="paged_adamw_8bit"
    ),
    peft_config=lora_config,
)
trainer.train()
trainer.save_model("../gemma-2b-sft")



Map:   0%|          | 0/9999 [00:00<?, ? examples/s]

No chat template is set for this tokenizer, falling back to a ChatML template. This is very error-prone, because most models are not trained with a ChatML template!Default chat templates are a legacy feature and will be removed in Transformers v4.43, at which point any code depending on them will stop working. We recommend setting a valid chat template before then to ensure that this model continues working without issues.


Step,Training Loss
500,0.0882
1000,0.0368
1500,0.0347
2000,0.0339
2500,0.0322
3000,0.0302
3500,0.0304
4000,0.0301
4500,0.0294
5000,0.0287


IOPub message rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_msg_rate_limit`.

Current values:
ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
ServerApp.rate_limit_window=3.0 (secs)



In [15]:
model = AutoModelForCausalLM.from_pretrained("../gemma-2b-sft", quantization_config=bnb_config, device_map='auto')
tokenizer = AutoTokenizer.from_pretrained("../gemma-2b-sft")

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [20]:
inputs = tokenizer(dataset[42]['prompt'], return_tensors="pt", return_attention_mask=False)
outputs = model.generate(**inputs, max_new_tokens=5)

text = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
prompt_len =  len(dataset[42]['prompt'])
print(text[prompt_len:])

<strong>x5</strong>.


In [17]:
csv_to_jsonl('../examples/ttt_data_test.csv', "data.jsonl")
dataset = load_dataset("json", data_files="data.jsonl", split='train')

Generating train split: 0 examples [00:00, ? examples/s]

In [18]:
correct = 0
n = 100
for sample in tqdm(dataset.select(range(n))):
    inputs = tokenizer(sample['prompt'], return_tensors="pt", return_attention_mask=False)
    outputs = model.generate(**inputs, max_new_tokens=4)
    text = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
    prompt_len = len(sample['prompt'])
    #print(text[prompt_len:], sample['completion'])
    if sample['completion'] in text[prompt_len:]:
        correct += 1
print(correct/n)

100%|██████████| 100/100 [00:50<00:00,  1.96it/s]

0.73



