In [1]:
from langchain.text_splitter import RecursiveCharacterTextSplitter # type: ignore
from langchain_core.prompts import PromptTemplate
from langchain_ollama import OllamaLLM

import weaviate

import weaviate.classes as wvc
from sentence_transformers import SentenceTransformer
from weaviate.classes.config import Property, DataType

from weaviate.collections import Collection
from weaviate.collections.classes.config import (
    Property, DataType
)

from enum import Enum
from typing import Dict
import numpy as np
from math import floor
from typing import List, Dict, Optional
from llmlingua import PromptCompressor
from jinja2 import Template
import dotenv
import os

dotenv.load_dotenv()
llm_name = os.getenv("LLM")
prompts_folder = os.getenv("PROMPTS_FOLDER")
embedding_model_path = os.getenv("ENCODER_MODEL")

embedding_model = SentenceTransformer(embedding_model_path, trust_remote_code=True, device='cuda')
compressor = PromptCompressor(model_name='microsoft/llmlingua-2-xlm-roberta-large-meetingbank', use_llmlingua2=True)
wv_client = weaviate.connect_to_local()




<All keys matched successfully>


In [None]:
class BooksProcessor:
    def __init__(self, wv_client, embedding_model):
        self.embedding_model = embedding_model
        self.wv_client = wv_client

    def create_collection_if_not_exists(self, collection_name):
        if not self.wv_client.collections.exists(collection_name):
            self.wv_client.collections.create(
                name=collection_name,
                properties=[
                    Property(name="chunk", data_type=DataType.TEXT),
                    Property(name="book_name", data_type=DataType.TEXT),
                    Property(name="chunk_num", data_type=DataType.INT)
                ],
                #vectorizer_config=wvc.config.Configure.Vectorizer.none()
                #vectorizer_config=[
                    #Configure.NamedVectors.text2vec_ollama(
                    #    name="book_vectorizer",
                    #    source_properties=["book_chunks"],
                    #    api_endpoint="http://ollama:11434",
                    #    model=self.embedding_model_name,
                    #    vector_index_config=Configure.VectorIndex.hnsw(
                    #        distance_metric=VectorDistances.COSINE
                    #    )
                    #)
                #]
            )
        return self.wv_client.collections.get(collection_name)

    def split_book(self, book_text, chunk_size, chunk_overlap):
        splitter = RecursiveCharacterTextSplitter(
            chunk_size=chunk_size,
            chunk_overlap=chunk_overlap
        )
        return [i.page_content for i in splitter.create_documents([book_text])]

    def process_book(self, book_name, book_txt):
        if self.wv_client.collections.exists(book_name + '_medium_chunks'):
            print("Book already exists")
            return
        chunk_configs = [
        #    ('_big_chunks', 3000, 1000),
            ('_medium_chunks', 1000, 100),
        #    ('_small_chunks', 400, 50)
        ]
        
        for suffix, chunk_size, overlap in chunk_configs:
            collection = self.create_collection_if_not_exists(book_name + suffix)
            chunks = self.split_book(book_txt, chunk_size, overlap)
            embeddings = self.embedding_model.encode(['search_document: ' + i for i in chunks], batch_size=15).tolist()
            question_objs = []

            for i, (chunk, embedding) in enumerate(zip(chunks, embeddings)):
                question_objs.append(wvc.data.DataObject(
                    properties= {
                        "chunk": chunk,
                        "book_name": book_name,
                        "chunk_num": i
                    },
                    vector=embedding
                ))
            collection.data.insert_many(question_objs)

    def delete_book(self, book_name: str) -> None:
        """
        Delete all collections associated with a book.
        """
        for suffix in ['_big_chunks', '_medium_chunks', '_small_chunks']:
            collection_name = book_name + suffix
            if self.wv_client.collections.exists(collection_name):
                try:
                    self.wv_client.collections.delete(collection_name)
                except Exception as e:
                    print(f"Error deleting collection {collection_name}: {e}")
        print(f"Successfully deleted collections for {book_name}")

class Search:
    def __init__(self, wv_client, embedding_model):
        self.embedding_model = embedding_model
        self.wv_client = wv_client
        self.multiplier_mapping = {'_big_chunks': 0.7, '_medium_chunks': 1, '_small_chunks': 1.9}
        #self._load_prompt_template()

    def process_chunks(self, relevant_chunks):
        relevant_text = '\n'.join([f'\nCHUNK {i.properties['chunk_num']}\n' + i.properties['chunk'].strip() for i in relevant_chunks])
        print(f'Len of relevant text: {len(relevant_text)}')

    def search(self, query, book_name):
        collection_type = '_medium_chunks'
        print(f'Collection type: {collection_type}')
        book = self.wv_client.collections.get(book_name + collection_type)
        
        total_count = book.aggregate.over_all(total_count=True).total_count
        chunks_to_retrieve = floor(np.maximum(self.multiplier_mapping[collection_type] * np.log(total_count), 1))
        print(f"Retrieving {chunks_to_retrieve} chunks from book {book_name}")
        
        embedding = self.embedding_model.encode('search_query: ' + query, batch_size=1)
        response = book.query.near_vector(near_vector=list(embedding), limit=chunks_to_retrieve, return_metadata=wvc.query.MetadataQuery(certainty=True))
        relevant_chunks = response.objects#sorted(response.objects, key=lambda x: x.properties['chunk_num'])
        return relevant_chunks

    def search_multiple_books(self, query, book_names):
        result = []
        for book_name in book_names:
            chunks = self.search(query, book_name)
            result.extend([{'chunk': i.properties['chunk'].strip(),
                            'chunk_num': i.properties['chunk_num'],
                            'book_name': book_name} for i in chunks])
        return result

class RAGSystem:
    def __init__(self, wv_client, embedding_model, compressor, llm_name, prompts_folder, compression_rate=0.75):
        self.embedding_model = embedding_model
        self.searcher = Search(wv_client, self.embedding_model)
        self.compression_rate = compression_rate
        self.compressor = compressor
        self.llm = OllamaLLM(
            model=llm_name,
            temperature=0,
            base_url=f"http://localhost:11434"
        )
        with open(os.path.join(prompts_folder, 'final_prompt.j2')) as f:
            self._template = f.read()

    def query(self, query: str, book_names: List[str], 
             dialogue_history: Optional[List[Dict[str, str]]] = None) -> str:
        dialogue_history = dialogue_history or []
        compressed_contexts = []
        
        for book_name in book_names:
            context = self.searcher.search(query, book_name)
            if context:
                compressed = self.compressor.compress_prompt(
                    context,
                    rate=self.compression_rate,
                    force_tokens=['\n', '?', '.', '!', 'CHUNK']
                )['compressed_prompt']
                compressed_contexts.append(f"From {book_name}:\n{compressed}")
        
        if not compressed_contexts:
            return "No relevant information found."

        print(f'Len of compressed context: {sum([len(i) for i in compressed_contexts])}')
        final_prompt = Template(self._template).render(
            contexts=compressed_contexts,
            dialogue_history=dialogue_history,
            query=query
        )
        
        return self.llm.invoke(final_prompt)

In [None]:
processor = BooksProcessor(wv_client, embedding_model)
with open('Sherlock Study in Scarlet.txt', 'r', encoding='utf8') as file:
    text = file.read()
processor.process_book('Sherlock_Study_in_Scarlet', text)
#processor.delete_book('Sherlock_Study_in_Scarlet')

In [None]:
search = Search(wv_client, embedding_model)
#rag_context = search.search(query='search_query: ' + 'What happened in London?', book_name='Sherlock_Study_in_Scarlet')
#rag_context

Collection type: _medium_chunks
Retrieving 5 chunks
Len of relevant text: 4644




In [10]:
rag = RAGSystem(wv_client, embedding_model, compressor, llm_name=llm_name, prompts_folder=prompts_folder)

queries = [
    "What word was written in blood on the wall near Enoch Drebber's body?",
    "What clue did Sherlock Holmes find near the body that aided in the investigation?",
    "What method did Sherlock Holmes use to deduce Dr. Watson's profession and background upon their first meeting?",
    "Who was the actual murderer of Enoch Drebber, and what was the motive?",
    "What roles did Inspectors Gregson and Lestrade from Scotland Yard play in the investigation?"
]

for query in queries:
    response = rag.query(
        query=query,
        book_names=['Sherlock_Study_in_Scarlet'],
        dialogue_history=[]
    )
    print(response)

  rag = RAGSystem(wv_client, embedding_model, compressor, llm_name=llm_name, prompts_folder=prompts_folder)


Collection type: _medium_chunks
Retrieving 5 chunks
Len of relevant text: 4645
Len of compressed context: 3696
According to CHUNK 70, the single word "RACHE" was written in blood-red letters on the wall near Enoch Drebber's body.
Collection type: _medium_chunks
Retrieving 5 chunks
Len of relevant text: 4450
Len of compressed context: 3551
According to CHUNK 64, Sherlock Holmes found numerous gouts and splashes of blood around the body, indicating that there was a second individual involved, presumably the murderer.
Collection type: _medium_chunks
Retrieving 5 chunks
Len of relevant text: 4553
Len of compressed context: 3645
According to CHUNK 12, when introducing himself, Dr. Watson asked how Sherlock Holmes knew that he had been in Afghanistan. Sherlock Holmes replied that he had noticed a scar on Dr. Watson's hand, which suggested that he had been in Afghanistan (CHUNK 2 mentions that Dr. Watson completed his studies attached to the Fifth Northumberland Fusiliers as Assistant Surgeon

In [35]:
from typing_extensions import TypedDict

class GraphState(TypedDict):
    """
    Represents the state of our graph.

    Attributes:
        question: question
        generation: LLM generation
        documents: list of documents
    """

    user_question: str
    current_question: str
    all_questions: str
    facts: List[str]
    watched_documents: List[int]

retriever = Search(wv_client, embedding_model)
llm = OllamaLLM(
            model=llm_name,
            temperature=0,
            base_url=f"http://localhost:11434")
book_names = ['Sherlock_Study_in_Scarlet']

  llm = OllamaLLM(


In [34]:
from langchain_ollama import OllamaLLM

# Initialize the model
llm = OllamaLLM(
    model="llama3.2",
    temperature=0,
    base_url="http://localhost:11434"
)

# Input data
question = "agent memory"
doc_txt = "This document discusses agent memory in multi-agent systems."

# Define the prompt
prompt = (
    "Evaluate the relevance of the document to the user's question. Respond strictly in JSON format:\n"
    '{"binary_score": "yes"} or {"binary_score": "no"}.\n\n'
    "Example:\n"
    'Question: What animals live in the Arctic?\n'
    'Document: The Arctic is home to polar bears and walruses.\n'
    'Answer: {"binary_score": "yes"}\n\n'
    f"Question: {question}\nDocument: {doc_txt}\n"
)

# Get the model's response
response = llm.invoke(prompt)
print("Model response:", response)

# Validate and parse the output
from json import loads

try:
    parsed_response = loads(response)
    if parsed_response.get("binary_score") in ["yes", "no"]:
        print("Parsed successfully:", parsed_response)
    else:
        print("Error: Invalid 'binary_score' value.")
except Exception as e:
    print(f"Parsing error: {e}")

Model response: {"binary_score": "yes"}
Parsed successfully: {'binary_score': 'yes'}


In [None]:
def retrieve(state):
    """
    Retrieve documents

    Args:
        state (dict): The current graph state

    Returns:
        state (dict): New key added to state, documents, that contains retrieved documents
    """
    print("---RETRIEVE---")
    question = 'search_query: ' + state["question"]

    # Retrieval
    documents = retriever.search_multiple_books(question, book_names=book_names)
    return {"documents": documents}

def collect_info(state):
    print("---COLLECTING INFO---")
    question = 'search_query: ' + state["question"]

    # Retrieval
    documents = retriever.search_multiple_books(question, book_names=book_names)
    return {"documents": documents}

In [None]:
from langgraph.graph import END, StateGraph, START

workflow = StateGraph(GraphState)

# Define the nodes
workflow.add_node("retrieve", retrieve)  # retrieve
workflow.add_node("grade_documents", grade_documents)  # grade documents
workflow.add_node("generate", generate)  # generatae
workflow.add_node("transform_query", transform_query)  # transform_query

# Build graph
workflow.add_edge(START, "retrieve")
workflow.add_edge("retrieve", "grade_documents")
workflow.add_conditional_edges(
    "grade_documents",
    decide_to_generate,
    {
        "transform_query": "transform_query",
        "generate": "generate",
    },
)
workflow.add_edge("transform_query", "retrieve")
workflow.add_conditional_edges(
    "generate",
    grade_generation_v_documents_and_question,
    {
        "not supported": "generate",
        "useful": END,
        "not useful": "transform_query",
    },
)

# Compile
app = workflow.compile()