In [None]:
!pip install -q ragas

In [None]:
import torch
# from transformers import AutoTokenizer, AutoModelForSeq2SeqLM # No longer needed for generation
from sentence_transformers import SentenceTransformer
from datasets import Dataset
import faiss
import numpy as np
import textwrap
import logging
import os
import time
from typing import List, Dict, Any

In [None]:
# --- Import Google Generative AI Library ---
import google.generativeai as genai

In [None]:
# --- Configuration Management ---
class Config:
    DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

    # Retriever Config
    EMBEDDING_MODEL_NAME = 'all-MiniLM-L6-v2'
    EMBEDDING_BATCH_SIZE = 32
    RETRIEVAL_TOP_K = 3

    # Generator Config (Now for Gemini API)
    GEMINI_API_KEY = os.getenv("GEMINI_API_KEY") # IMPORTANT: Store securely, e.g., environment variable
    GEMINI_MODEL_NAME = "gemini-pro" # Or "gemini-1.5-pro", "gemini-1.0-pro", etc. based on availability
    # GENERATOR_MAX_LENGTH and GENERATOR_NUM_BEAMS are not directly applicable to Gemini API calls in the same way,
    # but you can control response length and safety settings via API parameters.
    GENERATOR_MAX_LENGTH = 256 # Example: Set a default max length for Gemini output

    # Logging Config
    LOG_LEVEL = logging.INFO
    LOG_FILE = "rag_pipeline.log"

    # Data Paths
    KNOWLEDGE_BASE_PATH = "data/documents.txt"

    # FAISS Config
    FAISS_INDEX_TYPE = "IndexFlatL2" # Add FAISS index type to config


# Initialize logging
logging.basicConfig(level=Config.LOG_LEVEL,
                    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
                    handlers=[
                        logging.FileHandler(Config.LOG_FILE),
                        logging.StreamHandler()
                    ])
logger = logging.getLogger(__name__)

In [None]:
# --- 0. Environment Setup (Local dependencies for Retriever) ---
def setup_environment():
    try:
        # import transformers # Removed as generator model is external
        import datasets
        # import accelerate # Accelerate is more for local training/inference optimization
        import faiss
        import sentence_transformers
        import google.generativeai # Check for gemini library
        logger.info("All required libraries are already installed.")
    except ImportError:
        logger.warning("Required libraries not found. Installing them now...")
        # Note: 'accelerate' is not strictly needed if you're not using local HF models with it.
        os.system("pip install datasets faiss-cpu sentence-transformers google-generativeai")
        logger.info("Libraries installed.")

In [None]:
# --- 1. Data Preparation ---
class KnowledgeBase:
    def __init__(self, documents: List[str]):
        self.documents = documents
        self.corpus_dataset = Dataset.from_dict({"text": documents})
        logger.info(f"Corpus size: {len(self.corpus_dataset)} documents")

    @classmethod
    def from_text_list(cls, text_list: List[str]):
        return cls(text_list)

    @classmethod
    def from_file(cls, file_path: str):
        try:
            with open(file_path, 'r', encoding='utf-8') as f:
                documents = [line.strip() for line in f if line.strip()]
            logger.info(f"Loaded {len(documents)} documents from {file_path}")
            return cls(documents)
        except FileNotFoundError:
            logger.error(f"Knowledge base file not found: {file_path}")
            return cls([])
        except Exception as e:
            logger.error(f"Error loading knowledge base from file: {e}")
            return cls([])

In [None]:
# --- 2. Retriever Module (Remains largely the same) ---
class Retriever:
    def __init__(self, config: Config, knowledge_base: KnowledgeBase):
        self.config = config
        self.knowledge_base = knowledge_base
        self.device = self.config.DEVICE
        self.embedding_model = self._load_embedding_model()
        self.index = self._build_faiss_index()

    def _load_embedding_model(self):
        logger.info(f"Loading embedding model: {self.config.EMBEDDING_MODEL_NAME}...")
        try:
            model = SentenceTransformer(self.config.EMBEDDING_MODEL_NAME).to(self.device)
            logger.info("Embedding model loaded successfully.")
            return model
        except Exception as e:
            logger.critical(f"Failed to load embedding model: {e}")
            raise

    def _build_faiss_index(self):
        logger.info("Generating document embeddings...")
        try:
            start_time = time.time()
            document_embeddings = self.embedding_model.encode(
                self.knowledge_base.corpus_dataset["text"],
                batch_size=self.config.EMBEDDING_BATCH_SIZE,
                convert_to_tensor=True,
                show_progress_bar=False, # Set to False for cleaner logs in non-interactive
                device=self.device
            )
            document_embeddings_np = document_embeddings.cpu().numpy()
            logger.info(f"Document embeddings shape: {document_embeddings_np.shape}")
            logger.info(f"Embedding generation time: {time.time() - start_time:.2f} seconds")

            dimension = document_embeddings_np.shape[1]
            if self.config.FAISS_INDEX_TYPE == "IndexFlatL2":
                index = faiss.IndexFlatL2(dimension)
            else:
                logger.warning(f"Unsupported FAISS index type: {self.config.FAISS_INDEX_TYPE}. Falling back to IndexFlatL2.")
                index = faiss.IndexFlatL2(dimension)


            index.add(document_embeddings_np)
            logger.info(f"FAISS index built with {index.ntotal} vectors using {self.config.FAISS_INDEX_TYPE}.")
            return index
        except Exception as e:
            logger.critical(f"Failed to build FAISS index: {e}")
            raise


    def retrieve_documents(self, query: str, k: int = None) -> List[str]:
        if k is None:
            k = self.config.RETRIEVAL_TOP_K

        start_time = time.time()
        try:
            query_embedding = self.embedding_model.encode(query, convert_to_tensor=True, device=self.device)
            query_embedding_np = query_embedding.cpu().numpy().reshape(1, -1)

            distances, indices = self.index.search(query_embedding_np, k)

            retrieved_docs = [self.knowledge_base.corpus_dataset["text"][idx] for idx in indices[0]]
            logger.info(f"Retrieved {len(retrieved_docs)} documents in {time.time() - start_time:.4f} seconds for query: '{query[:50]}...'")
            return retrieved_docs
        except Exception as e:
            logger.error(f"Error during document retrieval for query '{query[:50]}...': {e}")
            return []

In [None]:
# --- 3. Generator Module (Re-written for Gemini API) ---
class Generator:
    def __init__(self, config: Config):
        self.config = config
        self.model = self._load_gemini_model()

    def _load_gemini_model(self):
        logger.info(f"Configuring Gemini API with model: {self.config.GEMINI_MODEL_NAME}...")
        try:
            if not self.config.GEMINI_API_KEY:
                raise ValueError("GEMINI_API_KEY is not set. Please set it as an environment variable.")
            genai.configure(api_key=self.config.GEMINI_API_KEY)
            model = genai.GenerativeModel(self.config.GEMINI_MODEL_NAME)
            logger.info("Gemini model configured successfully.")
            return model
        except Exception as e:
            logger.critical(f"Failed to configure Gemini model: {e}")
            raise

    def generate_answer(self, query: str, retrieved_docs: List[str]) -> str:
        start_time = time.time()

        if not retrieved_docs:
            logger.warning(f"No documents retrieved for query: '{query[:50]}...'. Generating answer without context.")
            prompt = f"Question: {query}\nAnswer:"
        else:
            context = "\n".join([f"- {doc}" for doc in retrieved_docs]) # Format context nicely
            prompt = f"""Based on the following context, answer the question. If the information is not in the context, state that you don't know.

Context:
{context}

Question: {query}
Answer:"""

        try:
            # Call Gemini API
            # You can add generation_config and safety_settings here if needed
            response = self.model.generate_content(
                prompt,
                generation_config=genai.types.GenerationConfig(
                    temperature=0.2,  # Example: lower temperature for more deterministic answers
                    max_output_tokens=self.config.GENERATOR_MAX_LENGTH # This maps roughly to max_length
                )
            )

            # Extract text from the response
            answer = response.text
            logger.info(f"Generated answer in {time.time() - start_time:.4f} seconds for query: '{query[:50]}...'")
            return answer
        except Exception as e:
            logger.error(f"Error during Gemini API call for query '{query[:50]}...': {e}")
            return "An error occurred while generating the answer using the Gemini API."

In [None]:
# --- 4. RAG Pipeline Orchestrator (Remains largely the same) ---
class RAGPipeline:
    def __init__(self, config: Config):
        self.config = config

        documents = [
            "The capital of France is Paris.",
            "Eiffel Tower is located in Paris, France.",
            "The Louvre Museum is a famous art museum in Paris.",
            "Germany is a country in Central Europe.",
            "Berlin is the capital and largest city of Germany.",
            "The Brandenburg Gate is an 18th-century neoclassical monument in Berlin.",
            "Artificial intelligence (AI) is a rapidly advancing field of computer science.",
            "Machine learning is a subset of AI that focuses on algorithms that learn from data.",
            "Deep learning is a subset of machine learning using neural networks with many layers.",
            "Python is a popular programming language for AI and machine learning."
        ]
        self.knowledge_base = KnowledgeBase.from_text_list(documents)

        self.retriever = Retriever(self.config, self.knowledge_base)
        self.generator = Generator(self.config)

    def query(self, query_text: str) -> Dict[str, Any]:
        logger.info(f"Processing query: '{query_text}'")
        pipeline_start_time = time.time()

        retrieval_start_time = time.time()
        retrieved_docs = self.retriever.retrieve_documents(query_text)
        retrieval_end_time = time.time()
        logger.info(f"Retrieval took: {retrieval_end_time - retrieval_start_time:.4f} seconds")

        generation_start_time = time.time()
        answer = self.generator.generate_answer(query_text, retrieved_docs)
        generation_end_time = time.time()
        logger.info(f"Generation took: {generation_end_time - generation_start_time:.4f} seconds")


        pipeline_end_time = time.time()
        logger.info(f"Total pipeline time: {pipeline_end_time - pipeline_start_time:.4f} seconds")

        return {
            "query": query_text,
            "retrieved_documents": retrieved_docs,
            "answer": answer,
            "metrics": {
                "retrieval_time_sec": retrieval_end_time - retrieval_start_time,
                "generation_time_sec": generation_end_time - generation_start_time,
                "total_pipeline_time_sec": pipeline_end_time - pipeline_start_time,
                "num_retrieved_docs": len(retrieved_docs)
            }
        }

In [None]:
# --- 5. Test the RAG System ---
if __name__ == "__main__":
    logger.info("Starting RAG system initialization...")
    try:
        setup_environment() # Call this to ensure google-generativeai is installed

        config = Config()

        # Ensure GEMINI_API_KEY is set before proceeding
        if not config.GEMINI_API_KEY:
            logger.critical("GEMINI_API_KEY environment variable not set. Please set it to run the Gemini-powered generator.")
            exit(1) # Exit if API key is missing


        rag_pipeline = RAGPipeline(config)
        logger.info("RAG system initialized successfully.")

        print("\n--- Testing RAG System ---")

        queries = [
            "What is the capital of France?",
            "Tell me about monuments in Berlin.",
            "What is deep learning?",
            "Which programming language is good for AI?",
            "Where is the Eiffel Tower?",
            "Who won the World Series in 2023?" # Example of out-of-domain query
        ]

        # Store results for Ragas evaluation
        query_results = []

        for i, query_text in enumerate(queries):
            print(f"\nQuery {i+1}: {query_text}")

            result = rag_pipeline.query(query_text)
            query_results.append(result) # Add result to list for Ragas

            print("Retrieved Documents:")
            if result["retrieved_documents"]:
                for j, doc in enumerate(result["retrieved_documents"]):
                    print(f"  {j+1}. {textwrap.fill(doc, width=80)}")
            else:
                print("  No relevant documents retrieved.")

            print(f"Generated Answer: {textwrap.fill(result['answer'], width=80)}")
            print(f"Metrics: {result['metrics']}")
            print("-" * 50)

        # --- Ragas Evaluation ---
        print("\n--- Ragas Evaluation ---")
        ragas_dataset = prepare_ragas_dataset(query_results)
        ragas_scores = evaluate_ragas_metrics(ragas_dataset)

        print("\nRagas Scores:")
        display(ragas_scores)


    except Exception as e:
        logger.critical(f"RAG system encountered a critical error during initialization or testing: {e}")

In [None]:
!pip install -q ragas datasets

In [None]:
from datasets import Dataset

def prepare_ragas_dataset(query_results: List[Dict[str, Any]]) -> Dataset:
    """
    Prepares a Dataset object for Ragas evaluation from the RAG pipeline results.

    Args:
        query_results: A list of dictionaries, where each dictionary is the output
                       of the RAGPipeline.query() method.

    Returns:
        A Dataset object compatible with Ragas.
    """
    questions = [result["query"] for result in query_results]
    answers = [result["answer"] for result in query_results]
    contexts = [result["retrieved_documents"] for result in query_results]

    # Note: Ragas metrics like faithfulness and answer relevancy ideally require
    # 'ground_truths'. Since we don't have a predefined dataset with ground truths
    # in this example, we will omit them. Some metrics like context relevancy
    # can still be calculated without ground_truths.
    # If you have ground truths, add them like this:
    # ground_truths = [result["ground_truth"] for result in query_results]
    # data = {'question': questions, 'answer': answers, 'contexts': contexts, 'ground_truths': ground_truths}

    data = {'question': questions, 'answer': answers, 'contexts': contexts}

    return Dataset.from_dict(data)

In [None]:
from ragas import evaluate
from ragas.metrics import (
    answer_relevancy,
    context_relevancy,
    faithfulness,
)

def evaluate_ragas_metrics(ragas_dataset: Dataset):
    """
    Evaluates Ragas metrics on a given dataset.

    Args:
        ragas_dataset: A Dataset object compatible with Ragas.

    Returns:
        A dictionary containing the calculated Ragas scores.
    """
    logger.info("Calculating Ragas metrics...")
    metrics = [
        context_relevancy,
        answer_relevancy,
        faithfulness,
        # Add other metrics here if needed
    ]

    # Note: Faithfulness and Answer Relevancy require a generator model.
    # Ragas can use various models, including OpenAI and Hugging Face.
    # For this example, we'll assume Ragas can infer a suitable model or
    # you might need to configure it based on your Ragas setup.
    # If you need to specify a model, you might do it like this:
    # from ragas.llms import OpenAI
    # openai_model = OpenAI(api_key="YOUR_OPENAI_API_KEY")
    # result = evaluate(ragas_dataset, metrics=metrics, llm=openai_model)
    # However, Ragas often works out-of-the-box with default models if available.

    result = evaluate(ragas_dataset, metrics=metrics)


    logger.info("Ragas evaluation complete.")
    return result