In [20]:
import os
import sys
import torch
import importlib

from peft import LoraConfig
from dotenv import load_dotenv
from transformers import AutoModelForCausalLM, AutoTokenizer
from datasets import Dataset

In [21]:
module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)

# Load environment variables
load_dotenv()

device = 'cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu'
device

'mps'

In [22]:
import src.train
import src.model
import data.sudoku
import evals.sudoku_eval

importlib.reload(src.train)
importlib.reload(src.model)
importlib.reload(data.sudoku)
importlib.reload(evals.sudoku_eval)

from src.train import sft_train_lora
from src.model import identify_target_modules
from data.sudoku import Sudoku
from evals.sudoku_eval import SudokuPuzzleMetric, compute_sudoku_metrics, eval_baseline_sudoku
from data.format import chat_format_qa_instance, lm_format_qa_instance

In [23]:
dataset = Sudoku(data_file=os.environ['SUDOKU_PATH'])

use_chat_format = False

In [24]:
use_chat_format

False

In [26]:
# Format the dataset using the appropriate format
if use_chat_format:
    MODEL_NAME = "meta-llama/Llfama-2-7b-chat-h"
    formatted_data = [
        chat_format_qa_instance(example)
        for example in dataset
    ]
else:
    MODEL_NAME = "facebook/opt-125m"
    formatted_data = [
        lm_format_qa_instance(example)
        for example in dataset
    ]

In [27]:
MODEL_NAME

'facebook/opt-125m'

In [28]:
model = AutoModelForCausalLM.from_pretrained(MODEL_NAME)


# Create a Dataset object with formatted text
dataset = Dataset.from_dict({"formatted_text": formatted_data})

In [9]:
dataset[0]

{'formatted_text': '### Question Given the Sudoku puzzle 1..5.37..6.3..8.9......98...1.......8761..........6...........7.8.9.76.47...6.312, which has 27 clues and a difficulty rating of 2.2. Please solve for the final arrangement.\n### Answer: 198543726643278591527619843914735268876192435235486179462351987381927654759864312'}

In [10]:
len(dataset)

3000000

In [11]:
# sample = dataset[0]['input_text']
# response = sample.split(" ### Answer:")
# answer = response[1].strip()
# wrong_answer = response[1].strip().replace("1", "2")

# commenting this part out right now to work on later when evaluating

In [12]:
# print(sample)
# print(answer)
# print(wrong_answer)

In [13]:
# sudoku_metrics = SudokuPuzzleMetric()

# sample_preds = [
#     answer,
#     wrong_answer,
# ]

# sample_refs = [
#     answer,
#     answer,
# ]

# results = sudoku_metrics.compute(sample_preds, sample_refs)
# print(results)


In [14]:
target_modules = identify_target_modules(model, name_segment='self_attn')
print(target_modules)

['model.decoder.layers.0.self_attn.k_proj', 'model.decoder.layers.0.self_attn.v_proj', 'model.decoder.layers.0.self_attn.q_proj', 'model.decoder.layers.0.self_attn.out_proj', 'model.decoder.layers.1.self_attn.k_proj', 'model.decoder.layers.1.self_attn.v_proj', 'model.decoder.layers.1.self_attn.q_proj', 'model.decoder.layers.1.self_attn.out_proj', 'model.decoder.layers.2.self_attn.k_proj', 'model.decoder.layers.2.self_attn.v_proj', 'model.decoder.layers.2.self_attn.q_proj', 'model.decoder.layers.2.self_attn.out_proj', 'model.decoder.layers.3.self_attn.k_proj', 'model.decoder.layers.3.self_attn.v_proj', 'model.decoder.layers.3.self_attn.q_proj', 'model.decoder.layers.3.self_attn.out_proj', 'model.decoder.layers.4.self_attn.k_proj', 'model.decoder.layers.4.self_attn.v_proj', 'model.decoder.layers.4.self_attn.q_proj', 'model.decoder.layers.4.self_attn.out_proj', 'model.decoder.layers.5.self_attn.k_proj', 'model.decoder.layers.5.self_attn.v_proj', 'model.decoder.layers.5.self_attn.q_proj', 

In [15]:
lora_config = LoraConfig(
    target_modules=target_modules,
    r=16,
    lora_alpha=32,
    lora_dropout=0.05,
    bias="none",
)

In [16]:
sft_train_lora(
    base_model=model,
    train_dataset=dataset,
    eval_dataset=dataset,
    tokenizer=AutoTokenizer.from_pretrained(MODEL_NAME),
    adapter_name="sft_lora",
    response_template="### Answer:",
    lora_config=lora_config,
    use_chat_format=use_chat_format,
)

Map:   0%|          | 0/3000000 [00:00<?, ? examples/s]

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


KeyboardInterrupt: 