In [None]:
# In fine_tuning/fine_tune_bert.ipynb

# 1. Install and Import Libraries
# !pip install transformers datasets torch peft accelerate bitsandbytes scikit-learn

import pandas as pd
from datasets import Dataset
from transformers import AutoTokenizer, AutoModelForQuestionAnswering, TrainingArguments, Trainer
from peft import LoraConfig, get_peft_model, TaskType
import sys
import os
# Add project root to path to import our modules
sys.path.append(os.path.abspath(os.path.join(os.getcwd(), '..')))
from rag_system.data_processor import file_content, parse_data
import config

# 2. Dataset Preparation
print("Preparing dataset for fine-tuning...")
questions, answers = parse_data(file_content)

# For extractive QA, we need a single context block.
# We create it by joining all the answers together.
context = " ".join(answers)

# Now, we format the data as required for the QA task
data = []
for q, a in zip(questions, answers):
    answer_start = context.find(a)
    if answer_start != -1:
        data.append({
            "question": q,
            "context": context,
            "answers": {"text": [a], "answer_start": [answer_start]}
        })

# Convert to Hugging Face Dataset
dataset = Dataset.from_list(data)
# Split into training and evaluation sets
dataset = dataset.train_test_split(test_size=0.1)
print(dataset)

# 3. Tokenization and Preprocessing
tokenizer = AutoTokenizer.from_pretrained(config.BASE_QA_MODEL)

def preprocess_function(examples):
    questions = [q.strip() for q in examples["question"]]
    inputs = tokenizer(
        questions,
        examples["context"],
        max_length=512,
        truncation="only_second",
        return_offsets_mapping=True,
        padding="max_length",
    )

    offset_mapping = inputs.pop("offset_mapping")
    answers = examples["answers"]
    start_positions = []
    end_positions = []

    for i, offset in enumerate(offset_mapping):
        answer = answers[i]
        start_char = answer["answer_start"][0]
        end_char = start_char + len(answer["text"][0])
        sequence_ids = inputs.sequence_ids(i)

        # Find the start and end of the context
        idx = 0
        while sequence_ids[idx] != 1:
            idx += 1
        context_start = idx
        while sequence_ids[idx] == 1:
            idx += 1
        context_end = idx - 1

        # If the answer is not fully inside the context, label is (0, 0)
        if offset[context_start][0] > end_char or offset[context_end][1] < start_char:
            start_positions.append(0)
            end_positions.append(0)
        else:
            # Otherwise it's the start and end token positions
            idx = context_start
            while idx <= context_end and offset[idx][0] <= start_char:
                idx += 1
            start_positions.append(idx - 1)

            idx = context_end
            while idx >= context_start and offset[idx][1] >= end_char:
                idx -= 1
            end_positions.append(idx + 1)
            
    inputs["start_positions"] = start_positions
    inputs["end_positions"] = end_positions
    return inputs

tokenized_datasets = dataset.map(preprocess_function, batched=True, remove_columns=dataset["train"].column_names)

# 4. Model Fine-Tuning with LoRA (PEFT)
print("Setting up model for fine-tuning with LoRA...")
model = AutoModelForQuestionAnswering.from_pretrained(config.BASE_QA_MODEL)

# Define LoRA config
lora_config = LoraConfig(
    r=16,
    lora_alpha=32,
    lora_dropout=0.05,
    target_modules=["q_lin", "v_lin"], # Specific to DistilBERT
    bias="none",
    task_type=TaskType.QUESTION_ANS
)

peft_model = get_peft_model(model, lora_config)
peft_model.print_trainable_parameters()

# 5. Training
training_args = TrainingArguments(
    output_dir="./results",
    eval_strategy="epoch",
    learning_rate=2e-5,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    num_train_epochs=5, # Increased epochs for small dataset
    weight_decay=0.01,
)

trainer = Trainer(
    model=peft_model,
    args=training_args,
    train_dataset=tokenized_datasets["train"],
    eval_dataset=tokenized_datasets["test"],
    tokenizer=tokenizer,
)

print("Starting training...")
trainer.train()

# 6. Save the trained adapters
print(f"Saving LoRA adapters to {config.FINE_TUNED_ADAPTER_PATH}")
peft_model.save_pretrained(config.FINE_TUNED_ADAPTER_PATH)
tokenizer.save_pretrained(config.FINE_TUNED_ADAPTER_PATH)

print("✅ Fine-tuning complete.")