In [None]:
from datasets import load_dataset
from transformers import Trainer, TrainingArguments, DataCollatorWithPadding
from peft import LoraConfig, get_peft_model, TaskType
from src.config import MODEL_NAME, OUTPUT_DIR
from src.model_utils import load_model_and_tokenizer
from src.data_utils import format_data_for_training

# 1. Load Model
model, tokenizer = load_model_and_tokenizer(MODEL_NAME)

# 2. Prepare Data
ds = load_dataset("qiaojin/PubMedQA", "pqa_artificial")
train_dataset = ds['train'].map(
    lambda x: format_data_for_training(x, tokenizer),
    num_proc=4,
    remove_columns=ds['train'].column_names
)

# 3. LoRA Config
lora_config = LoraConfig(
    r=16, lora_alpha=32, target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
    lora_dropout=0.05, bias="none", task_type=TaskType.SEQ_CLS
)
model = get_peft_model(model, lora_config)

# 4. Train
training_args = TrainingArguments(
    output_dir=OUTPUT_DIR, per_device_train_batch_size=4, 
    num_train_epochs=1, learning_rate=2e-4, fp16=True, logging_steps=10
)

trainer = Trainer(
    model=model, args=training_args, train_dataset=train_dataset,
    data_collator=DataCollatorWithPadding(tokenizer)
)
trainer.train()
trainer.save_model(f"{OUTPUT_DIR}/final_adapter")