In [None]:
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, TrainerCallback, TrainingArguments, TrainerState, TrainerControl
from trl import SFTTrainer, SFTConfig
from peft import LoraConfig
from datasets import load_dataset

In [None]:
MODEL_ID = "microsoft/Phi-3-mini-4k-instruct"
BITS_AND_BYTES_CONFIG = {
    "load_in_4bit": True, 
    "bnb_4bit_quant_type": "nf4", 
    "bnb_4bit_use_double_quant": True, 
    "bnb_4bit_compute_dtype": "float16"
} 

In [None]:
model_id = "microsoft/Phi-3-mini-4k-instruct"

bnb = BitsAndBytesConfig(**BITS_AND_BYTES_CONFIG)

tok = AutoTokenizer.from_pretrained(model_id, use_fast=True)
#tok.pad_token = tok.eos_token

model = AutoModelForCausalLM.from_pretrained(model_id, quantization_config=bnb, device_map="auto")

In [None]:
print(model)

In [None]:
lora = LoraConfig(
    r=6, 
    lora_alpha=12, 
    lora_dropout=0.1,
    target_modules=["q_proj","k_proj","v_proj","o_proj"],
    task_type="CAUSAL_LM"
    )

#pissa for training stability + faster convergence
#relatively high dropout to prevent overfitting dataset has 1000 samples
#rank 12 to increase training stability -> due to small dataset

In [None]:
ds = load_dataset("BoostedJonP/JeromePowell-SFT")

In [None]:
cfg = SFTConfig(
    output_dir="out/powell-phi3-lora",
    max_length=1024,
    per_device_train_batch_size=1,
    gradient_accumulation_steps=16,
    num_train_epochs=3,
    learning_rate=1.5e-4,
    lr_scheduler_type="cosine",
    warmup_ratio=0.05,
    bf16=False, fp16=True,
    packing=True,
    logging_steps=20,
    save_steps=500,
    save_total_limit=2,
)

In [None]:
trainer = SFTTrainer(
    model=model,
    peft_config=lora,
    train_dataset=ds["train"],
    formatting_func=lambda ex: 
       tok.apply_chat_template(
        [{"role":"user","content":ex["instruction"] + ("\n\n" + ex["input"] if ex.get("input") else "")},
           {"role":"assistant","content":ex["output"]}],
           tokenize=False, add_generation_prompt=False)
    ,
    args=cfg,
)


class PrintLossCallback(TrainerCallback):
    def on_log(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, logs=None, **kwargs):
        if logs is not None and 'loss' in logs:
            print(f"Step {state.global_step}: Loss = {logs['loss']:.4f}")

trainer.add_callback(PrintLossCallback())

In [None]:
trainer.train()