In [1]:
import argparse
import os
import shutil
from langchain.document_loaders.pdf import PyPDFDirectoryLoader
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain.schema.document import Document
from get_embedding_function import get_embedding_function
from langchain.vectorstores.chroma import Chroma
from langchain_community.embeddings.ollama import OllamaEmbeddings
from langchain.llms import Ollama

In [2]:
CHROMA_PATH = "chroma"
DATA_PATH = "data"


In [3]:
import argparse
import os
import shutil
from contextlib import contextmanager
import sys

# Import necessary classes and functions

from langchain.prompts import PromptTemplate
from langchain.vectorstores import DocArrayInMemorySearch
from langchain.embeddings import OllamaEmbeddings
from langchain.prompts import PromptTemplate, ChatPromptTemplate, MessagesPlaceholder
from langchain.schema import HumanMessage, AIMessage
from langchain.memory import ConversationBufferMemory
from langchain.chains import ConversationalRetrievalChain
from langchain_chroma import Chroma
from langchain_core.output_parsers import StrOutputParser
from langchain.prompts import PromptTemplate, ChatPromptTemplate

In [4]:
def main():
    try:
        if not hasattr(sys, 'ps1'):
            parser = argparse.ArgumentParser()
            parser.add_argument("--reset", action="store_true", help="Reset the database.")
            args = parser.parse_args()
            if args.reset:
                print("Clearing Database")
                clear_database()

        documents = load_documents()
        chunks = split_documents(documents)
        add_to_chroma(chunks)
    except Exception as e:
        print(f"An error occurred in main: {e}")

def load_documents():
    try:
        print("Loading documents...")
        document_loader = PyPDFDirectoryLoader(DATA_PATH)
        docs = document_loader.load()
        print(f"Loaded {len(docs)} documents")
        return docs
    except Exception as e:
        print(f"Error loading documents: {e}")
        return []

def split_documents(documents):
    try:
        print("Splitting documents...")
        text_splitter = RecursiveCharacterTextSplitter(
            chunk_size=800,
            chunk_overlap=80,
            length_function=len,
            is_separator_regex=False,
        )
        chunks = text_splitter.split_documents(documents)
        print(f"Created {len(chunks)} chunks")
        return chunks
    except Exception as e:
        print(f"Error splitting documents: {e}")
        return []

def add_to_chroma(chunks):
    try:
        print("Adding chunks to Chroma...")
        db = Chroma(persist_directory=CHROMA_PATH, embedding_function=get_embedding_function())
        chunks_with_ids = calculate_chunk_ids(chunks)

        existing_items = db.get(include=[])
        existing_ids = set(existing_items['ids'])

        new_chunks = [chunk for chunk in chunks_with_ids if chunk.metadata["id"] not in existing_ids]

        if new_chunks:
            print(f"Adding {len(new_chunks)} new chunks to the database")
            # Add chunks in smaller batches
            batch_size = 100
            for i in range(0, len(new_chunks), batch_size):
                batch = new_chunks[i:i+batch_size]
                db.add_documents(batch)
                print(f"Added batch {i//batch_size + 1}")
            db.persist()
        else:
            print("No new chunks to add.")
    except Exception as e:
        print(f"Error adding to Chroma: {e}")

def calculate_chunk_ids(chunks):
    try:
        print("Calculating chunk IDs...")
        last_page_id = None
        current_chunk_index = 0
        for chunk in chunks:
            source = chunk.metadata.get("source")
            page = chunk.metadata.get("page")
            current_page_id = f"{source}:{page}"
            current_chunk_index = current_chunk_index + 1 if current_page_id == last_page_id else 0
            chunk.metadata["id"] = f"{current_page_id}:{current_chunk_index}"
            last_page_id = current_page_id
        return chunks
    except Exception as e:
        print(f"Error calculating chunk IDs: {e}")
        return []

def clear_database():
    if os.path.exists(CHROMA_PATH):
        shutil.rmtree(CHROMA_PATH)
        print("Database cleared")

def get_embedding_function():
    try:
        print("Initializing embedding function...")
        embeddings = OllamaEmbeddings(model="nomic-embed-text")
        return embeddings
    except Exception as e:
        print(f"Error initializing embedding function: {e}")
        return None


In [6]:
main()

Loading documents...


  from cryptography.hazmat.primitives.ciphers.algorithms import AES, ARC4


Loaded 10 documents
Splitting documents...
Created 33 chunks
Adding chunks to Chroma...
Initializing embedding function...
Calculating chunk IDs...
Adding 33 new chunks to the database


: 

In [5]:


template = """
You are an AI named okaygpt and you are working for Mercedes Benz Buses and Truck.
Answer the question based on the context below and the chat history. If you can't 
answer the question based on the given information, reply 'Sorry, I don't know'.

Context: {context}

Chat History: {chat_history}

Question: {question}

Answer:
"""
rag_prompt = PromptTemplate.from_template(template)

In [6]:
def setup_qa_chain():
    # Initialize embeddings and model
    embeddings = get_embedding_function()
    model = Ollama(model='llama3', num_ctx=9000, base_url='http://localhost:11434')

    # Load the Chroma database
    db = Chroma(persist_directory=CHROMA_PATH, embedding_function=embeddings)

    # Create a retriever from the Chroma database
    retriever = db.as_retriever()

    # Define the RAG prompt template
    template = """
    You are an AI named okaygpt and you are working for Mercedes Benz Buses and Truck.
    Answer the question based on the context below and the chat history. If you can't 
    answer the question based on the given information, reply 'Sorry, I don't know'.

    Context: {context}

    Chat History: {chat_history}

    Question: {question}

    Answer:
    """
    rag_prompt = PromptTemplate.from_template(template)

    # Initialize conversation memory
    memory = ConversationBufferMemory(
        memory_key="chat_history",
        return_messages=True,
        output_key="answer"
    )

    # Create the conversational retrieval chain
    qa_chain = ConversationalRetrievalChain.from_llm(
        llm=model,
        retriever=retriever,
        memory=memory,
        combine_docs_chain_kwargs={"prompt": rag_prompt},
         return_source_documents=True,
        return_generated_question=False
        
    )

    return qa_chain

def start_app(qa_chain):
    message_count = 0
    max_message = 3
    while message_count <= max_message:
        question = input('You: ')
        print(f'User: {question}\n')

        if question.lower() == 'done':
            print('Session ended.')
            break
        
        # Get response from the chain
        response = qa_chain({"question": question})
        
        print('AI:', response['answer'])
        
        # Print sources
        print("\nSources:")
        for doc in response['source_documents']:
            print(f"- {doc.metadata['source']} (Page {doc.metadata['page']})")
        
        message_count += 1
    
    print('AI terminated itself!')

In [7]:
qa_chain = setup_qa_chain()
start_app(qa_chain)

Initializing embedding function...
User: who wrote information form group 1 and can you give me their numbers also



  warn_deprecated(


AI: Sorry, I don't know. The context is about a project for Mercedes Benz Buses and Trucks, but there's no mention of "Information Form Group 1" or any specific author/number associated with it.

Sources:
- data\Final Report.pdf (Page 6)
- data\Thesis_Premise.pdf (Page 5)
- data\Final Report.pdf (Page 9)
- data\Thesis_Premise.pdf (Page 1)
User: done

Session ended.
AI terminated itself!
