In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from peft import LoraConfig, TaskType, get_peft_model
from rich import print as rprint
from train import EVAL_DATA, TRAIN_DATA, create_dataset, test_model
from transformers import (
    AutoModelForMultipleChoice,
    AutoTokenizer,
    Trainer,
    TrainingArguments,
)

In [None]:
td_string = "\n----------------\n".join(
    [f"‚ùì Question: {d['question']}\n‚úÖ Context: {d['context']}" for d in TRAIN_DATA]
)
rprint(td_string)
ed_string = "\n----------------\n".join(
    [f"‚ùì Question: {d['question']}\n‚úÖ Context: {d['context']}" for d in EVAL_DATA]
)
rprint(ed_string)

In [None]:
model_name = "distilbert-base-uncased"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForMultipleChoice.from_pretrained(model_name)

lora_config = LoraConfig(
    task_type=TaskType.SEQ_CLS,
    r=8,
    lora_alpha=16,
    lora_dropout=0.1,
    target_modules=["q_lin", "v_lin"],
)
model = get_peft_model(model, lora_config)

In [None]:
train_dataset = create_dataset(TRAIN_DATA, tokenizer)
eval_dataset = create_dataset(EVAL_DATA, tokenizer)

In [None]:
training_args = TrainingArguments(
    output_dir="./results",
    num_train_epochs=15,
    per_device_train_batch_size=2,
    per_device_eval_batch_size=2,
    learning_rate=5e-4,
    logging_steps=10,
    eval_strategy="epoch",
    save_strategy="no",
    report_to="none",
    remove_unused_columns=False,
)
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
)

In [None]:
trainer.train()

In [None]:
print("\nüéâ Testing after training...")
results = [test_model(model, ex, tokenizer) for ex in TRAIN_DATA]

for r in results:
    status = "‚úì" if r["is_correct"] else "‚úó"
    print(f"{status} {r['question']}: {r['predicted']} ({r['confidence']:.1%})")

correct = sum(1 for r in results if r["is_correct"])
print(f"\nüìä Final Score: {correct}/{len(TRAIN_DATA)} correct")
print("‚úÖ Model fine-tuned with LoRA on fictional 'alternate reality' facts!")
print(
    "‚ö†Ô∏è The test data is the same as the training data, so this score is not representative of real-world performance."
)