In [None]:
import google.generativeai as genai
import numpy as np
# from sklearn.metrics.pairwise import cosine_similarity # No longer needed for retrieval with FAISS
from IPython.display import Markdown, display
import os
import faiss # Import FAISS

In [None]:
# --- 1. Configuration and API Key Setup ---
try:
    genai.configure(api_key=os.environ.get("GOOGLE_API_KEY"))
    if not os.environ.get("GOOGLE_API_KEY"):
        print("Warning: GOOGLE_API_KEY environment variable not set. Please set it for API access.")
except Exception as e:
    print(f"Error configuring Gemini API: {e}")
    raise # Re-raise to stop execution if models can't be initialized

In [None]:
# --- 2. Initialize Models ---
GENERATION_MODEL_NAME = 'gemini-1.5-flash-latest'
EMBEDDING_MODEL_NAME = 'text-embedding-004'

try:
    generation_model = genai.GenerativeModel(GENERATION_MODEL_NAME)
    embedding_model = genai.GenerativeModel(EMBEDDING_MODEL_NAME)
    print(f"Gemini models initialized: '{GENERATION_MODEL_NAME}' for generation, '{EMBEDDING_MODEL_NAME}' for embeddings.")
except Exception as e:
    print(f"Error initializing Gemini models: {e}")
    raise

In [None]:
# --- 3. Create a Simple In-Memory Knowledge Base ---
knowledge_base_documents = [
    "The capital of France is Paris. Paris is known for its Eiffel Tower.",
    "The Amazon rainforest is the largest tropical rainforest in the world.",
    "Python is a high-level, general-purpose programming language.",
    "Artificial intelligence (AI) is intelligence demonstrated by machines.",
    "The human heart has four chambers: two atria and two ventricles.",
    "The deepest ocean trench is the Mariana Trench, located in the western Pacific Ocean.",
    "Machine learning is a subset of AI that enables systems to learn from data.",
    "The Earth revolves around the Sun in an elliptical orbit.",
    "Quantum computing uses quantum-mechanical phenomena like superposition and entanglement.",
    "The Great Barrier Reef is the world's largest coral reef system, located off the coast of Queensland, Australia."
]

print(f"Knowledge base loaded with {len(knowledge_base_documents)} documents.")

In [None]:
# --- 4. Embed the Knowledge Base Documents and Build FAISS Index ---

# Global variables to store the FAISS index and the mapping of index to document
faiss_index = None
document_store = [] # This will map the FAISS index to the original documents

print("Starting document embedding process and FAISS index construction...")
embeddings_list = [] # Temporary list to collect all embeddings

for i, doc in enumerate(knowledge_base_documents):
    try:
        embedding_response = embedding_model.embed_content(model=EMBEDDING_MODEL_NAME, content=doc)
        if embedding_response and 'embedding' in embedding_response:
            embedding = embedding_response['embedding']
            embeddings_list.append(embedding)
            document_store.append(doc) # Store the document at the same index
            print(f"Embedded document {i+1}/{len(knowledge_base_documents)}")
        else:
            print(f"Warning: No embedding found for document {i+1}: {doc[:50]}...")
    except Exception as e:
        print(f"Error embedding document {i+1} ('{doc[:50]}...'): {e}")

if not embeddings_list:
    raise ValueError("No embeddings were successfully generated for the knowledge base. Cannot build FAISS index.")

# Convert list of embeddings to a NumPy array (float32 is standard for FAISS)
embeddings_np = np.array(embeddings_list).astype('float32')

# Get the dimensionality of the embeddings
d = embeddings_np.shape[1]

# Initialize a FAISS index. For simplicity, we use IndexFlatL2 for L2 distance (Euclidean).
# For production, consider IndexIVFFlat for better performance with large datasets.
try:
    faiss_index = faiss.IndexFlatL2(d) # L2 distance is common for embeddings
    faiss_index.add(embeddings_np) # Add the embeddings to the index
    print(f"FAISS index built with {faiss_index.ntotal} documents and dimensionality {d}.")
except Exception as e:
    print(f"Error building FAISS index: {e}")
    raise # Stop if FAISS index cannot be built

In [None]:
# --- 5. Retrieval Mechanism (Updated for FAISS) ---
def retrieve_relevant_documents(query_embedding, top_k=2):
    '''
    Retrieves the top_k most relevant documents from the FAISS index
    based on similarity with the query embedding.

    Args:
        query_embedding (np.array): The embedding of the user's query.
        top_k (int): The number of top documents to retrieve.

    Returns:
        list: A list of relevant documents (strings).
    '''
    global faiss_index, document_store

    if faiss_index is None or faiss_index.ntotal == 0:
        print("Warning: FAISS index is not initialized or empty, no documents to retrieve.")
        return []

    try:
        # Reshape query_embedding to be a 2D array (1, d) as required by FAISS search
        query_embedding_faiss = query_embedding.reshape(1, -1).astype('float32')

        # Perform the search: D are distances, I are indices
        # FAISS returns distances (D) and indices (I) of the top_k nearest neighbors
        D, I = faiss_index.search(query_embedding_faiss, top_k)

        retrieved_docs = []
        for idx in I[0]: # I[0] contains the indices for the first (and only) query
            if idx != -1: # FAISS returns -1 for empty slots if k > ntotal
                retrieved_docs.append(document_store[idx]) # Retrieve the actual document
        print(f"Retrieved {len(retrieved_docs)} documents using FAISS.")
        return retrieved_docs
    except Exception as e:
        print(f"Error during FAISS retrieval: {e}")
        return []

In [None]:
# --- 6. Agentic Logic for RAG (No changes here, as it calls the updated retrieval function) ---
def agentic_rag_system(user_query):
    '''
    An agentic RAG system that decides whether to retrieve information
    before generating a response.

    Args:
        user_query (str): The user's input query.

    Returns:
        str: The generated response from the RAG system.
    '''
    print(f"\n--- User Query: '{user_query}' ---")

    # --- Agent's Decision Phase ---
    keywords_for_retrieval = ["what is", "tell me about", "where is", "who is", "explain", "describe", "facts about", "information on"]
    needs_retrieval = any(keyword in user_query.lower() for keyword in keywords_for_retrieval)

    retrieved_context = ""
    if needs_retrieval:
        print("Agent decision: Retrieval is likely needed based on keywords.")
        try:
            query_embedding_response = embedding_model.embed_content(model=EMBEDDING_MODEL_NAME, content=user_query)
            if query_embedding_response and 'embedding' in query_embedding_response:
                query_embedding = query_embedding_response['embedding']
                relevant_docs = retrieve_relevant_documents(np.array(query_embedding))

                if relevant_docs:
                    retrieved_context = "\n".join(relevant_docs)
                    print("\n--- Retrieved Context ---")
                    display(Markdown(f"```text\n{retrieved_context}\n```"))
                else:
                    print("No relevant documents found for retrieval.")
            else:
                print("Warning: Could not get embedding for the user query, skipping retrieval.")

        except Exception as e:
            print(f"Error during query embedding or retrieval: {e}")
            retrieved_context = ""
    else:
        print("Agent decision: Direct generation without retrieval (no relevant keywords detected).")

    # --- Generation Phase ---
    prompt_parts = []
    if retrieved_context:
        prompt_parts.append(f"Here is some relevant information:\n{retrieved_context}\n\n")
    prompt_parts.append(f"Based on the provided information (if any) and your general knowledge, please answer the following question:\nQuestion: {user_query}\n\n")
    prompt_parts.append("Answer:")

    final_prompt = "".join(prompt_parts)
    print(f"\n--- Final Prompt Sent to Gemini ---\n```\n{final_prompt}\n```")

    try:
        response = generation_model.generate_content(final_prompt)

        if response.candidates and response.candidates[0].content:
            generated_text = response.candidates[0].content.parts[0].text
            print("\n--- Generated Response ---")
            display(Markdown(generated_text))
        else:
            print("\n--- No Valid Response from Gemini ---")
            generated_text = "I couldn't generate a valid response for that query."
            if response.prompt_feedback and response.prompt_feedback.block_reason:
                print(f"Blocked reason: {response.prompt_feedback.block_reason.name}")
                generated_text += f"\nReason: {response.prompt_feedback.block_reason.name}"
            elif response.candidates and response.candidates[0].finish_reason:
                print(f"Finish reason: {response.candidates[0].finish_reason.name}")
                generated_text += f"\nFinish Reason: {response.candidates[0].finish_reason.name}"
            display(Markdown(generated_text))

    except Exception as e:
        print(f"\n--- Error Generating Content with Gemini ---")
        print(f"An error occurred: {e}")
        generated_text = "An error occurred while generating the response."
        display(Markdown(generated_text))

    return generated_text

In [None]:
# --- Example Usage ---
print("\n--- Running Example Queries ---")
agentic_rag_system("What is the capital of France?")
agentic_rag_system("Tell me about quantum computing.")
agentic_rag_system("What is the largest tropical rainforest?")
agentic_rag_system("How many chambers does the human heart have?")
agentic_rag_system("What is the deepest ocean trench?")
agentic_rag_system("Tell me about the universe.")
agentic_rag_system("Define artificial intelligence.")
agentic_rag_system("Explain machine learning.")
agentic_rag_system("Tell me about the Big Bang theory.")