# Retrieval-Augmented Generation with Milvus

This notebook demonstrates how to use Milvus for Retrieval-Augmented Generation (RAG).

In [None]:
# Import all required libraries for the chatbot
# Operating system interface for file/directory operations
import os
# Milvus database connection and utility functions
from pymilvus import connections, utility
# LangChain component for combining multiple documents into one context
from langchain.chains.combine_documents import create_stuff_documents_chain
# Base document class for storing text and metadata
from langchain.schema import Document
# Template system for creating chat prompts
from langchain_core.prompts import ChatPromptTemplate
# Groq's language model interface
from langchain_groq.chat_models import ChatGroq
# Milvus vector database integration for LangChain
from langchain_milvus import Milvus
# Tool for downloading web pages recursively
from langchain_community.document_loaders import RecursiveUrlLoader
# Library for parsing HTML content
from bs4 import BeautifulSoup
# Tool for splitting text into smaller chunks
from langchain_text_splitters import RecursiveCharacterTextSplitter
# Component for creating document retrieval systems
from langchain.chains import create_retrieval_chain
# Interface for HuggingFace's embedding models
from langchain_huggingface import HuggingFaceEmbeddings

# Define constant values used throughout the program
# URL of the website we'll use as our knowledge base
WEBSITE_URL = 'https://www.csusb.edu/its'
# Path where we'll store our vector database files
DATABASE_PATH = "milvus/jupyter_milvus_vector3.db"
# Name of the embedding model we'll use to convert text to vectors
EMBEDDING_MODEL = "sentence-transformers/all-MiniLM-L12-v2"

def get_api_key():
    """Gets the Groq API key from user input"""
    # Prompt the user for their API key and return their input
    return input("Please enter your Groq API key: ")

def setup_vector_database(database_path):
    """Sets up the vector database connection"""
    # Create the parent directory for our database if it doesn't exist
    # os.path.split separates the path into directory and filename
    # os.makedirs creates all necessary directories in the path
    os.makedirs(os.path.split(database_path)[0], exist_ok=True)
    
    # Connect to the Milvus database using the provided path
    connections.connect(
        # Use the default connection name
        "default",
        # Specify the database location
        uri=database_path
    )
    
    # Check if our collection exists and return True/False
    return utility.has_collection("IT_support")

def create_embedding_model():
    """Creates the text embedding model"""
    # Initialize a new HuggingFace embeddings model
    # Return the model configured with our chosen embedding model name
    return HuggingFaceEmbeddings(
        # Specify which pre-trained model to use for embeddings
        model_name=EMBEDDING_MODEL
    )

def clean_webpage_text(html_content):
    """Cleans HTML content to extract readable text"""
    # Create a BeautifulSoup parser object with the HTML content
    soup = BeautifulSoup(
        # The HTML content to parse
        html_content,
        # Specify which parser to use
        'html.parser'
    )
    
    # Remove unwanted HTML elements that aren't part of the main content
    # Loop through each type of element we want to remove
    for element in soup(['script', 'style', 'header', 'footer', 'nav']):
        # Remove the element from the parsed HTML
        element.decompose()
    
    # Try to find the main content section of the page
    main_content = soup.find('main')
    
    # If we found a main content section, get its text, otherwise get all page text
    # Use newlines to separate different text blocks
    text = main_content.get_text('\n') if main_content else soup.get_text('\n')
    
    # Process each line of text
    # Strip whitespace from each line
    # Only keep lines that have content (aren't empty after stripping)
    lines = [line.strip() for line in text.splitlines() if line.strip()]
    
    # Join all the cleaned lines back together with newlines between them
    return '\n'.join(lines)

def download_website_content():
    """Downloads and processes website content"""
    # Create a URL loader that will download pages recursively
    loader = RecursiveUrlLoader(
        # The starting URL to download from
        url=WEBSITE_URL,
        # Don't follow links to other websites
        prevent_outside=True,
        # Base URL for resolving relative links
        base_url=WEBSITE_URL
    )
    
    # Download all pages from the website
    pages = loader.load()
    
    # Initialize an empty list for our cleaned pages
    cleaned_pages = []
    
    # Process each downloaded page
    for page in pages:
        # Clean the HTML content of the current page
        clean_text = clean_webpage_text(page.page_content)
        # Create a new Document object with the cleaned text
        # Add it to our list of cleaned pages
        cleaned_pages.append(Document(
            # The cleaned text content
            page_content=clean_text,
            # Keep the original metadata (like URL)
            metadata=page.metadata
        ))
    
    # Return all our cleaned documents
    return cleaned_pages

def split_into_chunks(documents):
    """Splits documents into smaller pieces for processing"""
    # Create a text splitter with our desired configuration
    splitter = RecursiveCharacterTextSplitter(
        # Maximum size of each text chunk
        chunk_size=1000,
        # How much text should overlap between chunks
        chunk_overlap=300,
        # Don't use regex for splitting
        is_separator_regex=False
    )
    
    # Split all documents into chunks and return them
    return splitter.split_documents(documents)

def create_chatbot_prompt():
    """Creates the instruction prompt for the chatbot"""
    # Define the system message that guides the chatbot's behavior
    system_message = """
    You are an AI assistant that provides answers strictly based on the provided context. Adhere to these guidelines:
     - Only answer questions based on the content within the <context> tags.
     - If the <context> doesn't contain relevant information, respond with: "I don't have enough information to answer this question."
     - Ask for clarification if questions are unclear.
     - Provide specific, concise answers with relevant statistics when available.
     - Don't add external information or make assumptions.
    """
    
    # Create a prompt template with system and human messages
    prompt = ChatPromptTemplate.from_messages([
        # System message defines the chatbot's behavior
        ("system", system_message),
        # Human message template with placeholders for question and context
        ("human", "<question>{input}</question>\n\n<context>{context}</context>"),
    ])
    
    # Return the configured prompt template
    return prompt

def initialize_chatbot(api_key):
    """Sets up all components needed for the chatbot"""
    # Initialize the language model with our configuration
    model = ChatGroq(
        # Specify which model version to use
        model='llama-3.1-70b-versatile',
        # Set temperature to 0.9 for most deterministic responses
        temperature=0.9,
        # Pass in the user's API key
        api_key=api_key
    )
    
    # Check if we have an existing database
    if setup_vector_database(DATABASE_PATH):
        # Create a vector store object connected to existing database
        vector_store = Milvus(
            # Name of our collection in the database
            collection_name="IT_support",
            # Function to create embeddings from text
            embedding_function=create_embedding_model(),
            # Database connection details
            connection_args={"uri": DATABASE_PATH}
        )
        # Inform user we're using existing database
        print("Loading existing knowledge base...")
    else:
        # Inform user we're creating new database
        print("Creating new knowledge base...")
        # Download and process website content
        documents = download_website_content()
        # Split documents into smaller chunks
        chunks = split_into_chunks(documents)
        # Create a new vector store with our documents
        vector_store = Milvus.from_documents(
            # The document chunks to store
            documents=chunks,
            # The embedding model to use
            embedding=create_embedding_model(),
            # Name for our collection
            collection_name="IT_support",
            # Database connection details
            connection_args={"uri": DATABASE_PATH},
            # Remove existing collection if it exists
            drop_old=True
        )
    
    # Create a retriever from our vector store
    retriever = vector_store.as_retriever(
        # Use Maximal Marginal Relevance search
        search_type="mmr",
        # Configure search parameters
        search_kwargs={
            # Minimum similarity score to include results
            "score_threshold": 1,
            # Number of documents to retrieve
            "k": 3
        }
    )
    
    # Create a chain for processing documents
    document_chain = create_stuff_documents_chain(
        # The language model to use
        model,
        # The prompt template we created
        create_chatbot_prompt()
    )
    
    # Create and return the final retrieval chain
    return create_retrieval_chain(
        # The retriever we configured
        retriever,
        # The document processing chain
        document_chain
    )

def format_sources(context_documents):
    """Formats source URLs from retrieved documents"""
    # Initialize empty list for storing sources and their scores
    sources = []
    
    # Process each document in the context
    for i, doc in enumerate(context_documents):
        try:
            # Get the source URL from document metadata
            url = doc.metadata["source"]
            # Check if we already have this URL
            if url not in [s['url'] for s in sources]:
                # Add the URL and calculate its relevance score
                sources.append({
                    # The source URL
                    'url': url,
                    # Score based on position (earlier is better)
                    'relevance_score': 1 / (i + 1)
                })
        except (IndexError, KeyError):
            # Skip if we can't get the URL
            continue
    
    # Sort sources by relevance score (highest first)
    sources.sort(key=lambda x: x['relevance_score'], reverse=True)
    
    # Format sources as a numbered list
    source_list = "\n\nSources:\n" + "\n".join([
        f"{i+1}. {source['url']}" 
        for i, source in enumerate(sources[:4])  # Limit to top 4 sources
    ])
    
    # Return the formatted source list
    return source_list

def answer_question(chain, question):
    """Generates an answer to a user's question"""
    try:
        # Get response from the retrieval chain
        response = chain.invoke({"input": question})
        
        # Get the clean answer text
        answer = response["answer"].strip()
        
        # Check if the response contains "no information" message
        if "I don't have enough information to answer this question." in answer or \
           "You didn't ask a question." in answer or \
           "Hello" in answer or \
           "It appears that the question" in answer or \
           "Hello, it seems like you're looking for something related to the information provided in the context" in answer:
            return answer
        # For all other answers, add sources and return
        return answer + format_sources(response["context"])
        
    except Exception as e:
        # Check if it's a 503 Service Unavailable error
        if "503" in str(e) and "Service Unavailable" in str(e):
            return "I apologize, but the server is currently unavailable. This is a temporary issue. Please wait a few moments and try your question again."
        # Handle other potential errors
        else:
            return f"I encountered an error while processing your question. Error details: {str(e)}"

def main():
    """Main program loop"""
    # Print welcome message with website information
    print(f"Welcome to the CSUSB ITS Support Chatbot!\nThis bot answers questions about: {WEBSITE_URL}")
    
    try:
        # Get API key from user
        api_key = get_api_key()
        # Initialize chatbot with the API key
        chain = initialize_chatbot(api_key)
        
        # Start main conversation loop
        while True:
            # Get question from user
            question = input("\nEnter your question (or 'exit' to quit): ")
            
            # Check if user wants to exit
            if question.lower() == 'exit':
                # Print goodbye message
                print("Goodbye!")
                # Exit the loop
                break
            
            # Get answer to user's question
            answer = answer_question(chain, question)
            # Print the response
            print("\nResponse:", answer)
            
    except KeyboardInterrupt:
        # Handle if user presses Ctrl+C
        print("\nGoodbye!")
    except Exception as e:
        # Handle any other unexpected errors
        print(f"\nAn unexpected error occurred: {str(e)}")
        print("Please restart the program and try again.")

# Check if this file is being run directly (not imported)
if __name__ == "__main__":
    # Start the main program
    main()