<a href="https://colab.research.google.com/github/Piyumi22/LLMs/blob/main/LLM_fine_Tuning_with_RAG.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
# Install required packages
!pip install -q transformers datasets peft sentence-transformers scikit-learn torch pandas

# Enable GPU acceleration
# Go to Runtime -> Change runtime type -> Select "T4 GPU"
import torch
print(f"GPU available: {torch.cuda.is_available()}")

[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m491.2/491.2 kB[0m [31m6.2 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m363.4/363.4 MB[0m [31m3.0 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m13.8/13.8 MB[0m [31m28.7 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m24.6/24.6 MB[0m [31m35.9 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m883.7/883.7 kB[0m [31m16.9 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m664.8/664.8 MB[0m [31m2.4 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m211.5/211.5 MB[0m [31m4.5 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m56.3/56.3 MB[0m [31m8.3 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

In [None]:
# ======================
# 1. Setup & Imports
# ======================
from datasets import load_dataset, DatasetDict
from transformers import (
    AutoTokenizer,
    AutoModelForSequenceClassification,
    AutoModel,
    DataCollatorWithPadding,
    TrainingArguments,
    Trainer
)
from peft import LoraConfig, get_peft_model
import evaluate
import torch
import numpy as np
from sentence_transformers import SentenceTransformer
from sklearn.neighbors import NearestNeighbors

# ======================
# 2. Data Preparation
# ======================
# Load dataset (Colab-friendly version)
dataset = load_dataset('imdb', split='train[:1000]+test[:1000]')
dataset = DatasetDict({
    'train': dataset.select(range(1000)),
    'validation': dataset.select(range(1000, 2000))
})

# Create knowledge base
knowledge_base = [
    "Positive reviews often contain words like excellent, amazing, wonderful.",
    "Negative reviews often contain words like terrible, awful, disappointing.",
    "Movies with great acting tend to get positive reviews.",
    "Poor cinematography often leads to negative reviews.",
    "Positive reviews frequently mention being entertained or moved.",
    "Negative reviews often complain about plot holes or bad pacing."
]

# ======================
# 3. RAG Setup
# ======================
# Initialize embedding model (faster than full BERT)
embedding_model = SentenceTransformer('all-MiniLM-L6-v2')

# Embed knowledge base
knowledge_embeddings = embedding_model.encode(knowledge_base)

# Create nearest neighbors index
neigh = NearestNeighbors(n_neighbors=2)
neigh.fit(knowledge_embeddings)

def retrieve_relevant_info(text):
    """Retrieve relevant context for input text"""
    query_embedding = embedding_model.encode([text])
    _, indices = neigh.kneighbors(query_embedding)
    return " ".join([knowledge_base[i] for i in indices[0]])

# ======================
# 4. Model Setup
# ======================
model_checkpoint = 'distilbert-base-uncased'
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)

# Define label maps
id2label = {0: "Negative", 1: "Positive"}
label2id = {"Negative": 0, "Positive": 1}

# Load models
base_model = AutoModel.from_pretrained(model_checkpoint)
classifier_model = AutoModelForSequenceClassification.from_pretrained(
    model_checkpoint, num_labels=2, id2label=id2label, label2id=label2id
)

# ======================
# 5. Custom RAG Model
# ======================
class RAGClassifier(torch.nn.Module):
    def __init__(self, base_model, classifier_model):
        super().__init__()
        self.base_model = base_model
        self.classifier_model = classifier_model

    def forward(self, input_ids, attention_mask, retrieved_context=None):
        # Process original text
        text_output = self.base_model(
            input_ids=input_ids,
            attention_mask=attention_mask
        ).last_hidden_state[:, 0, :]  # CLS token

        # Process retrieved context if provided
        if retrieved_context:
            context_inputs = tokenizer(
                retrieved_context,
                return_tensors='pt',
                padding=True,
                truncation=True,
                max_length=512
            ).to(input_ids.device)

            context_output = self.base_model(
                input_ids=context_inputs['input_ids'],
                attention_mask=context_inputs['attention_mask']
            ).last_hidden_state[:, 0, :]

            # Combine features
            combined = torch.cat([text_output, context_output], dim=1)
        else:
            combined = text_output

        # Classification
        logits = self.classifier_model.classifier(combined)
        return logits

# Initialize model
rag_model = RAGClassifier(base_model, classifier_model).to('cuda')

# ======================
# 6. Data Processing
# ======================
def tokenize_with_rag(examples):
    # Retrieve context (batched for efficiency)
    contexts = [retrieve_relevant_info(text) for text in examples["text"]]

    # Tokenize main text
    tokenized = tokenizer(
        examples["text"],
        padding="max_length",
        truncation=True,
        max_length=128,  # Reduced for Colab memory
        return_tensors="np"
    )

    # Add labels and context
    tokenized["labels"] = examples["label"]
    tokenized["retrieved_context"] = contexts
    return tokenized

tokenized_dataset = dataset.map(tokenize_with_rag, batched=True)

# ======================
# 7. Training Setup
# ======================
peft_config = LoraConfig(
    task_type="SEQ_CLS",
    r=4,
    lora_alpha=32,
    lora_dropout=0.01,
    target_modules=['q_lin']
)

model = get_peft_model(rag_model, peft_config)
model.print_trainable_parameters()

# Evaluation metric
accuracy = evaluate.load("accuracy")
def compute_metrics(p):
    predictions, labels = p
    predictions = np.argmax(predictions, axis=1)
    return {"accuracy": accuracy.compute(predictions=predictions, references=labels)}

# Training arguments optimized for Colab
training_args = TrainingArguments(
    output_dir="rag-lora-imdb",
    learning_rate=1e-3,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    num_train_epochs=3,  # Reduced for Colab
    evaluation_strategy="epoch",
    save_strategy="epoch",
    load_best_model_at_end=True,
    report_to="none"
)

# Custom data collator
class RAGDataCollator:
    def __call__(self, features):
        batch = {
            "input_ids": torch.stack([torch.tensor(f["input_ids"]) for f in features]),
            "attention_mask": torch.stack([torch.tensor(f["attention_mask"]) for f in features]),
            "labels": torch.tensor([f["labels"] for f in features]),
            "retrieved_context": [f["retrieved_context"] for f in features]
        }
        return batch

# Trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_dataset["train"],
    eval_dataset=tokenized_dataset["validation"],
    compute_metrics=compute_metrics,
    data_collator=RAGDataCollator()
)

# ======================
# 8. Train & Evaluate
# ======================
trainer.train()

# ======================
# 9. Inference Demo
# ======================
def predict(text):
    context = retrieve_relevant_info(text)
    inputs = tokenizer(text, return_tensors="pt").to('cuda')

    with torch.no_grad():
        logits = model(
            input_ids=inputs['input_ids'],
            attention_mask=inputs['attention_mask'],
            retrieved_context=[context]
        )
        pred = torch.argmax(logits).item()

    print(f"Text: {text}")
    print(f"Retrieved Context: {context}")
    print(f"Prediction: {id2label[pred]}\n")

# Test examples
test_texts = [
    "The acting was phenomenal and the story moved me to tears.",
    "Worst movie I've ever seen, complete waste of time.",
    "The cinematography was beautiful but the plot made no sense."
]

for text in test_texts:
    predict(text)