# Open Source RAG Implementation

## Using Hugging Face Transformers and ChromaDB


In [None]:
!nvidia-smi # Check GPU and CUDA version


In [None]:
# Install the dependencies
!pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu124
!pip install --quiet transformers sentence-transformers chromadb langchain-community pypdf


## 1. Setup and Imports


In [None]:
import logging
import torch
from transformers import GPTJForCausalLM, AutoTokenizer, pipeline
from sentence_transformers import SentenceTransformer
from langchain_community.document_loaders.pdf import PyPDFDirectoryLoader
import chromadb
from chromadb.utils import embedding_functions
import hashlib

# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Check for GPU availability
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
logger.info(f"Using device: {DEVICE}")


## 2. Initialize Models and Database


In [None]:
def initialize_models():
    """Initialize the LLM and embedding models"""

    # Initialize GPT-J model and tokenizer
    tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2")
    model = GPTJForCausalLM.from_pretrained(
        "openai-community/gpt2",
        torch_dtype=torch.float16,  # Use float16 for memory efficiency
        # low_cpu_mem_usage=True
    ).to(DEVICE)

    # Create a text generation pipeline
    generator = pipeline(
        "text-generation",
        model=model,
        tokenizer=tokenizer,
        device=DEVICE,
        max_length=2048,
        do_sample=True,
        temperature=0.7,
        top_k=50,
        top_p=0.95,
        repetition_penalty=1.2,
    )

    # Initialize sentence transformer for embeddings
    embedding_model = SentenceTransformer("all-MiniLM-L6-v2")

    return generator, embedding_model


def initialize_chromadb():
    """Initialize ChromaDB"""
    client = chromadb.PersistentClient(path="/content/chromadb")

    # Create or get existing collection
    embedding_function = embedding_functions.SentenceTransformerEmbeddingFunction(
        model_name="all-MiniLM-L6-v2"
    )

    collection = client.get_or_create_collection(
        name="docs",
        embedding_function=embedding_function,
        metadata={"description": "Documents for RAG"},
    )

    return collection


# Initialize
logger.info("Initializing models and database...")
GENERATOR, EMBEDDING_MODEL = initialize_models()
COLLECTION = initialize_chromadb()
logger.info("Initialization complete!")


## 3. Enhanced Prompt Template


In [None]:
ENHANCED_PROMPT_TEMPLATE = """
Context Information:
{context}

User Question: {question}

Instructions: Using the context above, provide a clear and direct answer to the user's question. If the context doesn't contain enough information, acknowledge this. Use bullet points when appropriate.

Answer:
"""


## 4. Utility Functions


In [None]:
def chunk_text(text: str, max_chunk_size: int = 1000) -> list[str]:
    """Chunk a single text into smaller pieces"""
    paragraphs = text.split("\n\n")
    chunks = []
    current_chunk = ""

    for paragraph in paragraphs:
        if len(current_chunk) + len(paragraph) > max_chunk_size and current_chunk:
            chunks.append(current_chunk.strip())
            current_chunk = paragraph
        else:
            current_chunk += " " + paragraph

    if current_chunk:
        chunks.append(current_chunk.strip())

    return chunks


def read_and_process_documents(directory: str, collection):
    """Read PDFs and process them for the database"""
    logger.info(f"Reading documents from {directory}")

    # Load documents
    loader = PyPDFDirectoryLoader(directory)
    documents = loader.load()

    # Process and index documents
    batch_size = 100  # Process in batches

    for i in range(0, len(documents), batch_size):
        batch = documents[i : i + batch_size]

        # Chunk texts
        all_chunks = []
        for doc in batch:
            chunks = chunk_text(doc.page_content)
            all_chunks.extend(chunks)

        # Prepare data for ChromaDB
        ids = [hashlib.sha256(chunk.encode()).hexdigest() for chunk in all_chunks]

        # Add to collection
        collection.add(documents=all_chunks, ids=ids)

        logger.info(f"Processed and indexed batch of {len(all_chunks)} chunks")

    logger.info("Document processing complete")


def prepare_context(query_results) -> str:
    """Prepare context from query results"""
    context_parts = []

    for i, (doc, distance) in enumerate(
        zip(query_results["documents"][0], query_results["distances"][0]), 1
    ):
        relevance_score = 1 - distance  # Convert distance to similarity score
        context_part = f"[Excerpt {i} (Relevance: {relevance_score:.2f})]\n{doc}\n"
        context_parts.append(context_part)

    return "\n".join(context_parts)


def answer_question(question: str, collection) -> str:
    """Process a question and return an answer"""
    # Query ChromaDB
    query_results = collection.query(
        query_texts=[question], n_results=3  # Number of relevant chunks to retrieve
    )

    # Prepare context
    context = prepare_context(query_results)

    # Prepare prompt
    prompt = ENHANCED_PROMPT_TEMPLATE.format(context=context, question=question)

    # Generate response
    response = GENERATOR(prompt, max_length=2048, num_return_sequences=1)

    return response[0]["generated_text"].split("Answer:")[1].strip()


## 5. Interactive Chat Function


In [None]:
def run_interactive_chat(collection):
    """Run an interactive chat session"""
    print("Welcome to the Open Source Chatbot! Type 'quit' to exit.")
    print("Note: This is using GPT-J, responses might take longer than with OpenAI.")

    while True:
        question = input("\nYour question: ").strip()

        if question.lower() in ["quit", "exit", "bye"]:
            print("Thank you for using the Chatbot. Goodbye!")
            break

        if not question:
            print("Please ask a question!")
            continue

        try:
            print("Thinking...")  # Add this because local models might be slower
            answer = answer_question(question, collection)
            print("\nAnswer:", answer)
        except Exception as e:
            logger.error(f"Error processing question: {e}")
            print(
                "I apologize, but I encountered an error processing your question. Please try again."
            )


## 6. Main Execution


In [None]:
docs_directory = "/content/docs"  # Update this to your actual directory path
read_and_process_documents(docs_directory, COLLECTION)

# Start interactive chat
run_interactive_chat(COLLECTION)
