In [1]:
from transformers import LlamaForCausalLM, LlamaTokenizer, Trainer, TrainingArguments
from peft import LoraConfig, get_peft_model
from datasets import load_dataset
import torch

# Load model and tokenizer from local converted directory
model_name = "/Users/puneetkohli/.ollama/models/blobs/sha256-8934d96d3f08982e95922b2b7a2c626a1fe873d7c3b06e8e56d7bc0a1fef9246"  # Replace with absolute path, e.g., "/Users/puneetkohli/llama.cpp/llama2-7b-hf"
tokenizer = LlamaTokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token
model = LlamaForCausalLM.from_pretrained(
    model_name,
    load_in_4bit=True,
    device_map="mps",
    torch_dtype=torch.float16
)

# Apply LoRA
lora_config = LoraConfig(
    r=16,
    lora_alpha=32,
    target_modules=["q_proj", "v_proj"],
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM"
)
model = get_peft_model(model, lora_config)

# Load and preprocess dataset
dataset = load_dataset("json", data_files={"train": "train_split.jsonl", "val": "val_split.jsonl"})
def preprocess(examples):
    inputs = tokenizer(examples["input"], truncation=True, padding="max_length", max_length=512)
    outputs = tokenizer(examples["output"], truncation=True, padding="max_length", max_length=512)
    inputs["labels"] = outputs["input_ids"]
    return inputs
tokenized_dataset = dataset.map(preprocess, batched=True, remove_columns=["input", "output"])

# Training arguments
training_args = TrainingArguments(
    output_dir="./lora_llama2",
    per_device_train_batch_size=1,
    gradient_accumulation_steps=4,
    num_train_epochs=3,
    learning_rate=2e-4,
    fp16=True,
    logging_steps=10,
    save_steps=50,
    evaluation_strategy="steps",
    eval_steps=50,
    use_mps_device=True
)

# Train
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_dataset["train"],
    eval_dataset=tokenized_dataset["val"],
    tokenizer=tokenizer
)
trainer.train()

# Save
model.save_pretrained("./lora_llama2_final")
tokenizer.save_pretrained("./lora_llama2_final")

  from .autonotebook import tqdm as notebook_tqdm


: 