In [2]:
from langchain.text_splitter import RecursiveCharacterTextSplitter # type: ignore
from langchain_core.prompts import PromptTemplate
from langchain_ollama import OllamaLLM

import weaviate

import weaviate.classes as wvc
from sentence_transformers import SentenceTransformer
from weaviate.classes.config import Property, DataType

from weaviate.collections import Collection
from weaviate.collections.classes.config import (
    Property, DataType
)

from enum import Enum
from typing import Dict
import numpy as np
from math import floor
from typing import List, Dict, Optional

from llmlingua import PromptCompressor

from jinja2 import Template

embedding_model = SentenceTransformer("nomic-ai/nomic-embed-text-v1.5", trust_remote_code=True, device='cuda')
compressor = PromptCompressor(model_name='microsoft/llmlingua-2-xlm-roberta-large-meetingbank', use_llmlingua2=True)




<All keys matched successfully>


In [3]:
class BooksProcessor:
    def __init__(self, embedding_model):
        self.embedding_model = embedding_model
        self.wv_client = None

    def connect(self):
        if not self.wv_client:
            self.wv_client = weaviate.connect_to_local()
    def close(self):
        if self.wv_client:
            self.wv_client.close()

    def create_collection_if_not_exists(self, collection_name):
        self.connect()
        if not self.wv_client.collections.exists(collection_name):
            self.wv_client.collections.create(
                name=collection_name,
                properties=[
                    Property(name="chunk", data_type=DataType.TEXT),
                    Property(name="book_name", data_type=DataType.TEXT),
                    Property(name="chunk_num", data_type=DataType.INT)
                ],
                #vectorizer_config=wvc.config.Configure.Vectorizer.none()
                #vectorizer_config=[
                    #Configure.NamedVectors.text2vec_ollama(
                    #    name="book_vectorizer",
                    #    source_properties=["book_chunks"],
                    #    api_endpoint="http://ollama:11434",
                    #    model=self.embedding_model_name,
                    #    vector_index_config=Configure.VectorIndex.hnsw(
                    #        distance_metric=VectorDistances.COSINE
                    #    )
                    #)
                #]
            )
        return self.wv_client.collections.get(collection_name)

    def split_book(self, book_text, chunk_size, chunk_overlap):
        splitter = RecursiveCharacterTextSplitter(
            chunk_size=chunk_size,
            chunk_overlap=chunk_overlap
        )
        return [i.page_content for i in splitter.create_documents([book_text])]

    def process_book(self, book_name, book_txt):
        self.connect()
        if self.wv_client.collections.exists(book_name + '_big_chunks'):
            print("Book already exists")
            return
        chunk_configs = [
            ('_big_chunks', 3000, 1000),
            ('_medium_chunks', 1500, 500),
            ('_small_chunks', 400, 50)
        ]
        
        for suffix, chunk_size, overlap in chunk_configs:
            collection = self.create_collection_if_not_exists(book_name + suffix)
            chunks = self.split_book(book_txt, chunk_size, overlap)
            embeddings = self.embedding_model.encode(['search_document: ' + i for i in chunks], batch_size=15).tolist()
            question_objs = []

            for i, (chunk, embedding) in enumerate(zip(chunks, embeddings)):
                question_objs.append(wvc.data.DataObject(
                    properties= {
                        "chunk": chunk,
                        "book_name": book_name,
                        "chunk_num": i
                    },
                    vector=embedding
                ))
            collection.data.insert_many(question_objs)

    def delete_book(self, book_name: str) -> None:
        """
        Delete all collections associated with a book.
        """
        self.connect()
        for suffix in ['_big_chunks', '_medium_chunks', '_small_chunks']:
            collection_name = book_name + suffix
            if self.wv_client.collections.exists(collection_name):
                try:
                    self.wv_client.collections.delete(collection_name)
                except Exception as e:
                    print(f"Error deleting collection {collection_name}: {e}")
        print(f"Successfully deleted collections for {book_name}")
        

class ChunkSize(Enum):
    SMALL = '_small_chunks'
    MEDIUM = '_medium_chunks'
    LARGE = '_big_chunks'

class Search:
    def __init__(self, embedding_model, llm_name='llama3.2'):
        self.embedding_model = embedding_model
        self.llm_name = llm_name
        self.llm = None
        self.wv_client = None
        self.multiplier_mapping = {'_big_chunks': 1, '_medium_chunks': 1.5, '_small_chunks': 3}
        self._load_prompt_template()

    def _load_prompt_template(self):
        with open('classifier_prompt.j2') as f:
            template = f.read()
            self._prompt_template = PromptTemplate(
                input_variables=["query"],
                template=template,
                template_format="jinja2"
            )

    def connect(self):
        if not self.wv_client:
            self.wv_client = weaviate.connect_to_local()
        if not self.llm:
            self.llm = OllamaLLM(
                model=self.llm_name,
                temperature=0,
                base_url=f"http://localhost:11434"
            )

    def close(self):
        if self.wv_client:
            self.wv_client.close()

    def classify_query(self, query):
        self.connect()
        response = self.llm.invoke(self._prompt_template.format(query=query)).strip().upper()
        return getattr(ChunkSize, response, ChunkSize.MEDIUM).value

    def search(self, query, book_name):
        self.connect()
        collection_type = self.classify_query(query)
        print(f'Collection type: {collection_type}')
        book = self.wv_client.collections.get(book_name + '_small_chunks')#collection_type)
        
        total_count = book.aggregate.over_all(total_count=True).total_count
        chunks_to_retrieve = floor(np.maximum(self.multiplier_mapping[collection_type] * np.log(total_count), 1))
        print(f"Retrieving {chunks_to_retrieve} chunks")
        
        embedding = self.embedding_model.encode('search_query: ' + query, batch_size=1)
        response = book.query.near_vector(near_vector=list(embedding), limit=chunks_to_retrieve, return_metadata=wvc.query.MetadataQuery(certainty=True))
        relevant_chunks = response.objects#sorted(response.objects, key=lambda x: x.properties['chunk_num'])
        relevant_text = '\n\n'.join([i.properties['chunk'].strip() for i in relevant_chunks])
        print(f'Len of relevant text: {len(relevant_text)}')
        return relevant_text


class RAGSystem:
    def __init__(self, embedding_model, compressor, llm_name='llama3.2', compression_rate=0.55):
        self.embedding_model = embedding_model
        self.compression_rate = compression_rate
        self.compressor = compressor
        self.llm = OllamaLLM(
            model=llm_name,
            temperature=0,
            base_url=f"http://localhost:11434"
        )
        with open('final_prompt.j2') as f:
            self._template = f.read()

    def query(self, query: str, book_names: List[str], 
             dialogue_history: Optional[List[Dict[str, str]]] = None) -> str:
        dialogue_history = dialogue_history or []
        compressed_contexts = []
        searcher = Search(self.embedding_model)
        for book_name in book_names:
            context = searcher.search(query, book_name)
            if context:
                compressed = self.compressor.compress_prompt(
                    context,
                    rate=self.compression_rate,
                    force_tokens=['\n', '?', '.', '!']
                )['compressed_prompt']
                compressed_contexts.append(f"From {book_name}:\n{compressed}")
        searcher.close()
        
        if not compressed_contexts:
            return "No relevant information found."

        print(f'Len of compressed context: {len(compressed_contexts)}')
        final_prompt = Template(self._template).render(
            contexts=compressed_contexts,
            dialogue_history=dialogue_history,
            query=query
        )
        
        return self.llm.invoke(final_prompt)

In [4]:
# Способ 2: напрямую через контекстный менеджер
processor = BooksProcessor(embedding_model)
with open('Sherlock Study in Scarlet.txt', 'r', encoding='utf8') as file:
    text = file.read()
processor.process_book('Sherlock_Study_in_Scarlet', text)
#processor.delete_book('Sherlock_Study_in_Scarlet')
processor.close()

WeaviateConnectionError: Connection to Weaviate failed. Details: Error: All connection attempts failed. 
Is Weaviate running and reachable at http://localhost:8080?

In [None]:
search = Search(embedding_model)
rag_context = search.search(query='search_query: ' + 'What happened in London?', book_name='Sherlock_Study_in_Scarlet')
search.close()
rag_context

In [27]:
rag = RAGSystem(embedding_model, compressor)

response = rag.query(
    query="Who is Sherlock Holmes",
    book_names=['Sherlock_Study_in_Scarlet'],
    dialogue_history=[]
)
print(response)

  rag = RAGSystem(embedding_model, compressor)


Collection type: _small_chunks
Retrieving 27 chunks
Len of relevant text: 7988
Len of compressed context: 1
Sherlock Holmes is a detective who has "amateur shown talent in detective line" and is expected to attain degree skill. He is described as having "quiet ways habits regular" and is known for his scientific approach to solving cases, often using his powers of observation and deduction.


  response = rag.query(
