In [1]:
from langchain.text_splitter import RecursiveCharacterTextSplitter # type: ignore
from langchain_core.prompts import PromptTemplate
from langchain_ollama import OllamaLLM, OllamaEmbeddings

import weaviate
from weaviate.classes.data import DataObject # type: ignore

import weaviate.classes.query as wq
from weaviate.classes.query import Filter
from weaviate.classes.query import Rerank, MetadataQuery
from weaviate.classes.config import Property, DataType

from weaviate.util import generate_uuid5
from weaviate.collections import Collection
from weaviate.collections.classes.config import (
    Property, Configure, DataType, VectorDistances
)

from enum import Enum
from typing import Dict, Tuple
import numpy as np
import json
from math import floor
import gc
import weakref
import asyncio
import time
from functools import lru_cache
from contextlib import suppress
from pathlib import Path
from typing import List, Dict, Optional, Set
from dataclasses import dataclass
import socket

from llmlingua import PromptCompressor

ollama_url = 'localhost'


In [2]:
class BooksProcessor:
    """
    Processor for managing books in a vector database with different chunk sizes.
    Uses singleton pattern to prevent multiple instances and manage resources properly.
    """
    _instance = None
    _template = None
    
    def __new__(cls, *args, **kwargs):
        if cls._instance is None:
            cls._instance = super().__new__(cls)
        return cls._instance

    def __init__(self, ollama_url: str = 'localhost', 
                 embedding_model_name: str = 'nomic-embed-text',
                 wv_port_rest: int = 8080, 
                 wv_port_grpc: int = 50051):
        """
        Initialize the BooksProcessor.
        
        Args:
            ollama_url: URL for Ollama service
            embedding_model_name: Name of the embedding model to use
            wv_port_rest: Weaviate REST API port
            wv_port_grpc: Weaviate gRPC port
        """
        # Initialize only once due to singleton pattern
        if not hasattr(self, '_initialized'):
            self.embedding_model_name = embedding_model_name
            self.ollama_url = ollama_url
            self.wv_port_rest = wv_port_rest
            self.wv_port_grpc = wv_port_grpc
            self.wv_client = None
            self._transports = weakref.WeakSet()
            self._initialized = True

    def _track_transport(self, transport) -> None:
        """Add transport to tracking set for cleanup"""
        if transport and hasattr(transport, 'close'):
            self._transports.add(transport)

    def _cleanup_transports(self) -> None:
        """Clean up all tracked transports"""
        for transport in list(self._transports):
            with suppress(Exception):
                if not transport.is_closing():
                    transport.close()
        self._transports.clear()
        gc.collect()

    def _ensure_weaviate(self) -> None:
        """Ensure Weaviate connection exists and is properly tracked"""
        if self.wv_client is None:
            self.wv_client = weaviate.connect_to_local(
                host=self.ollama_url,
                port=self.wv_port_rest,
                grpc_port=self.wv_port_grpc,
            )
            
            # Track all possible transport variations
            if hasattr(self.wv_client, '_connection'):
                if hasattr(self.wv_client._connection, 'transport'):
                    self._track_transport(self.wv_client._connection.transport)
                if hasattr(self.wv_client._connection, '_transport'):
                    self._track_transport(self.wv_client._connection._transport)

    def close(self) -> None:
        """Close all connections and clean up resources"""
        if self.wv_client is not None:
            try:
                self.wv_client.close()
            finally:
                self.wv_client = None
        
        self._cleanup_transports()
        time.sleep(0.2)  # Allow time for connections to close
        gc.collect()

    def __enter__(self):
        """Context manager entry"""
        self._ensure_weaviate()
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        """Context manager exit"""
        self.close()

    def create_collection_if_not_exists(self, collection_name: str) -> Collection:
        """
        Create a new collection if it doesn't exist, or get existing one.
        
        Args:
            collection_name: Name for the collection
            
        Returns:
            Weaviate collection object
        """
        self._ensure_weaviate()
        
        try:
            if self.wv_client.collections.exists(collection_name):
                print(f"Getting '{collection_name}'")
            else:
                print(f"Creating '{collection_name}'")
                self.wv_client.collections.create(
                    name=collection_name,
                    properties=[
                        Property(name="chunk", data_type=DataType.TEXT),
                        Property(name="book_name", data_type=DataType.TEXT),
                        Property(name="chunk_num", data_type=DataType.INT)
                    ],
                    vectorizer_config=[
                        Configure.NamedVectors.text2vec_ollama(
                            name="book_vectorizer",
                            source_properties=["book_chunks"],
                            api_endpoint="http://ollama:11434",
                            model=self.embedding_model_name,
                            vector_index_config=Configure.VectorIndex.hnsw(
                                distance_metric=VectorDistances.COSINE
                            )
                        )
                    ]
                )
            return self.wv_client.collections.get(collection_name)
            
        except Exception as e:
            print(f"Error in create_collection_if_not_exists: {e}")
            raise

    def split_book(self, book_text: str, chunk_size: int, chunk_overlap: int) -> List:
        """
        Split book text into chunks.
        
        Args:
            book_text: Full text of the book
            chunk_size: Size of each chunk
            chunk_overlap: Overlap between chunks
            
        Returns:
            List of document chunks
        """
        splitter = RecursiveCharacterTextSplitter(
            chunk_size=chunk_size,
            chunk_overlap=chunk_overlap
        )
        return splitter.create_documents([book_text])

    def send_to_db(self, collection: Collection, chunks: List, book_name: str) -> None:
        """
        Send chunks to the database.
        
        Args:
            collection: Weaviate collection
            chunks: List of document chunks
            book_name: Name of the book
        """
        with collection.batch.fixed_size(batch_size=10) as batch:
            for i, d in enumerate(chunks):
                batch.add_object({
                    "chunk": d.page_content,
                    "book_name": book_name,
                    "chunk_num": int(i)
                })

    def process_book(self, book_name: str, book_txt: str) -> None:
        """
        Process a book by splitting it into chunks and storing in the database.
        
        Args:
            book_name: Name of the book
            book_txt: Full text of the book
        """
        self._ensure_weaviate()
        
        try:
            if self.wv_client.collections.exists(book_name + '_big_chunks'):
                print("Book already exists")
                return

            print("Processing book")
            
            # Process different chunk sizes
            chunk_configs = [
                ('_big_chunks', 3000, 1000),
                ('_medium_chunks', 1500, 500),
                ('_small_chunks', 750, 250)
            ]
            
            for suffix, chunk_size, overlap in chunk_configs:
                collection = self.create_collection_if_not_exists(book_name + suffix)
                chunks = self.split_book(book_txt, chunk_size, overlap)
                self.send_to_db(collection, chunks, book_name)
                gc.collect()  # Clean up after each major operation

            print("Book successfully processed")
            
        except Exception as e:
            print(f"Error in process_book: {e}")
            raise

    def delete_book(self, book_name: str) -> None:
        """
        Delete all collections associated with a book.
        
        Args:
            book_name: Name of the book to delete
        """
        self._ensure_weaviate()
        
        for suffix in ['_big_chunks', '_medium_chunks', '_small_chunks']:
            try:
                self.wv_client.collections.delete(book_name + suffix)
            except Exception as e:
                print(f"Error deleting collection {book_name}{suffix}: {e}")
        
        print(f"Successfully deleted collections for {book_name}")


class ChunkSize(Enum):
    SMALL = '_small_chunks'
    MEDIUM = '_medium_chunks'
    LARGE = '_big_chunks'

class Search:
    """
    Search class for querying book content with proper resource management.
    Uses singleton pattern to prevent multiple instances.
    """
    _instance = None
    _prompt_template = None
    
    def __new__(cls, *args, **kwargs):
        if cls._instance is None:
            cls._instance = super().__new__(cls)
        return cls._instance

    def __init__(self, ollama_url: str = 'localhost', llm_name: str = 'Llama3.2',
                 wv_port_rest: int = 8080, wv_port_grpc: int = 50051):
        """
        Initialize the Search instance.
        
        Args:
            ollama_url: URL for Ollama service
            llm_name: Name of the LLM model to use
            wv_port_rest: Weaviate REST API port
            wv_port_grpc: Weaviate gRPC port
        """
        if not hasattr(self, '_initialized'):
            self.ollama_url = ollama_url
            self.llm_name = llm_name
            self.wv_port_rest = wv_port_rest
            self.wv_port_grpc = wv_port_grpc
            self.llm = None
            self.wv_client = None
            self._transports = weakref.WeakSet()
            self._initialized = True
            self._load_prompt_template()

    @staticmethod
    @lru_cache(maxsize=1)
    def _load_prompt_template() -> None:
        """Load and cache the prompt template"""
        if Search._prompt_template is None:
            with open('classifier_prompt.j2') as f:
                template = f.read()
            Search._prompt_template = PromptTemplate(
                input_variables=["query"],
                template=template,
                template_format="jinja2"
            )

    def _track_transport(self, transport) -> None:
        """Add transport to tracking set for cleanup"""
        if transport and hasattr(transport, 'close'):
            self._transports.add(transport)

    def _cleanup_transports(self) -> None:
        """Clean up all tracked transports"""
        for transport in list(self._transports):
            with suppress(Exception):
                if not transport.is_closing():
                    transport.close()
        self._transports.clear()
        gc.collect()

    def _ensure_llm(self) -> None:
        """Ensure LLM connection exists"""
        if self.llm is None:
            self.llm = OllamaLLM(
                model=self.llm_name,
                temperature=0,
                base_url=f"http://{self.ollama_url}:11434"
            )
            # Track LLM client transport if available
            if hasattr(self.llm, 'client') and hasattr(self.llm.client, '_transport'):
                self._track_transport(self.llm.client._transport)

    def _ensure_weaviate(self) -> None:
        """Ensure Weaviate connection exists"""
        if self.wv_client is None:
            self.wv_client = weaviate.connect_to_local(
                host=self.ollama_url,
                port=self.wv_port_rest,
                grpc_port=self.wv_port_grpc,
            )
            # Track Weaviate client transports
            if hasattr(self.wv_client, '_connection'):
                if hasattr(self.wv_client._connection, 'transport'):
                    self._track_transport(self.wv_client._connection.transport)
                if hasattr(self.wv_client._connection, '_transport'):
                    self._track_transport(self.wv_client._connection._transport)

    def close(self) -> None:
        """Close all connections and clean up resources"""
        try:
            if self.llm is not None:
                if hasattr(self.llm, 'client') and hasattr(self.llm.client, 'close'):
                    self.llm.client.close()
                self.llm = None
                
            if self.wv_client is not None:
                self.wv_client.close()
                self.wv_client = None
                
            self._cleanup_transports()
            time.sleep(0.2)  # Allow time for connections to close
        finally:
            gc.collect()

    def __enter__(self):
        """Context manager entry"""
        self._ensure_llm()
        self._ensure_weaviate()
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        """Context manager exit"""
        self.close()

    def classify_query(self, query: str) -> str:
        """
        Classify the query to determine appropriate chunk size.
        
        Args:
            query: The search query
            
        Returns:
            The classified chunk size suffix
        """
        self._ensure_llm()
        prompt = self._prompt_template.format(query=query)
        response = self.llm.invoke(prompt).strip().upper()
        print(f'Chunk size chosen: {response}')
        try:
            return ChunkSize[response].value
        except KeyError:
            return ChunkSize.MEDIUM.value

    def search(self, query: str, book_name: str) -> str:
        """
        Search for relevant content in the book.
        
        Args:
            query: The search query
            book_name: Name of the book to search in
            
        Returns:
            Retrieved content from the book
        """
        self._ensure_llm()
        self._ensure_weaviate()
        
        try:
            collection_type = self.classify_query(query)
            book = self.wv_client.collections.get(book_name + collection_type)

            result = book.aggregate.over_all(total_count=True)
            total_count = result.total_count
            chunks_to_retrieve = floor(np.maximum(1.5 * np.log(total_count), 1))
            
            response = book.query.hybrid(
                query=query,
                limit=chunks_to_retrieve,
            )
            
            print(f'Chunks retrieved: {chunks_to_retrieve}')
            relevant_chunks = sorted(response.objects, key=lambda x: x.properties['chunk_num'])
            return '\n\n'.join([i.properties['chunk'].strip() for i in relevant_chunks])
            
        except Exception as e:
            print(f"Error in search: {e}")
            raise


@dataclass
class RAGConfig:
    """Configuration for RAG system"""
    compression_rate: float = 0.75
    force_tokens: List[str] = None
    llmlingua_model: str = "microsoft/llmlingua-2-xlm-roberta-large-meetingbank"
    ollama_model: str = "llama3.2"
    ollama_url: str = "localhost"
    temperature: float = 0
    
    def __post_init__(self):
        if self.force_tokens is None:
            self.force_tokens = ['\n', '?', '.', '!']

class SocketManager:
    """Manager for tracking and cleaning up sockets"""
    def __init__(self):
        self.sockets: Set[socket.socket] = set()
        self._original_socket = socket.socket
        self._patch_socket()
        
    def _patch_socket(self):
        """Patch socket creation to track all sockets"""
        def _tracked_socket(*args, **kwargs):
            sock = self._original_socket(*args, **kwargs)
            self.sockets.add(sock)
            return sock
            
        socket.socket = _tracked_socket
        
    def _unpatch_socket(self):
        """Restore original socket"""
        socket.socket = self._original_socket
        
    def cleanup(self):
        """Close all tracked sockets"""
        for sock in list(self.sockets):
            try:
                if not sock._closed:
                    sock.close()
            except Exception:
                pass
        self.sockets.clear()
        self._unpatch_socket()

class ResourceManager:
    """Manager for all system resources"""
    def __init__(self):
        self.socket_manager = SocketManager()
        self._transports = weakref.WeakSet()
        
    def track_transport(self, obj):
        """Track any transports associated with an object"""
        if hasattr(obj, '_transport'):
            self._transports.add(obj._transport)
        if hasattr(obj, 'transport'):
            self._transports.add(obj.transport)
        if hasattr(obj, 'client'):
            if hasattr(obj.client, '_transport'):
                self._transports.add(obj.client._transport)
            if hasattr(obj.client, 'transport'):
                self._transports.add(obj.client.transport)
                
    def cleanup(self):
        """Clean up all resources"""
        # Clean transports
        for transport in list(self._transports):
            try:
                if hasattr(transport, 'close') and not transport.is_closing():
                    transport.close()
            except Exception:
                pass
        self._transports.clear()
        
        # Clean sockets
        self.socket_manager.cleanup()
        
        # Clean event loop
        try:
            loop = asyncio.get_event_loop()
            if not loop.is_closed():
                for task in asyncio.all_tasks(loop):
                    task.cancel()
                loop.run_until_complete(asyncio.sleep(0.1))
                loop.close()
        except Exception:
            pass
            
        # Force cleanup
        time.sleep(0.2)
        gc.collect()

class RAGSystem:
    """RAG system with proper resource management"""
    _instance = None
    
    def __new__(cls, *args, **kwargs):
        if cls._instance is None:
            cls._instance = super().__new__(cls)
        return cls._instance

    def __init__(self, config: Optional[RAGConfig] = None):
        if not hasattr(self, '_initialized'):
            self.config = config or RAGConfig()
            self._initialized = True
            self.resource_manager = ResourceManager()
            self.compressor = None
            self.llm = None
            self._template = None
            self._initialize_resources()

    def _initialize_resources(self):
        """Initialize all required resources"""
        if self.compressor is None:
            self.compressor = PromptCompressor(
                model_name=self.config.llmlingua_model,
                use_llmlingua2=True,
                device_map="cpu"
            )
            self.resource_manager.track_transport(self.compressor)
        
        if self.llm is None:
            self.llm = OllamaLLM(
                model=self.config.ollama_model,  # было llama_model, исправлено на ollama_model
                temperature=self.config.temperature,
                base_url=f"http://{self.config.ollama_url}:11434"
            )
            self.resource_manager.track_transport(self.llm)
        
        if self._template is None:
            with open('final_prompt.j2') as f:
                self._template = f.read()

    def close(self):
        """Cleanup all resources"""
        try:
            if self.compressor is not None:
                if hasattr(self.compressor, 'close'):
                    self.compressor.close()
                self.compressor = None
            
            if self.llm is not None:
                if hasattr(self.llm, 'client') and hasattr(self.llm.client, 'close'):
                    self.llm.client.close()
                self.llm = None
            
            self.resource_manager.cleanup()
            self._initialized = False
            
        except Exception as e:
            print(f"Error during cleanup: {e}")

    def __enter__(self):
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        self.close()

    def _get_context(self, query: str, book_name: str) -> str:
        """Get RAG context for a single book"""
        with Search() as searcher:
            return searcher.search(query, book_name)
            
    def _compress_context(self, context: str) -> str:
        """Compress context using LLMLingua"""
        if not context:
            return ""
        result = self.compressor.compress_prompt(
            context,
            rate=self.config.compression_rate,
            force_tokens=self.config.force_tokens
        )
        return result['compressed_prompt']
        
    def _format_final_prompt(self, 
                           compressed_contexts: List[str],
                           dialogue_history: List[Dict[str, str]],
                           query: str) -> str:
        """Format final prompt using Jinja2 template"""
        from jinja2 import Template
        template = Template(self._template)
        return template.render(
            contexts=compressed_contexts,
            dialogue_history=dialogue_history,
            query=query
        )
        
    def query(self, 
             query: str,
             book_names: List[str],
             dialogue_history: Optional[List[Dict[str, str]]] = None) -> str:
        """
        Execute RAG query across multiple books with dialogue history.
        """
        try:
            self._initialize_resources()  # Ensure resources are initialized
            dialogue_history = dialogue_history or []
            
            # Get and compress context from each book
            compressed_contexts = []
            for book_name in book_names:
                try:
                    context = self._get_context(query, book_name)
                    if context:
                        compressed = self._compress_context(context)
                        compressed_contexts.append(f"From {book_name}:\n{compressed}")
                except Exception as e:
                    print(f"Error processing book {book_name}: {e}")
                    continue
                    
            if not compressed_contexts:
                return "I couldn't find relevant information in the provided books."
                
            # Format final prompt
            final_prompt = self._format_final_prompt(
                compressed_contexts=compressed_contexts,
                dialogue_history=dialogue_history,
                query=query
            )
            
            # Get LLM response
            response = self.llm.invoke(final_prompt)
            return response
            
        except Exception as e:
            print(f"Error in query: {e}")
            raise

In [3]:
# Способ 2: напрямую через контекстный менеджер
with BooksProcessor() as processor:
    with open('Sherlock Study in Scarlet.txt', 'r', encoding='utf8') as file:
        text = file.read()
    processor.process_book('Sherlock_Study_in_Scarlet', text)
    #processor.delete_book('Sherlock_Study_in_Scarlet')

Book already exists


In [4]:
search = Search()
query = "Who was the criminal?"

In [5]:
rag_context = search.search(query=query, book_name='Sherlock_Study_in_Scarlet')

Chunk size chosen: SMALL
Chunks retrieved: 9


In [6]:
# 1. Сначала у нас есть книги в векторной БД (используем BooksProcessor)
#with BooksProcessor() as processor:
#    processor.process_book('book1', text1)
#    processor.process_book('book2', text2)

# 2. Теперь можем использовать RAGSystem для поиска и ответов
config = RAGConfig(compression_rate=0.75)
rag = RAGSystem(config)

try:
    response = rag.query(
        query="What happened in London?",
        book_names=['Sherlock_Study_in_Scarlet'],
        dialogue_history=[]
    )
    print(response)
finally:
    rag.close()

Chunk size chosen: MEDIUM
Chunks retrieved: 8


  gc.collect()
Token indices sequence length is longer than the specified maximum sequence length for this model (2902 > 512). Running this sequence through the model will result in indexing errors


Based on the provided context from "Sherlock_Study_in_Scarlet", it appears that Sherlock Holmes and Dr. Watson were discussing various cases they had worked on.

One case involved a man named Mr. Drebber, who was found dead under mysterious circumstances. Mrs. Charpentier reported that her son, Lieutenant Charpentier, returned home around 11 pm, but she didn't know what he did during the two hours he was gone. Holmes believed this was the key to solving the case and arrested Lieutenant Charpentier.

Additionally, there was an incident involving a group of dirty street Arabs who were sent by Wiggins, a member of the Baker Street division of detective police force, to report on some suspicious activity. However, they didn't find anything significant.

It's also mentioned that Holmes had previously dealt with a case involving a man named Mr. Drebber, but the details are not provided in this context.

There is no mention of any other significant events or happenings in London beyond these 

  pass
