# LLM Fine-Tuning with LoRA (PEFT)
**Author:** Asma Begum

This notebook demonstrates loading the dataset, applying LoRA, training, and evaluating the model.

## 1. Load Dataset

In [None]:
from datasets import load_dataset
dataset = load_dataset("imdb")

## 2. Tokenization

In [None]:
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")

def tokenize(batch):
    return tokenizer(batch["text"], truncation=True, padding="max_length", max_length=256)

tokenized = dataset.map(tokenize, batched=True)

## 3. LoRA Model Setup

In [None]:
from transformers import AutoModelForSequenceClassification
from peft import LoraConfig, get_peft_model

model = AutoModelForSequenceClassification.from_pretrained("distilbert-base-uncased", num_labels=2)

config = LoraConfig(
    r=8,
    lora_alpha=16,
    lora_dropout=0.1,
    target_modules=["q_lin", "k_lin", "v_lin"],
)

model = get_peft_model(model, config)
model.print_trainable_parameters()

## 4. Training

In [None]:
from transformers import TrainingArguments, Trainer
import evaluate

accuracy = evaluate.load("accuracy")
f1 = evaluate.load("f1")

def compute_metrics(eval_pred):
    logits, labels = eval_pred
    preds = logits.argmax(axis=-1)
    return {
        "accuracy": accuracy.compute(predictions=preds, references=labels),
        "f1": f1.compute(predictions=preds, references=labels)
    }

training_args = TrainingArguments(
    output_dir="./results",
    num_train_epochs=1,
    evaluation_strategy="epoch",
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized["train"].select(range(2000)),
    eval_dataset=tokenized["test"].select(range(1000)),
    compute_metrics=compute_metrics
)

trainer.train()

## 5. Evaluation

In [None]:
trainer.evaluate()

## 6. Example Predictions

In [None]:
import torch

def classify(text):
    inputs = tokenizer(text, return_tensors="pt")
    outputs = model(**inputs)
    pred = torch.argmax(outputs.logits).item()
    return "Positive" if pred == 1 else "Negative"

classify("This movie was fantastic, I loved it!")