In [139]:
import os
import sys
import torch
import importlib

from peft import LoraConfig
from dotenv import load_dotenv
from transformers import AutoModelForCausalLM, AutoTokenizer
from datasets import Dataset

In [140]:
module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)

# Load environment variables
load_dotenv()

device = 'cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu'
device

'mps'

In [141]:
import src.train
import src.model
import data.sudoku
import evals.sudoku_eval

importlib.reload(src.train)
importlib.reload(src.model)
importlib.reload(data.sudoku)
importlib.reload(evals.sudoku_eval)

from src.train import sft_train_lora
from src.model import identify_target_modules
from data.sudoku import Sudoku
from evals.sudoku_eval import SudokuPuzzleMetric, compute_sudoku_metrics, eval_baseline_sudoku
from data.format import chat_format_qa_instance, lm_format_qa_instance

In [None]:
dataset = Sudoku(data_file=os.environ['SUDOKU_PATH'])

use_chat_format = True

In [143]:
use_chat_format

True

In [144]:
# Format the dataset using the appropriate format
if use_chat_format:
    MODEL_NAME = "meta-llama/Llama-3.2-1B-Instruct"
    formatted_data = [
        chat_format_qa_instance(example)
        for example in dataset
    ]
else:
    MODEL_NAME = "facebook/opt-125m"
    formatted_data = [
        lm_format_qa_instance(example)
        for example in dataset
    ]

In [145]:
MODEL_NAME

'meta-llama/Llama-3.2-1B-Instruct'

In [None]:
model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, token=os.environ['HF_TOKEN'])
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, token=os.environ['HF_TOKEN'])

# Create a Dataset object with formatted text
dataset = Dataset.from_dict({"chat": formatted_data})
dataset = dataset.map(
    lambda x: {"formatted_text": tokenizer.apply_chat_template(x["chat"], tokenize=False, add_generation_prompt=False)})

config.json:   0%|          | 0.00/877 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/2.47G [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/189 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/54.5k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/9.09M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/296 [00:00<?, ?B/s]

Map:   0%|          | 0/3000000 [00:00<?, ? examples/s]

In [147]:
dataset[0]

{'chat': [{'content': 'Given the Sudoku puzzle 1..5.37..6.3..8.9......98...1.......8761..........6...........7.8.9.76.47...6.312, which has 27 clues and a difficulty rating of 2.2. Please solve for the final arrangement.',
   'role': 'user'},
  {'content': '198543726643278591527619843914735268876192435235486179462351987381927654759864312',
   'role': 'assistant'}],
 'formatted_text': '<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\nCutting Knowledge Date: December 2023\nToday Date: 15 Nov 2024\n\n<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nGiven the Sudoku puzzle 1..5.37..6.3..8.9......98...1.......8761..........6...........7.8.9.76.47...6.312, which has 27 clues and a difficulty rating of 2.2. Please solve for the final arrangement.<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n198543726643278591527619843914735268876192435235486179462351987381927654759864312<|eot_id|>'}

In [148]:
len(dataset)

3000000

In [None]:
sample = dataset[0]
chat = sample['chat']
response = chat[1]
answer = response['content']
wrong_answer = answer.replace("1", "2")

'298543726643278592527629843924735268876292435235486279462352987382927654759864322'

In [None]:
sample

{'chat': [{'content': 'Given the Sudoku puzzle 1..5.37..6.3..8.9......98...1.......8761..........6...........7.8.9.76.47...6.312, which has 27 clues and a difficulty rating of 2.2. Please solve for the final arrangement.',
   'role': 'user'},
  {'content': '198543726643278591527619843914735268876192435235486179462351987381927654759864312',
   'role': 'assistant'}],
 'formatted_text': '<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\nCutting Knowledge Date: December 2023\nToday Date: 15 Nov 2024\n\n<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nGiven the Sudoku puzzle 1..5.37..6.3..8.9......98...1.......8761..........6...........7.8.9.76.47...6.312, which has 27 clues and a difficulty rating of 2.2. Please solve for the final arrangement.<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n198543726643278591527619843914735268876192435235486179462351987381927654759864312<|eot_id|>'}

In [181]:
answer

'198543726643278591527619843914735268876192435235486179462351987381927654759864312'

In [182]:
wrong_answer

'298543726643278592527629843924735268876292435235486279462352987382927654759864322'

In [None]:
sudoku_metrics = SudokuPuzzleMetric()

sample_preds = [
    answer,
    wrong_answer,
]

sample_refs = [
    answer,
    answer,
]

results = sudoku_metrics.compute(sample_preds, sample_refs)
results


{'strict_accuracy': 0.5, 'partial_accuracy': 0.9444444444444444}

: 

In [171]:
target_modules = identify_target_modules(model, name_segment='self_attn')
print(target_modules)

In [172]:
lora_config = LoraConfig(
    target_modules=target_modules,
    r=16,
    lora_alpha=32,
    lora_dropout=0.05,
    bias="none",
)

In [173]:
sft_train_lora(
    base_model=model,
    train_dataset=dataset,
    eval_dataset=dataset,
    tokenizer=tokenizer,
    adapter_name="sft_lora",
    response_template="### Answer:",
    lora_config=lora_config,
)

Map:   0%|          | 0/3000000 [00:00<?, ? examples/s]

Map:   0%|          | 0/3000000 [00:00<?, ? examples/s]

KeyboardInterrupt: 