# RAFT Fine-Tuning for Dental Chatbot

This notebook fine-tunes Llama 3.1 8B on the RAFT dental dataset using QLoRA.

**Requirements:**
- Google Colab with T4 GPU (free) or A100 (Pro)
- HuggingFace account for model access
- RAFT dataset prepared (train.jsonl, val.jsonl)

**Expected Training Time:**
- T4 (16GB): ~3-5 hours for 15K examples
- A100 (40GB): ~1-2 hours

## 1. Setup & Installation

In [None]:
# Install required packages
!pip install -q torch transformers accelerate bitsandbytes
!pip install -q peft trl datasets
!pip install -q wandb huggingface_hub

In [None]:
# Check GPU
!nvidia-smi

In [None]:
import torch
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")

## 2. Login to HuggingFace

In [None]:
from huggingface_hub import login

# Get your token from: https://huggingface.co/settings/tokens
# Make sure you have accepted Llama 3.1 license at:
# https://huggingface.co/meta-llama/Meta-Llama-3.1-8B-Instruct

login()

## 3. Load RAFT Dataset

In [None]:
import json
from datasets import Dataset

def load_raft_dataset(file_path):
    """Load RAFT dataset from JSONL file."""
    data = []
    with open(file_path, 'r', encoding='utf-8') as f:
        for line in f:
            data.append(json.loads(line))
    return Dataset.from_list(data)

# Upload your dataset files to Colab or mount Google Drive
# from google.colab import drive
# drive.mount('/content/drive')

# Load datasets
TRAIN_PATH = "train.jsonl"  # Update path
VAL_PATH = "val.jsonl"      # Update path

train_dataset = load_raft_dataset(TRAIN_PATH)
val_dataset = load_raft_dataset(VAL_PATH)

print(f"Train examples: {len(train_dataset)}")
print(f"Val examples: {len(val_dataset)}")

In [None]:
# Preview a training example
example = train_dataset[0]
print("Question:", example['question'][:100])
print("\nContext docs:", len(example['context']))
print("\nAnswer:", example['answer'][:200])

## 4. Load Base Model with Quantization

In [None]:
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
import torch

MODEL_NAME = "meta-llama/Meta-Llama-3.1-8B-Instruct"

# 4-bit quantization config
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16
)

# Load model
print(f"Loading {MODEL_NAME}...")
model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    quantization_config=bnb_config,
    device_map="auto",
    trust_remote_code=True
)

# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"

print("Model loaded!")

## 5. Configure QLoRA

In [None]:
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training

# Prepare model for k-bit training
model = prepare_model_for_kbit_training(model)

# QLoRA configuration
lora_config = LoraConfig(
    r=16,                          # Low-rank dimension
    lora_alpha=32,                 # Scaling factor (typically 2Ã—r)
    target_modules=[
        "q_proj", "k_proj", "v_proj", "o_proj",  # Attention
        "gate_proj", "up_proj", "down_proj"       # FFN
    ],
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM"
)

# Apply LoRA
model = get_peft_model(model, lora_config)
model.print_trainable_parameters()

## 6. Format Training Data

In [None]:
def format_raft_prompt(example):
    """
    Format RAFT example into Llama 3.1 Instruct format.
    """
    # Format context documents
    context_parts = []
    for i, doc in enumerate(example['context']):
        context_parts.append(
            f"Document {i+1} ({doc['source']}, p.{doc['page_number']}):\n{doc['content']}"
        )
    context_str = "\n\n".join(context_parts)
    
    # Llama 3.1 Instruct format
    prompt = f"""<|begin_of_text|><|start_header_id|>system<|end_header_id|>

You are a dental education assistant. Answer questions using the provided documents. Cite sources using ##begin_quote## and ##end_quote## markers. If documents don't contain relevant information, say so clearly.<|eot_id|>

<|start_header_id|>user<|end_header_id|>

Question: {example['question']}

Documents:
{context_str}<|eot_id|>

<|start_header_id|>assistant<|end_header_id|>

{example['answer']}<|eot_id|>"""
    
    return prompt

# Test formatting
sample_prompt = format_raft_prompt(train_dataset[0])
print(sample_prompt[:500])
print("...")
print(sample_prompt[-300:])

In [None]:
# Check token count
tokens = tokenizer(sample_prompt, return_tensors="pt")
print(f"Sample prompt tokens: {tokens['input_ids'].shape[1]}")

## 7. Training Configuration

In [None]:
from transformers import TrainingArguments
from trl import SFTTrainer

# Training arguments
training_args = TrainingArguments(
    output_dir="./checkpoints/llama-3.1-8b-dental-raft",
    
    # Training schedule
    num_train_epochs=3,
    per_device_train_batch_size=1,        # Keep low for T4
    per_device_eval_batch_size=1,
    gradient_accumulation_steps=8,        # Effective batch size = 8
    
    # Learning rate
    learning_rate=2e-4,
    lr_scheduler_type="cosine",
    warmup_steps=100,
    
    # Optimization
    optim="paged_adamw_8bit",
    fp16=False,
    bf16=True,                            # Use bf16 on Ampere+ GPUs
    
    # Logging and saving
    logging_steps=10,
    save_steps=500,
    save_total_limit=3,
    
    # Evaluation
    evaluation_strategy="steps",
    eval_steps=500,
    
    # Memory management
    gradient_checkpointing=True,
    max_grad_norm=0.3,
    
    # Misc
    report_to="wandb",                    # Optional: track with W&B
    run_name="dental-raft-llama3.1-8b",
    seed=42,
)

In [None]:
# Optional: Initialize Weights & Biases
import wandb

# wandb.login()  # Uncomment if you want to track with W&B
# wandb.init(project="dental-raft", name="llama3.1-8b-qlora")

## 8. Initialize Trainer

In [None]:
# Initialize SFTTrainer
trainer = SFTTrainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    tokenizer=tokenizer,
    formatting_func=format_raft_prompt,
    max_seq_length=2048,                  # Adjust based on GPU memory
    packing=False,                        # Don't pack multiple examples
)

print("Trainer initialized!")

## 9. Train!

In [None]:
# Check memory before training
print(f"GPU Memory allocated: {torch.cuda.memory_allocated() / 1e9:.2f} GB")
print(f"GPU Memory reserved: {torch.cuda.memory_reserved() / 1e9:.2f} GB")

In [None]:
# Start training
print("Starting training...")
trainer.train()

In [None]:
# Save final model
trainer.save_model("./final_model")
print("Model saved!")

## 10. Test the Fine-tuned Model

In [None]:
def generate_answer(question, context_docs):
    """
    Generate answer using fine-tuned model.
    """
    # Format context
    context_parts = []
    for i, doc in enumerate(context_docs):
        context_parts.append(
            f"Document {i+1} ({doc['source']}, p.{doc['page_number']}):\n{doc['content']}"
        )
    context_str = "\n\n".join(context_parts)
    
    prompt = f"""<|begin_of_text|><|start_header_id|>system<|end_header_id|>

You are a dental education assistant. Answer questions using the provided documents. Cite sources using ##begin_quote## and ##end_quote## markers.<|eot_id|>

<|start_header_id|>user<|end_header_id|>

Question: {question}

Documents:
{context_str}<|eot_id|>

<|start_header_id|>assistant<|end_header_id|>

"""
    
    inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
    
    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=256,
            temperature=0.7,
            top_p=0.9,
            do_sample=True,
            pad_token_id=tokenizer.pad_token_id
        )
    
    response = tokenizer.decode(outputs[0], skip_special_tokens=True)
    
    # Extract assistant response
    if "<|start_header_id|>assistant<|end_header_id|>" in response:
        answer = response.split("<|start_header_id|>assistant<|end_header_id|>")[-1]
    else:
        answer = response[len(prompt):]
    
    return answer.strip()

In [None]:
# Test with a sample from validation set
test_example = val_dataset[0]

print("Question:", test_example['question'])
print("\n" + "="*50 + "\n")

answer = generate_answer(
    test_example['question'],
    test_example['context']
)

print("Generated Answer:")
print(answer)
print("\n" + "="*50 + "\n")
print("Ground Truth:")
print(test_example['answer'][:500])

## 11. Push to HuggingFace Hub

In [None]:
# Merge LoRA weights with base model for easier deployment
from peft import PeftModel

# If you saved and want to reload:
# model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, ...)
# model = PeftModel.from_pretrained(model, "./final_model")

# Merge weights
merged_model = model.merge_and_unload()
print("Weights merged!")

In [None]:
# Push to HuggingFace Hub
HF_USERNAME = "your-username"  # Change this!
MODEL_REPO = f"{HF_USERNAME}/llama-3.1-8b-dental-raft"

# Push model
merged_model.push_to_hub(MODEL_REPO, private=True)
tokenizer.push_to_hub(MODEL_REPO, private=True)

print(f"Model pushed to: https://huggingface.co/{MODEL_REPO}")

## 12. Cleanup

In [None]:
# Clear GPU memory
import gc

del model
del trainer
gc.collect()
torch.cuda.empty_cache()

print("Cleanup complete!")

---

## Next Steps

1. **Download the FAISS index** from your data processing step
2. **Deploy to HuggingFace Spaces** with ZeroGPU:
   - Create a new Space with Gradio SDK
   - Upload `app/app.py` and `app/requirements.txt`
   - Upload your FAISS index files
   - Update `MODEL_NAME` in app.py to point to your pushed model

3. **Test the deployment** and iterate on the prompts if needed