In [1]:
%pip install -q chromadb openai pydantic sentence-transformers orjson

^C
Note: you may need to restart the kernel to use updated packages.


In [None]:
import asyncio
import orjson
import logging
from typing import Dict, List, Literal, Optional
import numpy as np
from chromadb import AsyncHttpClient, Settings
from openai import AsyncOpenAI
from pydantic import BaseModel, Field, field_validator, ValidationError
from sentence_transformers import SentenceTransformer
import functools
import json

logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s - %(levelname)s - %(message)s",
)
logger = logging.getLogger(__name__)

QWEN_MODEL = "Qwen/Qwen2.5-14B-Instruct-GPTQ-Int8"
vllm_client = AsyncOpenAI(
    base_url="http://65.109.137.0:60564/v1", api_key="dummy_key"
)

embedding_model = SentenceTransformer("cointegrated/LaBSE-en-ru")

chroma_client = await AsyncHttpClient(host="91.184.242.207", port=8000, settings=Settings(anonymized_telemetry=False))
collection = await chroma_client.get_or_create_collection("documents")

def create_embeddings(texts: List[str]):
    embeddings = embedding_model.encode(texts)
    logger.debug(f"Created embeddings for texts: {texts}")
    return embeddings

@functools.lru_cache(maxsize=1000)
def create_cached_embeddings(text: str) -> np.ndarray:
    return create_embeddings([text])[0]

class SourceReference(BaseModel):
    document_title: str = Field(..., description="Title of the referenced document")
    section: str = Field(..., description="Section number or identifier")
    exact_quote: str = Field(..., description="Direct quote from the source")
    relevance: Literal["high", "medium", "low"] = Field(..., description="Relevance level of the reference")

class ThinkStep(BaseModel):
    reasoning: str = Field(..., description="Step-by-step thought process")
    conclusion: str = Field(..., description="Intermediate or final conclusion")

class Checklist(BaseModel):
    query_understood: bool = Field(..., description="Query is fully understood")
    context_analyzed: bool = Field(..., description="Relevant context found and analyzed")
    sources_verified: bool = Field(..., description="Sources properly referenced")
    reasoning_complete: bool = Field(..., description="Full analysis conducted")
    answer_validated: bool = Field(..., description="Answer checked for accuracy")
    additional_notes: Optional[str] = Field(None, description="Any additional verification notes")

    @field_validator("additional_notes", mode="before")
    def validate_notes(cls, value):
        if isinstance(value, bool):
            return None
        return value

class Answer(BaseModel):
    source_references: List[SourceReference] = Field(..., description="List of relevant source references")
    thinking_steps: List[ThinkStep] = Field(..., description="Chain of reasoning steps")
    brief_answer: str = Field(..., description="Concise answer to the query")
    detailed_answer: Optional[str] = Field(None, description="Detailed explanation if needed")
    checklist: Checklist = Field(..., description="Validation checklist")

class Prompts:
    SYSTEM = """
    <system>
        <task>
            <primary>Process documentation queries with structured output</primary>
            <approach>Chain-of-thought reasoning with validation checklist</approach>
        </task>
        
        <rules>
            <output>
                <format>Strict JSON adherence to provided schema</format>
                <validation>Must pass all checklist items</validation>
            </output>
            <reasoning>
                <steps>Sequential thought process</steps>
                <verification>Cross-reference with documentation</verification>
            </reasoning>
        </rules>
        
        <quality>
            <accuracy>Verify against source material</accuracy>
            <completeness>Address all question aspects</completeness>
            <structure>Follow provided schema exactly</structure>
        </quality>
    </system>
    """

    CONTEXT = """
    <context>
        <documentation>
            {context_text}
        </documentation>
        <schema>
            {schema_definition}
        </schema>
        <query>{question}</query>
    </context>
    """

    CLARIFICATION = """
    <clarification>
        <task>
            <primary>Determine if query needs clarification</primary>
            <output_format>Single question or "No clarification needed"</output_format>
        </task>
        
        <rules>
            <analysis>
                <check>Query completeness</check>
                <check>Technical specificity</check>
                <check>Context sufficiency</check>
            </analysis>
            <response>
                <format>Clear, specific question</format>
                <language>English</language>
            </response>
        </rules>
        
        <validation>
            <condition>If query is clear and complete</condition>
            <response>"No clarification needed"</response>
        </validation>
    </clarification>
    """

    @staticmethod
    def get_schema_definition() -> str:
        """Convert Pydantic model to OpenAI-compatible JSON schema"""
        schema = Answer.model_json_schema()
        # Remove unnecessary fields from schema
        for key in ['title', 'description']:
            schema.pop(key, None)
        return json.dumps(schema, indent=2)

    @classmethod
    def format_context(cls, context_text: str, question: str) -> str:
        return cls.CONTEXT.format(
            context_text=context_text,
            schema_definition=cls.get_schema_definition(),
            question=question
        )

def validate_clarification(clarification: str) -> bool:
    if not clarification.endswith('?') and clarification != "No clarification needed":
        return False
    return True

def clean_response(response_text: str) -> str:
    return "".join(c for c in response_text if c.isprintable() or c in "\n\t")

def create_error_answer(error_message: str) -> Answer:
    return Answer(
        source_references=[],
        thinking_steps=[],
        brief_answer=f"Error processing response: {error_message}",
        detailed_answer=None,
        checklist=Checklist(
            query_understood=False,
            context_analyzed=False,
            sources_verified=False,
            reasoning_complete=False,
            answer_validated=False,
            additional_notes=None
        )
    )

async def get_relevant_documents(query: str) -> List[Dict]:
    query_embedding = create_embeddings([query])[0]

    results = await collection.query(
        query_embeddings=[query_embedding.tolist()],
        n_results=200,
        include=["documents", "metadatas"],
    )

    relevant_docs = []
    seen_sections = set()

    for doc, meta in zip(results["documents"][0], results["metadatas"][0]):
        section_number = meta.get("section_number")
        if section_number not in seen_sections:
            relevant_docs.append({
                "content": doc,
                "metadata": meta,
                "section_number": section_number
            })
            seen_sections.add(section_number)

            if section_number:
                for i in range(-1, 2):
                    nearby_section = section_number + i
                    if nearby_section not in seen_sections:
                        for nearby_doc, nearby_meta in zip(
                            results["documents"][0],
                            results["metadatas"][0]
                        ):
                            if nearby_meta.get("section_number") == nearby_section:
                                relevant_docs.append({
                                    "content": nearby_doc,
                                    "metadata": nearby_meta,
                                    "section_number": nearby_section
                                })
                                seen_sections.add(nearby_section)
                                break

    return relevant_docs[:15]

async def generate_clarifying_question(original_question: str) -> str:
    response = await vllm_client.chat.completions.parse(
        model=QWEN_MODEL,
        messages=[
            {"role": "system", "content": Prompts.CLARIFICATION},
            {"role": "user", "content": f"Query: {original_question}"}
        ],
        temperature=0.3,
        max_tokens=256,
        response_format={"type": "text"}
    )
    return response.strip()

async def ask_question(question: str) -> Answer:
    try:
        context = await get_relevant_documents(question)
        context_text = "\n".join(
            f"Section {doc['section_number']}: {doc['content']}" 
            for doc in context
        )

        response = await vllm_client.chat.completions.parse(
            model=QWEN_MODEL,
            messages=[
                {"role": "system", "content": Prompts.SYSTEM},
                {"role": "user", "content": Prompts.format_context(
                    context_text=context_text,
                    question=question
                )}
            ],
            response_format=Answer,
            temperature=0.3,
            max_tokens=2048
        )
        
        return response

    except Exception as e:
        logger.error(f"Error in ask_question: {str(e)}")
        return create_error_answer(str(e))

In [None]:
# CLI интерфейс
async def chat_loop():
    def format_response(response: Answer) -> str:
        try:
            return orjson.dumps(
                response.model_dump(),
                default=lambda x: (
                    float(x) if isinstance(x, (np.integer, np.floating))
                    else x.tolist() if isinstance(x, np.ndarray)
                    else x
                ),
                option=orjson.OPT_INDENT_2
            ).decode('utf-8')
        except Exception as e:
            logger.error(f"Error formatting response: {e}")
            return str(response.model_dump())

    async def handle_user_feedback(question: str) -> None:
        clarifying_question = await generate_clarifying_question(question)
        print("\nClarifying question:")
        print(clarifying_question)
        
        if clarifying_question != "No clarification needed":
            follow_up = input("\nWould you like to ask a clarifying question? (y/n): ").strip().lower()
            if follow_up == "y":
                new_question = input("\nEnter clarified question: ").strip()
                if new_question:
                    response = await ask_question(new_question)
                    print_response(response)

    def print_response(response: Answer) -> None:
        print("\nStructured response:")
        print(format_response(response))
        
        if response.brief_answer:
            print("\nBrief answer:")
            print(response.brief_answer)
            
        if response.detailed_answer:
            print("\nDetailed answer:")
            print(response.detailed_answer)

    print("CLI Chat Interface (type 'exit' to quit)")
    
    while True:
        try:
            question = input("\nYour question: ").strip()
            
            if not question:
                print("Please enter a question.")
                continue
                
            if question.lower() in ["exit", "quit"]:
                print("Exiting...")
                break

            response = await ask_question(question)
            print_response(response)

            satisfaction = input("\nAre you satisfied with the answer? (y/n): ").strip().lower()
            if satisfaction == "n":
                await handle_user_feedback(question)

        except Exception as e:
            logger.error(f"Error in chat loop: {str(e)}")
            print(f"\nError occurred: {str(e)}")

await chat_loop()