<a href="https://colab.research.google.com/github/ajit-ai/DataScience/blob/main/RAG.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
%pip install sentence-transformers faiss-cpu transformers

In [None]:
from sentence_transformers import SentenceTransformer
import faiss
import numpy as np
from transformers import pipeline

# Step 1: Prepare a small knowledge base (list of documents)
documents = [
    "The capital of France is Paris.",
    "France is known for its wine and cheese.",
    "The Eiffel Tower is in Paris, France.",
    "Florida is a state in the USA, and its capital is Tallahassee."
]

# Step 2: Encode documents using Sentence Transformers
encoder = SentenceTransformer('all-MiniLM-L6-v2')
doc_embeddings = encoder.encode(documents, convert_to_numpy=True)

# Step 3: Create a FAISS index for similarity search
dimension = doc_embeddings.shape[1]
index = faiss.IndexFlatL2(dimension)
index.add(doc_embeddings)

# Step 4: Define the query and retrieve relevant documents
query = "What is the capital of France?"
query_embedding = encoder.encode([query])

# Search for top-k relevant documents
k = 2
distances, indices = index.search(query_embedding, k)

# Retrieve the relevant documents
retrieved_docs = [documents[i] for i in indices[0]]
print("Retrieved Documents:", retrieved_docs)

# Step 5: Use a generative model to produce the final answer
generator = pipeline('text-generation', model='distilgpt2')
context = f"Question: {query}\nContext: {' '.join(retrieved_docs)}"
prompt = f"{context}\nAnswer:"
response = generator(prompt, max_length=50, num_return_sequences=1)[0]['generated_text']

# Extract and clean the answer
answer = response.split("Answer:")[1].strip() if "Answer:" in response else response
print("Generated Answer:", answer)

In [None]:
# prompt: generate a RAG codes for text

# Function to perform RAG
def answer_question_with_rag(query, encoder, index, documents, generator):
    """
    Answers a question using the RAG approach.

    Args:
        query (str): The question to answer.
        encoder (SentenceTransformer): The model used to encode documents and queries.
        index (faiss.IndexFlatL2): The FAISS index for similarity search.
        documents (list): The list of documents in the knowledge base.
        generator (pipeline): The generative model.

    Returns:
        str: The generated answer.
    """
    # Step 4: Define the query and retrieve relevant documents
    query_embedding = encoder.encode([query])

    # Search for top-k relevant documents
    k = 2  # You can adjust k
    distances, indices = index.search(query_embedding, k)

    # Retrieve the relevant documents
    retrieved_docs = [documents[i] for i in indices[0]]
    print("Retrieved Documents:", retrieved_docs)

    # Step 5: Use a generative model to produce the final answer
    context = f"Question: {query}\nContext: {' '.join(retrieved_docs)}"
    prompt = f"{context}\nAnswer:"

    # Use a reasonable max_length for the generated text
    response = generator(prompt, max_length=100, num_return_sequences=1)[0]['generated_text']

    # Extract and clean the answer
    # Find the index of "Answer:" and take the text after it
    answer_start_index = response.find("Answer:")
    if answer_start_index != -1:
        answer = response[answer_start_index + len("Answer:"):].strip()
    else:
        # If "Answer:" is not found, take the entire generated text and potentially clean it
        answer = response.strip()
        # Further cleaning to remove potential prompt remnants or incomplete sentences
        # This part might need adjustment based on the generator's typical output
        if '\n' in answer:
            answer = answer.split('\n')[0] # Take the first line if it contains newlines

    return answer

# Example usage of the function
new_query = "Where is the Eiffel Tower?"
generated_answer = answer_question_with_rag(new_query, encoder, index, documents, generator)
print("Generated Answer:", generated_answer)

new_query_2 = "What is the capital of Florida?"
generated_answer_2 = answer_question_with_rag(new_query_2, encoder, index, documents, generator)
print("Generated Answer:", generated_answer_2)