In [126]:
import os
import sys
import torch
import importlib

from peft import LoraConfig
from dotenv import load_dotenv
from transformers import AutoModelForCausalLM, AutoTokenizer
from datasets import Dataset

In [127]:
module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)

# Load environment variables
load_dotenv()

device = 'cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu'
device

'mps'

In [128]:
import src.train
import src.model
import data.sudoku
import evals.sudoku_eval

importlib.reload(src.train)
importlib.reload(src.model)
importlib.reload(data.sudoku)
importlib.reload(evals.sudoku_eval)

from src.train import sft_train_lora
from src.model import identify_target_modules
from data.sudoku import Sudoku
from evals.sudoku_eval import SudokuPuzzleMetric, compute_sudoku_metrics, eval_baseline_sudoku
from data.format import chat_format_qa_instance, lm_format_qa_instance

In [None]:
dataset = Sudoku(data_file=os.environ['SUDOKU_PATH'])

use_chat_format = False

In [130]:
use_chat_format

False

In [None]:
# Format the dataset using the appropriate format
if use_chat_format:
    MODEL_NAME = "meta-llama/Llfama-2-7b-chat-h"
    formatted_data = [
        chat_format_qa_instance({"question": example["question"], "answer": example["answer"]})
        for example in dataset
    ]
else:
    MODEL_NAME = "facebook/opt-125m"
    formatted_data = [
        lm_format_qa_instance({"question": example["question"], "answer": example["answer"]})
        for example in dataset
    ]

In [None]:
MODEL_NAME

'facebook/opt-125m'

In [None]:
model = AutoModelForCausalLM.from_pretrained(MODEL_NAME)


# Create a Dataset object with formatted text
dataset = Dataset.from_dict({"formatted_text": formatted_data})

In [134]:
dataset[0]

{'formatted_text': '### Question Given the Sudoku puzzle 1..5.37..6.3..8.9......98...1.......8761..........6...........7.8.9.76.47...6.312, which has 27 clues and a difficulty rating of 2.2. Please solve for the final arrangement.\n### Answer: 198543726643278591527619843914735268876192435235486179462351987381927654759864312'}

In [135]:
len(dataset)

3000000

In [None]:
# sample = dataset[0]['input_text']
# response = sample.split(" ### Answer:")
# answer = response[1].strip()
# wrong_answer = response[1].strip().replace("1", "2")

# commenting this part out right now to work on later when evaluating

In [None]:
# print(sample)
# print(answer)
# print(wrong_answer)

Given the Sudoku puzzle 1..5.37..6.3..8.9......98...1.......8761..........6...........7.8.9.76.47...6.312, which has 27 clues and a difficulty rating of 2.2. Please solve for the final arrangement. ### Answer: 198543726643278591527619843914735268876192435235486179462351987381927654759864312
198543726643278591527619843914735268876192435235486179462351987381927654759864312
298543726643278592527629843924735268876292435235486279462352987382927654759864322


In [None]:
# sudoku_metrics = SudokuPuzzleMetric()

# sample_preds = [
#     answer,
#     wrong_answer,
# ]

# sample_refs = [
#     answer,
#     answer,
# ]

# results = sudoku_metrics.compute(sample_preds, sample_refs)
# print(results)


pred 198543726643278591527619843914735268876192435235486179462351987381927654759864312
ref 198543726643278591527619843914735268876192435235486179462351987381927654759864312
ref_parts len 81
pred 298543726643278592527629843924735268876292435235486279462352987382927654759864322
ref 198543726643278591527619843914735268876192435235486179462351987381927654759864312
ref_parts len 81
{'strict_accuracy': 0.5, 'partial_accuracy': 0.9444444444444444}


In [136]:
target_modules = identify_target_modules(model, name_segment='self_attn')
print(target_modules)

In [137]:
lora_config = LoraConfig(
    target_modules=target_modules,
    r=16,
    lora_alpha=32,
    lora_dropout=0.05,
    bias="none",
)

In [None]:
sft_train_lora(
    base_model=model,
    train_dataset=dataset,
    eval_dataset=dataset,
    tokenizer=AutoTokenizer.from_pretrained(MODEL_NAME),
    adapter_name="sft_lora",
    response_template=" ### Answer:",
    lora_config=lora_config,
    use_chat_format=use_chat_format,
)

Map:   0%|          | 0/3000000 [00:00<?, ? examples/s]

Map:   0%|          | 0/3000000 [00:00<?, ? examples/s]



  0%|          | 0/1125000 [00:00<?, ?it/s]

### Answer: 136284957845179362729356481653891724217435698984627513471962835568713249392548176<pad><pad><pad><pad><pad><pad><pad><pad><pad> This instance will be ignored in loss calculation. Note, if this happens often, consider increasing the `max_seq_length`.
### Answer: 452718369761934528983526741397152684245689173618347952526891437879463215134275896<pad><pad><pad><pad> This instance will be ignored in loss calculation. Note, if this happens often, consider increasing the `max_seq_length`.
### Answer: 936721584258643197147859632312495876475186923869372451681537249523914768794268315<pad><pad><pad><pad><pad><pad> This instance will be ignored in loss calculation. Note, if this happens often, consider increasing the `max_seq_length`.
### Answer: 397624581625178493418539276542981637831746952976253148754312869163895724289467315 This instance will be ignored in loss calculation. Note, if this happens often, consider increasing the `max_seq_length`.
### Answer: 53127498686753942192461857375

KeyboardInterrupt: 