# RAG Question-Answering System with Groq and Milvus Lite

This notebook implements a RAG (Retrieval-Augmented Generation) system using Groq for LLM inference and Milvus Lite for vector storage. The system processes PDF documents, stores embeddings, and answers questions based on the document content.

In [None]:
!pip install PyPDF2 groq gradio nltk sentence-transformers pymilvus

In [None]:
# Import required libraries
import PyPDF2
import groq
import gradio as gr
from typing import List, Dict
from sentence_transformers import SentenceTransformer
from pymilvus import MilvusClient
import nltk
nltk.download('punkt')
from nltk.tokenize import sent_tokenize

## Document Processor

This class handles PDF reading and text chunking operations.

In [3]:
class DocumentProcessor:
    def __init__(self):
        nltk.download('punkt', quiet=True)

    def read_pdf(self, file_path: str) -> str:
        with open(file_path, 'rb') as file:
            reader = PyPDF2.PdfReader(file)
            text = ' '.join([page.extract_text() for page in reader.pages])
        return text

    def chunk_text(self, text: str, chunk_size: int = 1024, overlap_size: int = 100) -> List[str]:
        """
        Split text into chunks with minimal overlap and clear boundaries.

        Args:
            text (str): Input text to be chunked
            chunk_size (int): Target size for each chunk in characters
            overlap_size (int): Number of characters to overlap between chunks

        Returns:
            List[str]: List of text chunks
        """
        if not text:
            return []

        # Clean and normalize text
        text = ' '.join(text.split())  # Normalize whitespace
        sentences = sent_tokenize(text)

        chunks = []
        current_chunk = []
        current_length = 0

        for sentence in sentences:
            sentence = sentence.strip()
            sentence_length = len(sentence)

            # If adding this sentence would exceed chunk size
            if current_length + sentence_length > chunk_size and current_chunk:
                # Store the current chunk
                chunks.append(' '.join(current_chunk))

                # Find sentences to keep for overlap
                overlap_text = ' '.join(current_chunk)[-overlap_size:]
                overlap_sentences = [s for s in current_chunk
                                  if s in overlap_text][-2:]  # Keep up to 2 sentences

                # Start new chunk with overlap sentences
                current_chunk = overlap_sentences + [sentence]
                current_length = sum(len(s) for s in current_chunk)
            else:
                current_chunk.append(sentence)
                current_length += sentence_length

        # Add the last chunk if it exists
        if current_chunk:
            chunks.append(' '.join(current_chunk))

        # Post-process chunks to ensure quality
        processed_chunks = []
        for chunk in chunks:
            # Remove any chunks that are too small
            if len(chunk) >= chunk_size / 4:  # Minimum chunk size threshold
                # Clean up chunk
                chunk = chunk.strip()
                if chunk:
                    processed_chunks.append(chunk)

        return processed_chunks

    def get_chunk_stats(self, chunks: List[str]) -> Dict:
        """
        Get statistics about the chunks for validation.

        Args:
            chunks (List[str]): List of text chunks

        Returns:
            Dict: Statistics about the chunks
        """
        if not chunks:
            return {
                "num_chunks": 0,
                "avg_chunk_size": 0,
                "min_chunk_size": 0,
                "max_chunk_size": 0
            }

        chunk_sizes = [len(chunk) for chunk in chunks]
        return {
            "num_chunks": len(chunks),
            "avg_chunk_size": sum(chunk_sizes) / len(chunks),
            "min_chunk_size": min(chunk_sizes),
            "max_chunk_size": max(chunk_sizes)
        }

## RAG System

This class implements the main RAG system functionality, including vector storage, embedding generation, and question answering.

In [4]:
from pymilvus import CollectionSchema, FieldSchema, DataType


In [5]:
class RAGSystem:
    def __init__(self):
        try:
            # Existing initialization code remains the same
            self.embedding_model = SentenceTransformer("BAAI/bge-small-en-v1.5", trust_remote_code=True)
            self.embedding_dim = 384
            self.doc_processor = DocumentProcessor()
            self.client = MilvusClient('milvus_demo.db') #This is the database name
            self.collection_name = "documents"
            self.setup_vector_store()
            self.groq_client = groq.Client(api_key="INSERT API KEY HERE")
            self.document_loaded = False
            self.chat_history = []
            self.initialize_chat()

            # Add storage for last retrieved chunks
            self.last_retrieved_chunks = []

        except Exception as e:
            print(f"Initialization error: {e}")
            raise

    def initialize_chat(self):
        """Initialize chat history with system message"""
        self.chat_history = [{
            "role": "system",
            "content": """You are a helpful AI assistant specialized in answering questions based on provided document contexts.
            Use the following contexts to answer the question accurately. If you cannot find the relevant information in the contexts,
            say so honestly. Base your answer solely on the provided contexts while maintaining a natural conversational tone."""
        }]

    def setup_vector_store(self):
        """Set up the vector store with proper schema and index"""
        try:
            # Drop existing collection if it exists
            try:
                if self.collection_name in self.client.list_collections():
                    self.client.drop_collection(self.collection_name)
            except Exception as e:
                print(f"Warning during collection check/drop: {e}")

            # Define collection schema
            fields = [
                FieldSchema(name="id", dtype=DataType.INT64, is_primary=True, auto_id=False),
                FieldSchema(name="text", dtype=DataType.VARCHAR, max_length=65535),
                FieldSchema(name="embedding", dtype=DataType.FLOAT_VECTOR, dim=self.embedding_dim)
            ]
            schema = CollectionSchema(fields=fields)

            # Create collection
            self.client.create_collection(
                collection_name=self.collection_name,
                schema=schema
            )

            # Create index for similarity search
            index_params = {
                "field_name": "embedding",
                "index_type": "HNSW",
                "metric_type": "COSINE",
                "params": {
                    "M": 8,
                    "efConstruction": 64
                }
            }

            self.client.create_index(
                collection_name=self.collection_name,
                index_params=[index_params]
            )

            # Load collection for search
            self.client.load_collection(self.collection_name)

            print(f"Successfully set up vector store: {self.collection_name}")

        except Exception as e:
            print(f"Vector store setup error: {e}")
            raise

    def retrieve_relevant_chunks(self, query: str, top_k: int = 3) -> List[str]:
        """Retrieve relevant chunks using similarity search"""
        if not self.document_loaded:
            raise ValueError("No document has been loaded yet")

        if not query.strip():
            raise ValueError("Query cannot be empty")

        try:
            # Generate query embedding
            query_embedding = self.embedding_model.encode(query)

            # Search in Milvus
            search_params = {
                "metric_type": "COSINE",
                "params": {"nprobe": 10}
            }

            results = self.client.search(
                collection_name=self.collection_name,
                data=[query_embedding.tolist()],
                limit=top_k,
                output_fields=["text"],
                search_params=search_params
            )

            if not results or not results[0]:
                print("No search results found")
                return []

            relevant_chunks = []
            for hit in results[0]:
                if isinstance(hit, dict) and "entity" in hit:
                    chunk = hit["entity"].get("text", "")
                    if isinstance(chunk, str) and chunk.strip():
                        relevant_chunks.append(chunk)

            return relevant_chunks

        except Exception as e:
            print(f"Error in retrieve_relevant_chunks: {e}")
            return []

    def process_and_store_document(self, file_path: str) -> int:
        """Process document and store chunks in vector store"""
        try:
            # Reset state
            self.document_loaded = False
            self.clear_history()

            # Process document
            text = self.doc_processor.read_pdf(file_path)
            if not text.strip():
                raise ValueError("No text extracted from PDF")

            chunks = self.doc_processor.chunk_text(text)
            if not chunks:
                raise ValueError("No chunks created from document")

            # Generate embeddings
            embeddings = self.embedding_model.encode(chunks)

            # Clear existing data
            try:
                self.client.delete(
                    collection_name=self.collection_name,
                    filter="id >= 0"
                )
            except Exception as e:
                print(f"Warning during collection clearing: {e}")

            # Prepare entities
            entities = []
            for i, (chunk, embedding) in enumerate(zip(chunks, embeddings)):
                if chunk.strip():  # Only add non-empty chunks
                    entities.append({
                        "id": i,
                        "text": chunk,
                        "embedding": embedding.tolist()
                    })

            # Insert data
            if entities:
                self.client.insert(
                    collection_name=self.collection_name,
                    data=entities
                )
                self.document_loaded = True
                return len(entities)
            else:
                raise ValueError("No valid entities to insert")

        except Exception as e:
            print(f"Document processing error: {e}")
            self.document_loaded = False
            raise

    def generate_response(self, query: str, contexts: List[str]) -> str:
        """Generate response using Groq with retrieved contexts and chat history"""
        try:
            # Format context for better readability
            formatted_context = "\n\n".join([f"Context {i+1}:\n{ctx}" for i, ctx in enumerate(contexts)])

            # Add user query to chat history
            self.chat_history.append({
                "role": "user",
                "content": query
            })

            # Prepare messages for Groq with chat history context
            messages = self.chat_history.copy()

            # Add context to the latest user message
            messages[-1]["content"] = f"""Using the following contexts and our conversation history, please answer this question: {query}

Contexts:
{formatted_context}"""

            # Generate response using Groq
            chat_completion = self.groq_client.chat.completions.create(
                messages=messages,
                model="llama3-70b-8192",
                temperature=0.7,
                max_tokens=1000,
                top_p=0.9
            )

            response = chat_completion.choices[0].message.content

            # Add assistant's response to chat history
            self.chat_history.append({
                "role": "assistant",
                "content": response
            })

            return response

        except Exception as e:
            print(f"Error in generate_response: {e}")
            error_msg = f"Error generating response: {str(e)}"

            # Add error response to chat history
            self.chat_history.append({
                "role": "assistant",
                "content": error_msg
            })

            return error_msg

    def query(self, query_text: str) -> str:
        """Process a query through the RAG pipeline"""
        try:
            if not self.document_loaded:
                return "Please upload and process a document first."

            if not query_text.strip():
                return "Please enter a valid question."

            # Retrieve relevant chunks using similarity search
            relevant_chunks = self.retrieve_relevant_chunks(query_text, top_k=3)

            # Store the retrieved chunks
            self.last_retrieved_chunks = relevant_chunks

            if not relevant_chunks:
                return "I couldn't find relevant information to answer your question."

            # Generate response using the retrieved chunks
            response = self.generate_response(query_text, relevant_chunks)

            return response

        except Exception as e:
            print(f"Error in query method: {e}")
            return f"Error processing query: {str(e)}"

    def get_last_retrieved_chunks(self) -> str:
        """Get the chunks retrieved for the last query"""
        if not self.last_retrieved_chunks:
            return "No chunks have been retrieved yet. Try asking a question first."

        formatted_chunks = []
        for i, chunk in enumerate(self.last_retrieved_chunks, 1):
            formatted_chunks.append(f"Chunk {i}:\n{chunk}\n")

        return "\n".join(formatted_chunks)

    def get_all_chunks(self) -> str:
        """Get all document chunks from the database"""
        if not self.document_loaded:
            return "No document has been loaded yet. Please upload and process a document first."

        try:
            results = self.client.query(
                collection_name=self.collection_name,
                filter="id >= 0",
                output_fields=["text", "id"],
                limit=1000
            )

            if not results:
                return "No chunks found in the database."

            chunks = []
            for result in results:
                if isinstance(result, dict) and "text" in result:
                    text = result["text"]
                    if isinstance(text, list):
                        text = text[0]
                    if isinstance(text, str) and text.strip():
                        chunk_id = result.get("id", "unknown")
                        chunks.append(f"Chunk {chunk_id}: {text}")

            return "\n\n".join(chunks) if chunks else "No readable chunks found."

        except Exception as e:
            print(f"Error retrieving chunks: {e}")
            return f"Error retrieving chunks: {str(e)}"

    def clear_history(self):
        """Reset conversation history"""
        self.initialize_chat()

## Gradio Interface

Set up the Gradio interface for interactive use of the RAG system.

In [6]:
def create_gradio_interface():
    rag_system = RAGSystem()

    def process_file(file):
        if file is None:
            return "Please upload a document first."
        try:
            num_chunks = rag_system.process_and_store_document(file.name)
            return f"Processed {num_chunks} chunks. Documents are stored in Milvus Lite database."
        except Exception as e:
            return f"Error: {str(e)}"

    def process_query(query, history):
        if not query:
            return "", history

        try:
            if not rag_system.document_loaded:
                response = "Please upload and process a document first."
            else:
                response = rag_system.query(query)

            history = history + [[query, response]]
            return "", history
        except Exception as e:
            return "", history + [[query, f"Error processing query: {str(e)}"]]

    with gr.Blocks() as demo:
        gr.Markdown("# RAG Question-Answering System with Groq and Milvus Lite")

        with gr.Row():
            file_input = gr.File(label="Upload PDF Document")
            process_button = gr.Button("Process Document")

        output_text = gr.Textbox(label="Processing Status")
        process_button.click(process_file, inputs=[file_input], outputs=[output_text])

        chatbot = gr.Chatbot()
        msg = gr.Textbox(label="Enter your question")
        clear = gr.Button("Clear")

        msg.submit(process_query, [msg, chatbot], [msg, chatbot])
        clear.click(lambda: (None, None), None, [msg, chatbot], queue=False)

        with gr.Row():
            col1, col2 = gr.Column(), gr.Column()
            with col1:
                display_all_chunks_button = gr.Button("Display All Document Chunks")
                all_chunks_output = gr.Textbox(label="All Document Chunks", lines=10)
            with col2:
                display_retrieved_chunks_button = gr.Button("Display Last Retrieved Chunks")
                retrieved_chunks_output = gr.Textbox(label="Last Retrieved Chunks", lines=10)

        # Link buttons to their respective functions
        display_all_chunks_button.click(
            rag_system.get_all_chunks,
            outputs=[all_chunks_output]
        )

        display_retrieved_chunks_button.click(
            rag_system.get_last_retrieved_chunks,
            outputs=[retrieved_chunks_output]
        )

    return demo

## Launch the Application

Run this cell to start the Gradio interface.

In [None]:
if __name__ == "__main__":
    demo = create_gradio_interface()
    demo.launch(share=True,debug=True)