From 5f2dc180f5334a55fddb542134d853c5006f2079 Mon Sep 17 00:00:00 2001 From: Aleksandr Samofalov Date: Wed, 15 Apr 2026 01:36:57 +0300 Subject: [PATCH 1/5] feat(langfuse): nested Redis/Chroma spans inside LangGraph nodes + tests Observability: - Added current_span ContextVar to track active LangGraph node span. - Wrapped all agent nodes (base, core, loop) with manual Langfuse spans; LLM callbacks are now scoped to the node span via CallbackHandler(stateful_client=span). - Updated RedisAdapter and ChromaAdapter to nest their spans under current_span instead of attaching directly to the root trace. - Removed global CallbackHandler from /chat router to avoid duplicate spans. Tests: - Added tests/service/api/v1/test_router.py covering /test_invoke (success & 424) and /chat (success & rate-limit). - Added tests/agents/profkom_consultant/nodes/test_base.py for update_user_history_context (append, trim, 1:1 sync). Fixes: - Fixed history trim bug: model_answers now uses [-HISTORY_LIMIT:] instead of [-trim_count:], keeping Q&A lists synchronized. Deps: - Added pytest, pytest-asyncio, httpx to dev dependencies. --- pyproject.toml | 3 + src/agents/profkom_consultant/nodes/base.py | 131 ++++++++------ src/agents/profkom_consultant/nodes/core.py | 160 ++++++++++-------- src/agents/profkom_consultant/nodes/loop.py | 92 +++++----- src/modules/chroma_ext/base.py | 119 ++++++++----- src/modules/redis_ext/base.py | 60 +++++-- src/service/api/v1/router.py | 9 +- src/service/logger/context_vars.py | 6 +- tests/__init__.py | 0 tests/agents/__init__.py | 0 tests/agents/profkom_consultant/__init__.py | 0 .../profkom_consultant/nodes/__init__.py | 0 .../profkom_consultant/nodes/test_base.py | 52 ++++++ tests/conftest.py | 33 ++++ tests/service/__init__.py | 0 tests/service/api/__init__.py | 0 tests/service/api/v1/__init__.py | 0 tests/service/api/v1/test_router.py | 117 +++++++++++++ uv.lock | 53 ++++++ 19 files changed, 604 insertions(+), 231 deletions(-) create mode 100644 tests/__init__.py create mode 100644 tests/agents/__init__.py create mode 100644 tests/agents/profkom_consultant/__init__.py create mode 100644 tests/agents/profkom_consultant/nodes/__init__.py create mode 100644 tests/agents/profkom_consultant/nodes/test_base.py create mode 100644 tests/conftest.py create mode 100644 tests/service/__init__.py create mode 100644 tests/service/api/__init__.py create mode 100644 tests/service/api/v1/__init__.py create mode 100644 tests/service/api/v1/test_router.py diff --git a/pyproject.toml b/pyproject.toml index 07101b2..a606e62 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -61,6 +61,9 @@ dev = [ "ruff", "pylint", "pre-commit", + "pytest", + "pytest-asyncio", + "httpx", ] [tool.ruff] diff --git a/src/agents/profkom_consultant/nodes/base.py b/src/agents/profkom_consultant/nodes/base.py index e019b2e..77b54c1 100644 --- a/src/agents/profkom_consultant/nodes/base.py +++ b/src/agents/profkom_consultant/nodes/base.py @@ -1,9 +1,31 @@ +from contextlib import asynccontextmanager + from langchain_core.prompts import ChatPromptTemplate +from langfuse.callback import CallbackHandler from agents.profkom_consultant.states import AgentState +from service.logger.context_vars import current_span, current_trace class BaseAgentNodes: + @asynccontextmanager + async def _node_span(self, name: str, state: AgentState): + trace = current_trace.get() + span = None + if trace: + span = trace.span(name=name, input={"text": state.get("text", "")}) + current_span.set(span) + try: + yield span + finally: + if span: + span.end() + current_span.set(None) + + def _llm_config(self, span): + if span: + return {"callbacks": [CallbackHandler(stateful_client=span)]} + return {} async def validate_text(self, state: AgentState) -> AgentState: """Проверяем, что текст вопроса пользователя соответсвует публичной политики. @@ -14,39 +36,40 @@ async def validate_text(self, state: AgentState) -> AgentState: Return: Бинарное значение - спам или нормальный вопрос к агенту. """ - question = state["text"] - try: - cached_result = self.cache.get(meta_info="validate_input", query=question) - if cached_result: - self.logger.debug(f"Cached result {cached_result}") - state["is_valid"] = cached_result.get("json").get("is_valid") - state["final_answer"] = cached_result.get("json").get("final_answer") - return state - else: - self.logger.debug(f"Cached result {cached_result}") - prompt = ChatPromptTemplate.from_template( - self.langfuse_client.get_prompt("policy_validation").get_langchain_prompt() # TO DO: FIX - ) - chain = prompt | self.llm - output = await chain.ainvoke({"text": state["text"]}) - output = output.content.strip().lower() - self.logger.info(f"Output: {output}") + async with self._node_span("validate_text", state) as span: + question = state["text"] + try: + cached_result = self.cache.get(meta_info="validate_input", query=question) + if cached_result: + self.logger.debug(f"Cached result {cached_result}") + state["is_valid"] = cached_result.get("json").get("is_valid") + state["final_answer"] = cached_result.get("json").get("final_answer") + return state + else: + self.logger.debug(f"Cached result {cached_result}") + prompt = ChatPromptTemplate.from_template( + self.langfuse_client.get_prompt("policy_validation").get_langchain_prompt() # TO DO: FIX + ) + chain = prompt | self.llm + output = await chain.ainvoke({"text": state["text"]}, config=self._llm_config(span)) + output = output.content.strip().lower() + self.logger.info(f"Output: {output}") + + is_valid = "да" in output + + cache_data = {"is_valid": is_valid} - is_valid = "да" in output - - cache_data = {"is_valid": is_valid} - - if not is_valid: - cache_data["final_answer"] = "Не прошёл валидацию" - state["final_answer"] = cache_data["final_answer"] + if not is_valid: + cache_data["final_answer"] = "Не прошёл валидацию" + state["final_answer"] = cache_data["final_answer"] - state["is_valid"] = is_valid - self.logger.debug(f"is_valid: {is_valid}") - self.cache.save(meta_info="validate_input", query=question, output="", json_data=cache_data) - return state + state["is_valid"] = is_valid + self.logger.debug(f"is_valid: {is_valid}") + self.cache.save(meta_info="validate_input", query=question, output="", json_data=cache_data) + return state - except Exception as e: - print(f"Validate error at validate_input: {e}") + except Exception as e: + self.logger.error(f"Validate error at validate_input: {e}") async def validate_final_answer(self, state: AgentState) -> AgentState: """Проверяем, что текст ответа модели соответсвует публичной политики. @@ -58,28 +81,29 @@ async def validate_final_answer(self, state: AgentState) -> AgentState: Return: Бинарное значение - спам или нормальный ответ от агента. """ - final_answer = state.get("final_answer", "") - try: - cached_result = self.cache.get(meta_info="validate_final_answer", query=final_answer) - if cached_result: - state["is_valid"] = cached_result.get("json").get("is_valid") or True - return state - else: - prompt = self.langfuse_client.get_prompt("policy_validation").get_langchain_prompt() - prompt = ChatPromptTemplate.from_template(prompt) - chain = prompt | self.llm - output = await chain.ainvoke({"text": final_answer}) - - is_valid = "да" in output.content.strip().lower() - cache_data = {"answer": is_valid} - if not is_valid: - state["final_answer"] = "Не прошёл валидацию" - state["is_valid"] = is_valid - self.cache.save(meta_info="validate_final_answer", query=final_answer, output="", json_data=cache_data) - return state - - except Exception as e: - print(f"Error at validate_final_answer: {e}") + async with self._node_span("validate_final_answer", state) as span: + final_answer = state.get("final_answer", "") + try: + cached_result = self.cache.get(meta_info="validate_final_answer", query=final_answer) + if cached_result: + state["is_valid"] = cached_result.get("json").get("is_valid") or True + return state + else: + prompt = self.langfuse_client.get_prompt("policy_validation").get_langchain_prompt() + prompt = ChatPromptTemplate.from_template(prompt) + chain = prompt | self.llm + output = await chain.ainvoke({"text": final_answer}, config=self._llm_config(span)) + + is_valid = "да" in output.content.strip().lower() + cache_data = {"answer": is_valid} + if not is_valid: + state["final_answer"] = "Не прошёл валидацию" + state["is_valid"] = is_valid + self.cache.save(meta_info="validate_final_answer", query=final_answer, output="", json_data=cache_data) + return state + + except Exception as e: + self.logger.error(f"Error at validate_final_answer: {e}") def update_user_history_context(self, state: AgentState) -> AgentState: """Обновляет историю вопросов/ответов: аппендит текущий вопрос + ответ, тримирует до HISTORY_LIMIT. @@ -101,9 +125,8 @@ def update_user_history_context(self, state: AgentState) -> AgentState: state["model_answers"] = [state["final_answer"]] if len(state["user_history"]) > self.HISTORY_LIMIT: - trim_count = len(state["user_history"]) - self.HISTORY_LIMIT state["user_history"] = state["user_history"][-self.HISTORY_LIMIT :] - state["model_answers"] = state["model_answers"][-trim_count:] + state["model_answers"] = state["model_answers"][-self.HISTORY_LIMIT :] return {"user_history": state["user_history"], "model_answers": state["model_answers"]} diff --git a/src/agents/profkom_consultant/nodes/core.py b/src/agents/profkom_consultant/nodes/core.py index ff4aeac..7d16782 100644 --- a/src/agents/profkom_consultant/nodes/core.py +++ b/src/agents/profkom_consultant/nodes/core.py @@ -46,10 +46,14 @@ async def _detect_topics_for_question(self, question: str) -> str: Returns: Relevant topic. """ + from service.logger.context_vars import current_span + prompt = self.langfuse_client.get_prompt("topic_choose_router").get_langchain_prompt() prompt = ChatPromptTemplate.from_template(prompt) chain = prompt | self.llm - response = await chain.ainvoke({"question": question}) + span = current_span.get() + config = self._llm_config(span) + response = await chain.ainvoke({"question": question}, config=config) return response.content.strip() async def decompose_question(self, state: AgentState) -> None | dict[str, Any] | dict[str, list[Any]]: @@ -66,30 +70,32 @@ async def decompose_question(self, state: AgentState) -> None | dict[str, Any] | Return: Словарь простых вопросов пользователя. """ - question = state["text"] - try: - cached_result = self.cache.get(meta_info="decompose_question_" + state["user_id"], query=question) - if cached_result: - return {"parts": cached_result.get("json").get("parts")} - else: - prompt = self.langfuse_client.get_prompt("decompose_question").get_langchain_prompt() - prompt = ChatPromptTemplate.from_template(prompt) - chain = prompt | self.llm - response = await chain.ainvoke( - {"user_question": question, "user_history": state.get("user_history", "")} - ) - response = response.content.strip() + async with self._node_span("decompose_question", state) as span: + question = state["text"] + try: + cached_result = self.cache.get(meta_info="decompose_question_" + state["user_id"], query=question) + if cached_result: + return {"parts": cached_result.get("json").get("parts")} + else: + prompt = self.langfuse_client.get_prompt("decompose_question").get_langchain_prompt() + prompt = ChatPromptTemplate.from_template(prompt) + chain = prompt | self.llm + response = await chain.ainvoke( + {"user_question": question, "user_history": state.get("user_history", "")}, + config=self._llm_config(span), + ) + response = response.content.strip() - content = re.search(r"<ЗАДАЧИ.*?>(.*?)", response, re.IGNORECASE | re.DOTALL) - content = content.group(1) if content else response + content = re.search(r"<ЗАДАЧИ.*?>(.*?)", response, re.IGNORECASE | re.DOTALL) + content = content.group(1) if content else response - cache_data = {"parts": [p.strip() for p in content.split("") if p.strip()]} - self.cache.save( - meta_info="decompose_question_" + state["user_id"], query=question, output="", json_data=cache_data - ) - return cache_data - except Exception as e: - print(f"Error at decompose_question: {e}") + cache_data = {"parts": [p.strip() for p in content.split("") if p.strip()]} + self.cache.save( + meta_info="decompose_question_" + state["user_id"], query=question, output="", json_data=cache_data + ) + return cache_data + except Exception as e: + self.logger.error(f"Error at decompose_question: {e}") async def answer_parts_async(self, state: AgentState, max_concurrent: int = 8) -> AgentState: """Генерируем асинхронные ответы на список вопросов. @@ -100,37 +106,39 @@ async def answer_parts_async(self, state: AgentState, max_concurrent: int = 8) - Returns: Список простых ответов на глобальный вопрос пользователя. """ - state["answers"] = [] - semaphore = asyncio.Semaphore(max_concurrent) + async with self._node_span("answer_parts_async", state) as span: + state["answers"] = [] + semaphore = asyncio.Semaphore(max_concurrent) - prompt = self.langfuse_client.get_prompt("query_worker").get_langchain_prompt() - prompt = ChatPromptTemplate.from_template(prompt) - # TO DO: CHECK что мы умеем работать с данными RAG - chain = prompt | self.llm + prompt = self.langfuse_client.get_prompt("query_worker").get_langchain_prompt() + prompt = ChatPromptTemplate.from_template(prompt) + # TO DO: CHECK что мы умеем работать с данными RAG + chain = prompt | self.llm - async def call_llm(part: str) -> str: - self.logger.info(f"Calling {part}") - async with semaphore: - cached_result = self.cache.get(meta_info="answer_parts_async", query=part) - if cached_result: - return cached_result.get("json").get("answer") - else: - topic = await self._detect_topics_for_question(part) - self.logger.info(f"Topic: {topic}") - retrived_data = await asyncio.to_thread( - self.chorma_client.get_info, query=part, collection_name=self.COLLECTION_NAME, topics=[topic] - ) - html_data = retrived_data.to_html() - result = await chain.ainvoke({"text": part, "rag": html_data}) - cache_data = {"answer": result.content.strip()} - self.cache.save(meta_info="answer_parts_async", query=part, output="", json_data=cache_data) - return cache_data.get("answer") - - if state.get("parts"): - tasks = [asyncio.create_task(call_llm(part)) for part in state["parts"]] - results = await asyncio.gather(*tasks, return_exceptions=True) - state["answers"] = [str(r) if not isinstance(r, Exception) else f"Error: {r}" for r in results] - return state + async def call_llm(part: str) -> str: + self.logger.info(f"Calling {part}") + async with semaphore: + cached_result = self.cache.get(meta_info="answer_parts_async", query=part) + if cached_result: + return cached_result.get("json").get("answer") + else: + topic = await self._detect_topics_for_question(part) + self.logger.info(f"Topic: {topic}") + retrived_data = await asyncio.to_thread( + self.chorma_client.get_info, query=part, collection_name=self.COLLECTION_NAME, topics=[topic] + ) + html_data = retrived_data.to_html() + config = self._llm_config(span) + result = await chain.ainvoke({"text": part, "rag": html_data}, config=config) + cache_data = {"answer": result.content.strip()} + self.cache.save(meta_info="answer_parts_async", query=part, output="", json_data=cache_data) + return cache_data.get("answer") + + if state.get("parts"): + tasks = [asyncio.create_task(call_llm(part)) for part in state["parts"]] + results = await asyncio.gather(*tasks, return_exceptions=True) + state["answers"] = [str(r) if not isinstance(r, Exception) else f"Error: {r}" for r in results] + return state async def collect_final_answer(self, state: AgentState) -> AgentState: """Собираем итоговый ответ на вопрос пользователя. @@ -143,26 +151,28 @@ async def collect_final_answer(self, state: AgentState) -> AgentState: Return: Итоговый текст ответа пользователю на вопрос. """ - question = state["text"] - if state.get("answers"): - answers_text = "\n".join(f"{i + 1}. {ans}" for i, ans in enumerate(state["answers"]) if ans) - prompt = self.langfuse_client.get_prompt("summary_response").get_langchain_prompt() - prompt = ChatPromptTemplate.from_template(prompt) - chain = prompt | self.llm - # TO DO: CHECK что у нас огромный промпт не ломает ответ - response = await chain.ainvoke( - { - "task_responses": answers_text, - "user_history": state.get("user_history", "Нет истории запросов."), - "original_question": question, - "model_answers": state.get("model_answers", "Нет истории ответов от модели"), - "additional_info": state.get( - "additional_info", "Нет дополнительной информации по предыдущим ответам." - ), - } - ) - response = response.content.strip() - state["final_answer"] = response - else: - state["final_answer"] = "Нет данных для итогового ответа." - return state + async with self._node_span("collect_final_answer", state) as span: + question = state["text"] + if state.get("answers"): + answers_text = "\n".join(f"{i + 1}. {ans}" for i, ans in enumerate(state["answers"]) if ans) + prompt = self.langfuse_client.get_prompt("summary_response").get_langchain_prompt() + prompt = ChatPromptTemplate.from_template(prompt) + chain = prompt | self.llm + # TO DO: CHECK что у нас огромный промпт не ломает ответ + response = await chain.ainvoke( + { + "task_responses": answers_text, + "user_history": state.get("user_history", "Нет истории запросов."), + "original_question": question, + "model_answers": state.get("model_answers", "Нет истории ответов от модели"), + "additional_info": state.get( + "additional_info", "Нет дополнительной информации по предыдущим ответам." + ), + }, + config=self._llm_config(span), + ) + response = response.content.strip() + state["final_answer"] = response + else: + state["final_answer"] = "Нет данных для итогового ответа." + return state diff --git a/src/agents/profkom_consultant/nodes/loop.py b/src/agents/profkom_consultant/nodes/loop.py index f64987a..4a31bc1 100644 --- a/src/agents/profkom_consultant/nodes/loop.py +++ b/src/agents/profkom_consultant/nodes/loop.py @@ -21,37 +21,39 @@ async def check_user_answer(self, state: AgentState) -> AgentState: - status="DONE" если final_answer релевантен text. - status="AGAIN" + counter_loop +=1 если нет (max 3). """ - prompt = self.langfuse_client.get_prompt("check_user_answer").get_langchain_prompt() - prompt = ChatPromptTemplate.from_template(prompt) - chain = prompt | self.llm - response = await chain.ainvoke( - { - "question": state["text"], - "parts": state.get("parts", "[]"), - "history_questions": state.get("user_history", "[]"), - "answer": state["final_answer"], - } - ) - response = "DONE" in response.content.strip().upper() - - if response: - state["status"] = AgentStatus.DONE - state["counter_loop"] = 0 - state["additional_info"] = "" - else: - counter = state.get("counter_loop", 0) - if counter >= self.MAX_LOOP_GENERATION: + async with self._node_span("check_user_answer", state) as span: + prompt = self.langfuse_client.get_prompt("check_user_answer").get_langchain_prompt() + prompt = ChatPromptTemplate.from_template(prompt) + chain = prompt | self.llm + response = await chain.ainvoke( + { + "question": state["text"], + "parts": state.get("parts", "[]"), + "history_questions": state.get("user_history", "[]"), + "answer": state["final_answer"], + }, + config=self._llm_config(span), + ) + response = "DONE" in response.content.strip().upper() + + if response: state["status"] = AgentStatus.DONE state["counter_loop"] = 0 state["additional_info"] = "" else: - if not state.get("counter_loop"): + counter = state.get("counter_loop", 0) + if counter >= self.MAX_LOOP_GENERATION: + state["status"] = AgentStatus.DONE state["counter_loop"] = 0 + state["additional_info"] = "" + else: + if not state.get("counter_loop"): + state["counter_loop"] = 0 - state["counter_loop"] += 1 - state["additional_info"] = state["final_answer"] - state["status"] = AgentStatus.AGAIN - return state + state["counter_loop"] += 1 + state["additional_info"] = state["final_answer"] + state["status"] = AgentStatus.AGAIN + return state async def generate_additional_questions(self, state) -> AgentState: """Генерируем новые вопросы чтобы ответить на вопрос пользователя. @@ -63,22 +65,24 @@ async def generate_additional_questions(self, state) -> AgentState: Return: Новый список вопросов. """ - prompt = self.langfuse_client.get_prompt("generate_additional_questions").get_langchain_prompt() - prompt = ChatPromptTemplate.from_template(prompt) - chain = prompt | self.llm - response = await chain.ainvoke( - { - "question": state["text"], - "history_questions": state.get("user_history", "[]"), - "answer": state["final_answer"], - "parts": state.get("parts", "[]"), - } - ) - - response = response.content.strip() - - content = re.search(r"<ЗАДАЧИ.*?>(.*?)", response, re.IGNORECASE | re.DOTALL) - content = content.group(1) if content else response - - data = {"parts": [p.strip() for p in content.split("") if p.strip()]} - return data + async with self._node_span("generate_additional_questions", state) as span: + prompt = self.langfuse_client.get_prompt("generate_additional_questions").get_langchain_prompt() + prompt = ChatPromptTemplate.from_template(prompt) + chain = prompt | self.llm + response = await chain.ainvoke( + { + "question": state["text"], + "history_questions": state.get("user_history", "[]"), + "answer": state["final_answer"], + "parts": state.get("parts", "[]"), + }, + config=self._llm_config(span), + ) + + response = response.content.strip() + + content = re.search(r"<ЗАДАЧИ.*?>(.*?)", response, re.IGNORECASE | re.DOTALL) + content = content.group(1) if content else response + + data = {"parts": [p.strip() for p in content.split("") if p.strip()]} + return data diff --git a/src/modules/chroma_ext/base.py b/src/modules/chroma_ext/base.py index 4da0798..8946555 100644 --- a/src/modules/chroma_ext/base.py +++ b/src/modules/chroma_ext/base.py @@ -5,6 +5,7 @@ from chromadb import QueryResult from service.logger import LoggerConfigurator +from service.logger.context_vars import current_span, current_trace from .utils import BM25Reranker, MyEmbeddingFunction @@ -84,6 +85,15 @@ def embedding_function(self): self.logger.debug("embedding_function initialized") return self._embedding_function + def _start_span(self, name: str, input_data: dict): + span = current_span.get() + if span: + return span.span(name=name, input=input_data) + trace = current_trace.get() + if trace: + return trace.span(name=name, input=input_data) + return None + def get_info_from_db( self, query: str, collection_name: str, n_results: int = 30, where: dict | None = None, **kwargs ) -> QueryResult: @@ -99,15 +109,30 @@ def get_info_from_db( Returns: relevant documents """ - self.logger.debug(f"get_info_from_db called for {collection_name}") - collection = self.client.get_collection(name=collection_name, embedding_function=self.embedding_function) - - return collection.query( - query_texts=[query], - n_results=n_results, - include=["documents", "metadatas", "distances"], - where=where, - ) + span = self._start_span("chroma_query", { + "query": query, + "collection": collection_name, + "n_results": n_results, + "where": where, + }) + try: + self.logger.debug(f"get_info_from_db called for {collection_name}") + collection = self.client.get_collection(name=collection_name, embedding_function=self.embedding_function) + + result = collection.query( + query_texts=[query], + n_results=n_results, + include=["documents", "metadatas", "distances"], + where=where, + ) + if span: + docs = result.get("documents", [[]])[0] + span.end(output={"documents_returned": len(docs)}) + return result + except Exception as e: + if span: + span.end(level="ERROR", status_message=str(e)) + raise def get_filtered_documents(self, data_raw: Dict[str, Any]) -> dict: self.logger.debug(f"get_filtered_documents: documents number {len(data_raw['documents'])}") @@ -137,41 +162,55 @@ def apply_reranker(self, query, documents): def get_info(self, query: str, collection_name: str, topics: list[str] | None = None) -> pd.DataFrame: # TO DO: фильтрация по метаданным и потом только query! - self.logger.debug(f"called {query} in get_info for {collection_name} and topics {topics}") - - where = None - if topics: - # один topic можно передать прямо строкой, несколько — через $in - if len(topics) == 1: - where = {"topic": topics[0]} - else: - where = {"topic": {"$in": topics}} - - data_raw = self.get_info_from_db( - query=query, - collection_name=collection_name, - n_results=self.max_rag_documents, - where=where, - ) - filtered_documents = self.get_filtered_documents(data_raw) - - if not filtered_documents["documents"]: - self.logger.debug(f"no documents found in {collection_name}") + span = self._start_span("chroma_rag", { + "query": query, + "collection": collection_name, + "topics": topics, + }) + try: + self.logger.debug(f"called {query} in get_info for {collection_name} and topics {topics}") + + where = None + if topics: + # один topic можно передать прямо строкой, несколько — через $in + if len(topics) == 1: + where = {"topic": topics[0]} + else: + where = {"topic": {"$in": topics}} + + data_raw = self.get_info_from_db( + query=query, + collection_name=collection_name, + n_results=self.max_rag_documents, + where=where, + ) + filtered_documents = self.get_filtered_documents(data_raw) + + if not filtered_documents["documents"]: + self.logger.debug(f"no documents found in {collection_name}") + if span: + span.end(output={"documents_found": 0}) + return pd.DataFrame.from_dict( + data={ + "documents": [], + "metadatas": [], + } + ) + + idx_relevant_documents = self.apply_reranker(query=query, documents=filtered_documents["documents"]) + self.logger.debug(f"Finished get_info for {query} returned {len(idx_relevant_documents)} documents") + if span: + span.end(output={"documents_found": len(idx_relevant_documents)}) return pd.DataFrame.from_dict( data={ - "documents": [], - "metadatas": [], + "documents": [filtered_documents["documents"][idx] for idx in idx_relevant_documents], + "metadatas": [filtered_documents["metadatas"][idx] for idx in idx_relevant_documents], } ) - - idx_relevant_documents = self.apply_reranker(query=query, documents=filtered_documents["documents"]) - self.logger.debug(f"Finished get_info for {query} returned {len(idx_relevant_documents)} documents") - return pd.DataFrame.from_dict( - data={ - "documents": [filtered_documents["documents"][idx] for idx in idx_relevant_documents], - "metadatas": [filtered_documents["metadatas"][idx] for idx in idx_relevant_documents], - } - ) + except Exception as e: + if span: + span.end(level="ERROR", status_message=str(e)) + raise def health_check(self) -> bool: """Simple Chroma check""" diff --git a/src/modules/redis_ext/base.py b/src/modules/redis_ext/base.py index 79a9d48..445d931 100644 --- a/src/modules/redis_ext/base.py +++ b/src/modules/redis_ext/base.py @@ -7,6 +7,7 @@ from langchain_redis import RedisSemanticCache from service.logger import LoggerConfigurator +from service.logger.context_vars import current_span, current_trace class RedisAdapter: @@ -36,31 +37,58 @@ def __init__( self.logger.info(f"REDIS_THRESHOLD: {self.redis_threshold}") self.logger.info(f"REDIS_TTL: {self.redis_ttl}") + def _start_span(self, name: str, input_data: dict): + span = current_span.get() + if span: + return span.span(name=name, input=input_data) + trace = current_trace.get() + if trace: + return trace.span(name=name, input=input_data) + return None + def save(self, meta_info: str, query: str = "", output: str = "", json_data: Optional[dict] = None): """ output=str в text, json_data=dict в metadata. """ - # self.logger.debug("saving query") - metadata = {"json": json_data} if json_data else {} - metadata["query"] = query - metadata["output"] = output + span = self._start_span("redis_save", {"query": query, "meta_info": meta_info}) + try: + metadata = {"json": json_data} if json_data else {} + metadata["query"] = query + metadata["output"] = output - json_str = json.dumps(metadata) + json_str = json.dumps(metadata) - result = [Generation(text=json_str)] - self.semantic_cache.update(query, meta_info, result) + result = [Generation(text=json_str)] + self.semantic_cache.update(query, meta_info, result) + if span: + span.end(output={"status": "saved"}) + except Exception as e: + if span: + span.end(level="ERROR", status_message=str(e)) + raise def get(self, meta_info: str, query: str = "") -> Optional[Dict[str, Any]]: """Возвращает полный dict из JSON в text.""" - # self.logger.debug("getting query") - result = self.semantic_cache.lookup(query, meta_info) - if result: - try: - return json.loads(result[0].text) - except json.JSONDecodeError as e: - self.logger.error(f"JSON decode error: {e}") - return None - return None + span = self._start_span("redis_get", {"query": query, "meta_info": meta_info}) + try: + result = self.semantic_cache.lookup(query, meta_info) + if result: + parsed = json.loads(result[0].text) + if span: + span.end(output={"hit": True}) + return parsed + if span: + span.end(output={"hit": False}) + return None + except json.JSONDecodeError as e: + self.logger.error(f"JSON decode error: {e}") + if span: + span.end(level="ERROR", status_message=str(e)) + return None + except Exception as e: + if span: + span.end(level="ERROR", status_message=str(e)) + raise def health_check(self) -> bool: """Simple health check""" diff --git a/src/service/api/v1/router.py b/src/service/api/v1/router.py index c0f8eab..4669aa7 100644 --- a/src/service/api/v1/router.py +++ b/src/service/api/v1/router.py @@ -10,6 +10,7 @@ from agents.profkom_consultant import AgentStatus, build_builder from service.config import APP_CONFIG from service.context import APP_CTX +from service.logger.context_vars import current_trace from . import schemas from .schemas import AgentChatRequest, AgentChatResponse, FailedDependecyResponse, YandexGPTAPITestResponse @@ -76,10 +77,16 @@ async def chat( agent_graph = build_builder(agent=APP_CTX.get_agent(), checkpointer=checkpointer) langfuse = await APP_CTX.get_langfuse() + trace = langfuse.client.trace( + name="chat", + user_id=headers.get("x-user-id"), + session_id=headers.get("x-trace-id"), + metadata={"stage": APP_CONFIG.app.stage}, + ) + current_trace.set(trace) config = { "configurable": {"thread_id": headers.get("x-user-id")}, - "callbacks": [langfuse.handler], "metadata": { "stage": APP_CONFIG.app.stage, "langfuse_session_id": headers.get("x-trace-id"), diff --git a/src/service/logger/context_vars.py b/src/service/logger/context_vars.py index 0518d22..12d8ffc 100644 --- a/src/service/logger/context_vars.py +++ b/src/service/logger/context_vars.py @@ -1,7 +1,11 @@ from contextvars import ContextVar +from typing import Any from .models import ContextLog +current_trace: ContextVar[Any | None] = ContextVar("current_trace", default=None) +current_span: ContextVar[Any | None] = ContextVar("current_span", default=None) + class ContextVarsContainer: @property @@ -46,4 +50,4 @@ def get_context_vars(self): ) -__all__ = ["ContextVarsContainer"] +__all__ = ["ContextVarsContainer", "current_trace", "current_span"] diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/agents/__init__.py b/tests/agents/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/agents/profkom_consultant/__init__.py b/tests/agents/profkom_consultant/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/agents/profkom_consultant/nodes/__init__.py b/tests/agents/profkom_consultant/nodes/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/agents/profkom_consultant/nodes/test_base.py b/tests/agents/profkom_consultant/nodes/test_base.py new file mode 100644 index 0000000..6b0777b --- /dev/null +++ b/tests/agents/profkom_consultant/nodes/test_base.py @@ -0,0 +1,52 @@ +from unittest.mock import MagicMock + +import pytest + +from agents.profkom_consultant.nodes.base import BaseAgentNodes + + +@pytest.fixture +def agent(): + instance = BaseAgentNodes.__new__(BaseAgentNodes) + instance.HISTORY_LIMIT = 3 + return instance + + +class TestUpdateUserHistoryContext: + def test_appends_question_and_answer(self, agent): + state = { + "text": "Новый вопрос", + "final_answer": "Новый ответ", + } + + result = agent.update_user_history_context(state) + + assert result["user_history"] == ["Новый вопрос"] + assert result["model_answers"] == ["Новый ответ"] + + def test_trims_to_history_limit(self, agent): + state = { + "user_history": ["вопрос 1", "вопрос 2", "вопрос 3"], + "model_answers": ["ответ 1", "ответ 2", "ответ 3"], + "text": "вопрос 4", + "final_answer": "ответ 4", + } + + result = agent.update_user_history_context(state) + + assert result["user_history"] == ["вопрос 2", "вопрос 3", "вопрос 4"] + assert result["model_answers"] == ["ответ 2", "ответ 3", "ответ 4"] + + def test_maintains_one_to_one_sync_after_trim(self, agent): + state = { + "user_history": ["вопрос 1", "вопрос 2"], + "model_answers": ["ответ 1", "ответ 2"], + "text": "вопрос 3", + "final_answer": "ответ 3", + } + + result = agent.update_user_history_context(state) + + assert len(result["user_history"]) == len(result["model_answers"]) + assert result["user_history"][-1] == "вопрос 3" + assert result["model_answers"][-1] == "ответ 3" diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..154f766 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,33 @@ +import pytest +import pytest_asyncio +from httpx import AsyncClient, ASGITransport + +from service.api import create_app +from service.context import APP_CTX + + +@pytest.fixture +def app(monkeypatch): + async def _noop(*args, **kwargs): + pass + + monkeypatch.setattr(APP_CTX, "on_startup", _noop) + monkeypatch.setattr(APP_CTX, "on_shutdown", _noop) + return create_app() + + +@pytest_asyncio.fixture +async def async_client(app): + transport = ASGITransport(app=app) + async with AsyncClient(transport=transport, base_url="http://test") as client: + yield client + + +@pytest.fixture +def mock_headers(): + return { + "x-trace-id": "test-trace-id", + "x-request-time": "2024-01-01T00:00:00+03:00", + "x-source-name": "pytest", + "x-user-id": "test-user-id", + } diff --git a/tests/service/__init__.py b/tests/service/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/service/api/__init__.py b/tests/service/api/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/service/api/v1/__init__.py b/tests/service/api/v1/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/service/api/v1/test_router.py b/tests/service/api/v1/test_router.py new file mode 100644 index 0000000..5f3282b --- /dev/null +++ b/tests/service/api/v1/test_router.py @@ -0,0 +1,117 @@ +import sys +from unittest.mock import AsyncMock, MagicMock + +import pytest + +from agents.profkom_consultant import AgentStatus +from service.context import APP_CTX + + +class FakeMessage: + def __init__(self, content): + self.content = content + + +@pytest.mark.anyio +async def test_test_invoke_success(async_client, monkeypatch, mock_headers): + router_module = sys.modules["service.api.v1.router"] + monkeypatch.setattr( + router_module, + "ChatOpenAI", + lambda **kwargs: MagicMock(invoke=lambda question: FakeMessage("Mocked answer")), + ) + + payload = { + "question": "Кто ты воин?", + "generation_params": {}, + } + + response = await async_client.post("/api/v1/test_invoke", json=payload, headers=mock_headers) + + assert response.status_code == 200 + data = response.json() + assert data["answer"] == "Mocked answer" + + +@pytest.mark.anyio +async def test_test_invoke_failed_dependency(async_client, monkeypatch, mock_headers): + def _raise(*args, **kwargs): + raise RuntimeError("YandexGPT down") + + router_module = sys.modules["service.api.v1.router"] + monkeypatch.setattr( + router_module, + "ChatOpenAI", + lambda **kwargs: MagicMock(invoke=_raise), + ) + + payload = {"question": "Кто ты воин?"} + response = await async_client.post("/api/v1/test_invoke", json=payload, headers=mock_headers) + + assert response.status_code == 424 + data = response.json() + assert "YandexGPT down" in data["error_description"] + + +@pytest.mark.anyio +async def test_chat_success(async_client, monkeypatch, mock_headers): + rate_limiter_mock = MagicMock() + rate_limiter_mock.check_and_increment.return_value = (True, 1) + monkeypatch.setattr(APP_CTX, "get_ratelimiter", AsyncMock(return_value=rate_limiter_mock)) + + checkpointer_mock = AsyncMock() + checkpointer_cm = AsyncMock() + checkpointer_cm.__aenter__ = AsyncMock(return_value=checkpointer_mock) + checkpointer_cm.__aexit__ = AsyncMock(return_value=None) + + postgres_mock = MagicMock() + postgres_mock.get_user_checkpointer.return_value = checkpointer_cm + monkeypatch.setattr(APP_CTX, "get_postgres_client", AsyncMock(return_value=postgres_mock)) + + langfuse_mock = MagicMock() + langfuse_mock.client.trace.return_value = MagicMock() + monkeypatch.setattr(APP_CTX, "get_langfuse", AsyncMock(return_value=langfuse_mock)) + + agent_mock = MagicMock() + monkeypatch.setattr(APP_CTX, "get_agent", lambda: agent_mock) + + graph_mock = AsyncMock() + graph_mock.ainvoke.return_value = {"final_answer": "Это ответ агента"} + + router_module = sys.modules["service.api.v1.router"] + monkeypatch.setattr(router_module, "build_builder", lambda agent, checkpointer: graph_mock) + + payload = { + "text": "Как вступить в профсоюз?", + "organisation": "ППО Невинномысский Азот", + } + + response = await async_client.post("/api/v1/chat", json=payload, headers=mock_headers) + + assert response.status_code == 200 + data = response.json() + assert data["response"] == "Это ответ агента" + + graph_mock.ainvoke.assert_awaited_once() + call_kwargs = graph_mock.ainvoke.call_args.kwargs + assert call_kwargs["input"]["status"] == AgentStatus.ACTIVE + + +@pytest.mark.anyio +async def test_chat_rate_limit(async_client, monkeypatch, mock_headers): + rate_limiter_mock = MagicMock() + rate_limiter_mock.check_and_increment.return_value = (False, 10) + rate_limiter_mock.ttl.return_value = 42 + monkeypatch.setattr(APP_CTX, "get_ratelimiter", AsyncMock(return_value=rate_limiter_mock)) + + payload = { + "text": "Как вступить в профсоюз?", + "organisation": "ППО Невинномысский Азот", + } + + response = await async_client.post("/api/v1/chat", json=payload, headers=mock_headers) + + assert response.status_code == 200 + data = response.json() + assert "превысили свой лимит" in data["response"] + assert "42" in data["response"] diff --git a/uv.lock b/uv.lock index f359195..d4a4155 100644 --- a/uv.lock +++ b/uv.lock @@ -876,6 +876,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/a4/ed/1f1afb2e9e7f38a545d628f864d562a5ae64fe6f7a10e28ffb9b185b4e89/importlib_resources-6.5.2-py3-none-any.whl", hash = "sha256:789cfdc3ed28c78b67a06acb8126751ced69a3d5f79c095a98298cd8a760ccec", size = 37461, upload-time = "2025-01-03T18:51:54.306Z" }, ] +[[package]] +name = "iniconfig" +version = "2.3.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/72/34/14ca021ce8e5dfedc35312d08ba8bf51fdd999c576889fc2c24cb97f4f10/iniconfig-2.3.0.tar.gz", hash = "sha256:c76315c77db068650d49c5b56314774a7804df16fee4402c1f19d6d15d8c4730", size = 20503, upload-time = "2025-10-18T21:55:43.219Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/cb/b1/3846dd7f199d53cb17f49cba7e651e9ce294d8497c8c150530ed11865bb8/iniconfig-2.3.0-py3-none-any.whl", hash = "sha256:f631c04d2c48c52b84d0d0549c99ff3859c98df65b3101406327ecc7d53fbf12", size = 7484, upload-time = "2025-10-18T21:55:41.639Z" }, +] + [[package]] name = "ipykernel" version = "7.1.0" @@ -2098,6 +2107,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/cb/28/3bfe2fa5a7b9c46fe7e13c97bda14c895fb10fa2ebf1d0abb90e0cea7ee1/platformdirs-4.5.1-py3-none-any.whl", hash = "sha256:d03afa3963c806a9bed9d5125c8f4cb2fdaf74a55ab60e5d59b3fde758104d31", size = 18731, upload-time = "2025-12-05T13:52:56.823Z" }, ] +[[package]] +name = "pluggy" +version = "1.6.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/f9/e2/3e91f31a7d2b083fe6ef3fa267035b518369d9511ffab804f839851d2779/pluggy-1.6.0.tar.gz", hash = "sha256:7dcc130b76258d33b90f61b658791dede3486c3e6bfb003ee5c9bfb396dd22f3", size = 69412, upload-time = "2025-05-15T12:30:07.975Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/54/20/4d324d65cc6d9205fabedc306948156824eb9f0ee1633355a8f7ec5c66bf/pluggy-1.6.0-py3-none-any.whl", hash = "sha256:e920276dd6813095e9377c0bc5566d94c932c33b27a3e3945d8389c374dd4746", size = 20538, upload-time = "2025-05-15T12:30:06.134Z" }, +] + [[package]] name = "ply" version = "3.11" @@ -2436,6 +2454,35 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/bd/24/12818598c362d7f300f18e74db45963dbcb85150324092410c8b49405e42/pyproject_hooks-1.2.0-py3-none-any.whl", hash = "sha256:9e5c6bfa8dcc30091c74b0cf803c81fdd29d94f01992a7707bc97babb1141913", size = 10216, upload-time = "2024-09-29T09:24:11.978Z" }, ] +[[package]] +name = "pytest" +version = "9.0.3" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "colorama", marker = "sys_platform == 'win32'" }, + { name = "iniconfig" }, + { name = "packaging" }, + { name = "pluggy" }, + { name = "pygments" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/7d/0d/549bd94f1a0a402dc8cf64563a117c0f3765662e2e668477624baeec44d5/pytest-9.0.3.tar.gz", hash = "sha256:b86ada508af81d19edeb213c681b1d48246c1a91d304c6c81a427674c17eb91c", size = 1572165, upload-time = "2026-04-07T17:16:18.027Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/d4/24/a372aaf5c9b7208e7112038812994107bc65a84cd00e0354a88c2c77a617/pytest-9.0.3-py3-none-any.whl", hash = "sha256:2c5efc453d45394fdd706ade797c0a81091eccd1d6e4bccfcd476e2b8e0ab5d9", size = 375249, upload-time = "2026-04-07T17:16:16.13Z" }, +] + +[[package]] +name = "pytest-asyncio" +version = "1.3.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "pytest" }, + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/90/2c/8af215c0f776415f3590cac4f9086ccefd6fd463befeae41cd4d3f193e5a/pytest_asyncio-1.3.0.tar.gz", hash = "sha256:d7f52f36d231b80ee124cd216ffb19369aa168fc10095013c6b014a34d3ee9e5", size = 50087, upload-time = "2025-11-10T16:07:47.256Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/e5/35/f8b19922b6a25bc0880171a2f1a003eaeb93657475193ab516fd87cac9da/pytest_asyncio-1.3.0-py3-none-any.whl", hash = "sha256:611e26147c7f77640e6d0a92a38ed17c3e9848063698d5c93d5aa7aa11cebff5", size = 15075, upload-time = "2025-11-10T16:07:45.537Z" }, +] + [[package]] name = "python-dateutil" version = "2.9.0.post0" @@ -3144,8 +3191,11 @@ dependencies = [ [package.dev-dependencies] dev = [ + { name = "httpx" }, { name = "pre-commit" }, { name = "pylint" }, + { name = "pytest" }, + { name = "pytest-asyncio" }, { name = "ruff" }, ] @@ -3182,8 +3232,11 @@ requires-dist = [ [package.metadata.requires-dev] dev = [ + { name = "httpx" }, { name = "pre-commit" }, { name = "pylint" }, + { name = "pytest" }, + { name = "pytest-asyncio" }, { name = "ruff" }, ] From c237777694aa33c8d939f4fac6efaaa4aeccf43d Mon Sep 17 00:00:00 2001 From: Aleksandr Samofalov Date: Wed, 15 Apr 2026 01:40:50 +0300 Subject: [PATCH 2/5] chore(tests): add pytest-cov and generate coverage.xml - Added pytest-cov to dev dependencies. - Generated coverage.xml from current test suite (7 tests, all passing). --- coverage.xml | 1855 ++++++++++++++++++++++++++++++++++++++++++++++++ pyproject.toml | 1 + uv.lock | 40 ++ 3 files changed, 1896 insertions(+) create mode 100644 coverage.xml diff --git a/coverage.xml b/coverage.xml new file mode 100644 index 0000000..515f60c --- /dev/null +++ b/coverage.xml @@ -0,0 +1,1855 @@ + + + + + + /Users/aleksandrsamofalov/PycharmProjects/GeneralPurposeChatBot/src + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/pyproject.toml b/pyproject.toml index a606e62..b906068 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -64,6 +64,7 @@ dev = [ "pytest", "pytest-asyncio", "httpx", + "pytest-cov>=7.1.0", ] [tool.ruff] diff --git a/uv.lock b/uv.lock index d4a4155..622a18c 100644 --- a/uv.lock +++ b/uv.lock @@ -423,6 +423,30 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/60/97/891a0971e1e4a8c5d2b20bbe0e524dc04548d2307fee33cdeba148fd4fc7/comm-0.2.3-py3-none-any.whl", hash = "sha256:c615d91d75f7f04f095b30d1c1711babd43bdc6419c1be9886a85f2f4e489417", size = 7294, upload-time = "2025-07-25T14:02:02.896Z" }, ] +[[package]] +name = "coverage" +version = "7.13.5" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/9d/e0/70553e3000e345daff267cec284ce4cbf3fc141b6da229ac52775b5428f1/coverage-7.13.5.tar.gz", hash = "sha256:c81f6515c4c40141f83f502b07bbfa5c240ba25bbe73da7b33f1e5b6120ff179", size = 915967, upload-time = "2026-03-17T10:33:18.341Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/a0/c3/a396306ba7db865bf96fc1fb3b7fd29bcbf3d829df642e77b13555163cd6/coverage-7.13.5-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:460cf0114c5016fa841214ff5564aa4864f11948da9440bc97e21ad1f4ba1e01", size = 219554, upload-time = "2026-03-17T10:30:42.208Z" }, + { url = "https://files.pythonhosted.org/packages/a6/16/a68a19e5384e93f811dccc51034b1fd0b865841c390e3c931dcc4699e035/coverage-7.13.5-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:0e223ce4b4ed47f065bfb123687686512e37629be25cc63728557ae7db261422", size = 219908, upload-time = "2026-03-17T10:30:43.906Z" }, + { url = "https://files.pythonhosted.org/packages/29/72/20b917c6793af3a5ceb7fb9c50033f3ec7865f2911a1416b34a7cfa0813b/coverage-7.13.5-cp312-cp312-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:6e3370441f4513c6252bf042b9c36d22491142385049243253c7e48398a15a9f", size = 251419, upload-time = "2026-03-17T10:30:45.545Z" }, + { url = "https://files.pythonhosted.org/packages/8c/49/cd14b789536ac6a4778c453c6a2338bc0a2fb60c5a5a41b4008328b9acc1/coverage-7.13.5-cp312-cp312-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:03ccc709a17a1de074fb1d11f217342fb0d2b1582ed544f554fc9fc3f07e95f5", size = 254159, upload-time = "2026-03-17T10:30:47.204Z" }, + { url = "https://files.pythonhosted.org/packages/9d/00/7b0edcfe64e2ed4c0340dac14a52ad0f4c9bd0b8b5e531af7d55b703db7c/coverage-7.13.5-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:3f4818d065964db3c1c66dc0fbdac5ac692ecbc875555e13374fdbe7eedb4376", size = 255270, upload-time = "2026-03-17T10:30:48.812Z" }, + { url = "https://files.pythonhosted.org/packages/93/89/7ffc4ba0f5d0a55c1e84ea7cee39c9fc06af7b170513d83fbf3bbefce280/coverage-7.13.5-cp312-cp312-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:012d5319e66e9d5a218834642d6c35d265515a62f01157a45bcc036ecf947256", size = 257538, upload-time = "2026-03-17T10:30:50.77Z" }, + { url = "https://files.pythonhosted.org/packages/81/bd/73ddf85f93f7e6fa83e77ccecb6162d9415c79007b4bc124008a4995e4a7/coverage-7.13.5-cp312-cp312-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:8dd02af98971bdb956363e4827d34425cb3df19ee550ef92855b0acb9c7ce51c", size = 251821, upload-time = "2026-03-17T10:30:52.5Z" }, + { url = "https://files.pythonhosted.org/packages/a0/81/278aff4e8dec4926a0bcb9486320752811f543a3ce5b602cc7a29978d073/coverage-7.13.5-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:f08fd75c50a760c7eb068ae823777268daaf16a80b918fa58eea888f8e3919f5", size = 253191, upload-time = "2026-03-17T10:30:54.543Z" }, + { url = "https://files.pythonhosted.org/packages/70/ee/fe1621488e2e0a58d7e94c4800f0d96f79671553488d401a612bebae324b/coverage-7.13.5-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:843ea8643cf967d1ac7e8ecd4bb00c99135adf4816c0c0593fdcc47b597fcf09", size = 251337, upload-time = "2026-03-17T10:30:56.663Z" }, + { url = "https://files.pythonhosted.org/packages/37/a6/f79fb37aa104b562207cc23cb5711ab6793608e246cae1e93f26b2236ed9/coverage-7.13.5-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:9d44d7aa963820b1b971dbecd90bfe5fe8f81cff79787eb6cca15750bd2f79b9", size = 255404, upload-time = "2026-03-17T10:30:58.427Z" }, + { url = "https://files.pythonhosted.org/packages/75/f0/ed15262a58ec81ce457ceb717b7f78752a1713556b19081b76e90896e8d4/coverage-7.13.5-cp312-cp312-musllinux_1_2_riscv64.whl", hash = "sha256:7132bed4bd7b836200c591410ae7d97bf7ae8be6fc87d160b2bd881df929e7bf", size = 250903, upload-time = "2026-03-17T10:31:00.093Z" }, + { url = "https://files.pythonhosted.org/packages/0f/e9/9129958f20e7e9d4d56d51d42ccf708d15cac355ff4ac6e736e97a9393d2/coverage-7.13.5-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:a698e363641b98843c517817db75373c83254781426e94ada3197cabbc2c919c", size = 252780, upload-time = "2026-03-17T10:31:01.916Z" }, + { url = "https://files.pythonhosted.org/packages/a4/d7/0ad9b15812d81272db94379fe4c6df8fd17781cc7671fdfa30c76ba5ff7b/coverage-7.13.5-cp312-cp312-win32.whl", hash = "sha256:bdba0a6b8812e8c7df002d908a9a2ea3c36e92611b5708633c50869e6d922fdf", size = 222093, upload-time = "2026-03-17T10:31:03.642Z" }, + { url = "https://files.pythonhosted.org/packages/29/3d/821a9a5799fac2556bcf0bd37a70d1d11fa9e49784b6d22e92e8b2f85f18/coverage-7.13.5-cp312-cp312-win_amd64.whl", hash = "sha256:d2c87e0c473a10bffe991502eac389220533024c8082ec1ce849f4218dded810", size = 222900, upload-time = "2026-03-17T10:31:05.651Z" }, + { url = "https://files.pythonhosted.org/packages/d4/fa/2238c2ad08e35cf4f020ea721f717e09ec3152aea75d191a7faf3ef009a8/coverage-7.13.5-cp312-cp312-win_arm64.whl", hash = "sha256:bf69236a9a81bdca3bff53796237aab096cdbf8d78a66ad61e992d9dac7eb2de", size = 221515, upload-time = "2026-03-17T10:31:07.293Z" }, + { url = "https://files.pythonhosted.org/packages/9e/ee/a4cf96b8ce1e566ed238f0659ac2d3f007ed1d14b181bcb684e19561a69a/coverage-7.13.5-py3-none-any.whl", hash = "sha256:34b02417cf070e173989b3db962f7ed56d2f644307b2cf9d5a0f258e13084a61", size = 211346, upload-time = "2026-03-17T10:33:15.691Z" }, +] + [[package]] name = "cryptography" version = "46.0.4" @@ -2483,6 +2507,20 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/e5/35/f8b19922b6a25bc0880171a2f1a003eaeb93657475193ab516fd87cac9da/pytest_asyncio-1.3.0-py3-none-any.whl", hash = "sha256:611e26147c7f77640e6d0a92a38ed17c3e9848063698d5c93d5aa7aa11cebff5", size = 15075, upload-time = "2025-11-10T16:07:45.537Z" }, ] +[[package]] +name = "pytest-cov" +version = "7.1.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "coverage" }, + { name = "pluggy" }, + { name = "pytest" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/b1/51/a849f96e117386044471c8ec2bd6cfebacda285da9525c9106aeb28da671/pytest_cov-7.1.0.tar.gz", hash = "sha256:30674f2b5f6351aa09702a9c8c364f6a01c27aae0c1366ae8016160d1efc56b2", size = 55592, upload-time = "2026-03-21T20:11:16.284Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/9d/7a/d968e294073affff457b041c2be9868a40c1c71f4a35fcc1e45e5493067b/pytest_cov-7.1.0-py3-none-any.whl", hash = "sha256:a0461110b7865f9a271aa1b51e516c9a95de9d696734a2f71e3e78f46e1d4678", size = 22876, upload-time = "2026-03-21T20:11:14.438Z" }, +] + [[package]] name = "python-dateutil" version = "2.9.0.post0" @@ -3196,6 +3234,7 @@ dev = [ { name = "pylint" }, { name = "pytest" }, { name = "pytest-asyncio" }, + { name = "pytest-cov" }, { name = "ruff" }, ] @@ -3237,6 +3276,7 @@ dev = [ { name = "pylint" }, { name = "pytest" }, { name = "pytest-asyncio" }, + { name = "pytest-cov", specifier = ">=7.1.0" }, { name = "ruff" }, ] From c14a364abd1ae33b656adea055e2abf548946df1 Mon Sep 17 00:00:00 2001 From: Aleksandr Samofalov Date: Wed, 15 Apr 2026 01:53:45 +0300 Subject: [PATCH 3/5] test(modules): unit-test coverage for all src/modules integrations Added pytest suites covering: - redis_ext: RedisAdapter (save/get/spans/health) + UserRateLimiter (incr/remain/reset/ttl/health) - chroma_ext: ChromaAdapter (init/query/filter/rerank/RAG), MyEmbeddingFunction (retry/batch/call), BM25Reranker (fit/rerank), data_reader (chunk/signature/topic), db_writer (sync/orphan cleanup) - postgres_ext: PostgresClient (pool lifecycle, conn loop, checkpointer, stats, close, health) - langfuse_ext: LangfuseClient (init, client/handler creation, health_check, on_startup) Fixes: - Fixed f-string bug in PostgresClient.get_pool_stats that raised ValueError at runtime. Coverage: - Regenerated coverage.xml from full suite (102 tests passing). --- coverage.xml | 988 +++++++++--------- src/modules/postgres_ext/base.py | 4 +- tests/modules/chroma_ext/__init__.py | 0 tests/modules/chroma_ext/scripts/__init__.py | 0 .../chroma_ext/scripts/test_data_reader.py | 117 +++ .../chroma_ext/scripts/test_db_writer.py | 138 +++ tests/modules/chroma_ext/test_base.py | 210 ++++ tests/modules/chroma_ext/utils/__init__.py | 0 .../chroma_ext/utils/test_embedings.py | 134 +++ .../modules/chroma_ext/utils/test_reranker.py | 60 ++ tests/modules/langfuse_ext/__init__.py | 0 tests/modules/langfuse_ext/test_base.py | 84 ++ tests/modules/postgres_ext/__init__.py | 0 tests/modules/postgres_ext/test_base.py | 155 +++ tests/modules/redis_ext/__init__.py | 0 tests/modules/redis_ext/test_base.py | 186 ++++ tests/modules/redis_ext/utils/__init__.py | 0 .../redis_ext/utils/test_RedisAdapters.py | 122 +++ 18 files changed, 1702 insertions(+), 496 deletions(-) create mode 100644 tests/modules/chroma_ext/__init__.py create mode 100644 tests/modules/chroma_ext/scripts/__init__.py create mode 100644 tests/modules/chroma_ext/scripts/test_data_reader.py create mode 100644 tests/modules/chroma_ext/scripts/test_db_writer.py create mode 100644 tests/modules/chroma_ext/test_base.py create mode 100644 tests/modules/chroma_ext/utils/__init__.py create mode 100644 tests/modules/chroma_ext/utils/test_embedings.py create mode 100644 tests/modules/chroma_ext/utils/test_reranker.py create mode 100644 tests/modules/langfuse_ext/__init__.py create mode 100644 tests/modules/langfuse_ext/test_base.py create mode 100644 tests/modules/postgres_ext/__init__.py create mode 100644 tests/modules/postgres_ext/test_base.py create mode 100644 tests/modules/redis_ext/__init__.py create mode 100644 tests/modules/redis_ext/test_base.py create mode 100644 tests/modules/redis_ext/utils/__init__.py create mode 100644 tests/modules/redis_ext/utils/test_RedisAdapters.py diff --git a/coverage.xml b/coverage.xml index 515f60c..6c4e691 100644 --- a/coverage.xml +++ b/coverage.xml @@ -1,5 +1,5 @@ - + @@ -379,7 +379,7 @@ - + @@ -388,7 +388,7 @@ - + @@ -400,265 +400,265 @@ - - - - - - - - - - - - - - - - - - - - - - - - - - - - + + + + + + + + + + + + + + + + + + + + + + + + + + + + - + - - - + + + - - - - - + + + + + - - - - - - - + + + + + + + - - - - - - - - - - - - - + + + + + + + + + + + + + - - - - - + + + + + - - + + - - - - + + + + - - - - - - + + + + + + - - - - - - - - - - - - - - - - - + + + + + + + + + + + + + + + + + - + - + - + - - - + + + - + - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + - - - + + + - + - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + - + @@ -668,7 +668,7 @@ - + @@ -682,72 +682,72 @@ - - - - - - - - - - - - - - - - - - - - - + + + + + + + + + + + + + + + + + + + + + - - - - - - - - - - - - - - - - - - - - - - - + + + + + + + + + + + + + + + + + + + + + + + - - - + + + - - - - - - - - - - - + + + + + + + + + + + - + @@ -757,40 +757,40 @@ - - - - - + + + + + - - - - - - - - - - + + + + + + + + + + - - - + + + - - - - - - - - + + + + + + + + - + @@ -799,7 +799,7 @@ - + @@ -808,42 +808,42 @@ - - - - - - - - - - - + + + + + + + + + + + - - + + - + - - - - - - - + + + + + + + - - - - + + + + - + @@ -852,7 +852,7 @@ - + @@ -865,73 +865,73 @@ - - - - - - + + + + + + - - - + + + - - - - - - + + + + + + - - - - - - - - + + + + + + + + - - - - - - - - - - - - - + + + + + + + + + + + + + - - - - + + + + - - - - - + + + + + - - - - - - + + + + + + - - - - - + + + + + - + @@ -944,7 +944,7 @@ - + @@ -954,7 +954,7 @@ - + @@ -967,117 +967,117 @@ - - - - - - - - - + + + + + + + + + - - - - - - - + + + + + + + - - - - - - - - - - - - - - + + + + + + + + + + + + + + - - - - - - - - - - - - - - - - - - - - + + + + + + + + + + + + + + + + + + + + - + - + - + - - - - - - - - - - - - - - + + + + + + + + + + + + + + - - - - - - - - - - + + + + + + + + + + - - - - - + + + + + - - + + - - + + - - - - - + + + + + diff --git a/src/modules/postgres_ext/base.py b/src/modules/postgres_ext/base.py index 89abba6..a97c799 100644 --- a/src/modules/postgres_ext/base.py +++ b/src/modules/postgres_ext/base.py @@ -116,8 +116,8 @@ async def get_pool_stats(self) -> dict[str, Any] | None: return None stats = self._pool.get_stats() self.logger.info( - f"Postgres.get_pool_stats: {id(self): pool_size={stats.get('pool_size', 0)}}" - f"pool_available={stats.get('pool_available', 0)}" + f"Postgres.get_pool_stats: id={id(self)} pool_size={stats.get('pool_size', 0)} " + f"pool_available={stats.get('pool_available', 0)} " f"request_waiting={stats.get('request_waiting', 0)}" ) return stats diff --git a/tests/modules/chroma_ext/__init__.py b/tests/modules/chroma_ext/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/modules/chroma_ext/scripts/__init__.py b/tests/modules/chroma_ext/scripts/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/modules/chroma_ext/scripts/test_data_reader.py b/tests/modules/chroma_ext/scripts/test_data_reader.py new file mode 100644 index 0000000..01b6dc6 --- /dev/null +++ b/tests/modules/chroma_ext/scripts/test_data_reader.py @@ -0,0 +1,117 @@ +from pathlib import Path +from unittest.mock import MagicMock, patch + +import pytest + +from modules.chroma_ext.scripts.data_reader import ( + DocumentChunk, + _build_topic_prefix, + _calc_signature, + _read_docx, + _split_into_chunks, + load_docx_with_metadata, +) + + +class TestReadDocx: + @patch("modules.chroma_ext.scripts.data_reader.docx2txt.process") + def test_returns_stripped_text(self, mock_process): + mock_process.return_value = " hello world \n\n" + result = _read_docx(Path("/fake/doc.docx")) + assert result == "hello world" + + @patch("modules.chroma_ext.scripts.data_reader.docx2txt.process") + def test_returns_empty_when_none(self, mock_process): + mock_process.return_value = None + assert _read_docx(Path("/fake/doc.docx")) == "" + + +class TestSplitIntoChunks: + def test_empty_text(self): + assert _split_into_chunks("") == [] + + def test_exact_size_no_overlap(self): + text = "a" * 10 + chunks = _split_into_chunks(text, chunk_size=5, chunk_overlap=0) + assert chunks == ["a" * 5, "a" * 5] + + def test_overlap(self): + text = "a" * 10 + chunks = _split_into_chunks(text, chunk_size=6, chunk_overlap=2) + assert chunks == ["a" * 6, "a" * 6] + + def test_single_chunk_when_text_shorter(self): + text = "short" + chunks = _split_into_chunks(text, chunk_size=100, chunk_overlap=10) + assert chunks == ["short"] + + +class TestCalcSignature: + def test_deterministic(self): + assert _calc_signature("hello") == _calc_signature("hello") + assert _calc_signature("hello") != _calc_signature("world") + + +class TestBuildTopicPrefix: + def test_empty(self): + assert _build_topic_prefix("") == "" + + def test_truncates_to_max_tokens(self): + text = "one two three four five" + assert _build_topic_prefix(text, max_tokens=3) == "one two three" + + def test_full_when_short(self): + text = "one two" + assert _build_topic_prefix(text, max_tokens=10) == "one two" + + +class TestLoadDocxWithMetadata: + @patch("modules.chroma_ext.scripts.data_reader.docx2txt.process") + def test_skips_empty_files(self, mock_process, tmp_path): + mock_process.return_value = "" + (tmp_path / "empty.docx").write_text("fake") + logger = MagicMock() + result = load_docx_with_metadata(logger, tmp_path) + assert result == [] + + @patch("modules.chroma_ext.scripts.data_reader.docx2txt.process") + def test_loads_single_file_with_topic_prefix(self, mock_process, tmp_path): + text = "Title of document. " + "body " * 200 + mock_process.return_value = text + (tmp_path / "contract.docx").write_text("fake") + logger = MagicMock() + result = load_docx_with_metadata(logger, tmp_path, chunk_size=50, chunk_overlap=10, topic_tokens=5) + + assert len(result) >= 1 + chunk = result[0] + assert isinstance(chunk, DocumentChunk) + assert chunk.id == "contract.docx::chunk:0" + assert chunk.metadata["filename"] == "contract.docx" + assert chunk.metadata["topic"] == "general" + assert "Title of" in chunk.text + assert chunk.metadata["file_signature"] == _calc_signature(text.strip()) + + @patch("modules.chroma_ext.scripts.data_reader.docx2txt.process") + def test_nested_directory_topic(self, mock_process, tmp_path): + text = "some content here" + mock_process.return_value = text + nested = tmp_path / "hr" + nested.mkdir() + (nested / "rules.docx").write_text("fake") + logger = MagicMock() + result = load_docx_with_metadata(logger, tmp_path) + assert len(result) == 1 + assert result[0].metadata["topic"] == "hr" + assert result[0].id == "hr/rules.docx::chunk:0" + + @patch("modules.chroma_ext.scripts.data_reader.docx2txt.process") + def test_multiple_chunks(self, mock_process, tmp_path): + text = "a" * 100 + mock_process.return_value = text + (tmp_path / "long.docx").write_text("fake") + logger = MagicMock() + result = load_docx_with_metadata(logger, tmp_path, chunk_size=30, chunk_overlap=5) + assert len(result) > 1 + for idx, chunk in enumerate(result): + assert chunk.metadata["chunk_index"] == idx + assert chunk.metadata["num_chunks"] == len(result) diff --git a/tests/modules/chroma_ext/scripts/test_db_writer.py b/tests/modules/chroma_ext/scripts/test_db_writer.py new file mode 100644 index 0000000..b3e400e --- /dev/null +++ b/tests/modules/chroma_ext/scripts/test_db_writer.py @@ -0,0 +1,138 @@ +from unittest.mock import MagicMock, patch + +import pytest + +from modules.chroma_ext.scripts.data_reader import DocumentChunk +from modules.chroma_ext.scripts.db_writer import ( + _collect_current_sources, + _group_by_source, + sync_docx_directory_to_collection, +) + + +class TestGroupBySource: + def test_groups_by_source(self): + chunks = [ + DocumentChunk(id="1", text="a", metadata={"source": "/a.docx"}), + DocumentChunk(id="2", text="b", metadata={"source": "/a.docx"}), + DocumentChunk(id="3", text="c", metadata={"source": "/b.docx"}), + ] + grouped = _group_by_source(chunks) + assert len(grouped["/a.docx"]) == 2 + assert len(grouped["/b.docx"]) == 1 + + +class TestCollectCurrentSources: + def test_collects_docx_paths(self, tmp_path): + (tmp_path / "a.docx").write_text("x") + nested = tmp_path / "sub" + nested.mkdir() + (nested / "b.docx").write_text("y") + result = _collect_current_sources(str(tmp_path)) + assert result == {str(tmp_path / "a.docx"), str(tmp_path / "sub" / "b.docx")} + + +class TestSyncDocxDirectoryToCollection: + @patch("modules.chroma_ext.scripts.db_writer.chromadb.HttpClient") + @patch("modules.chroma_ext.scripts.db_writer.MyEmbeddingFunction") + @patch("modules.chroma_ext.scripts.db_writer.load_docx_with_metadata") + def test_no_chunks_early_return(self, mock_load, mock_embed, mock_client): + mock_load.return_value = [] + logger = MagicMock() + sync_docx_directory_to_collection( + logger, "/docs", "test_collection", api_key="k", folder_id="f", host="h", port=8000 + ) + logger.warning.assert_called_once_with("No .docx files found, nothing to index") + # HttpClient is created before the empty check in current implementation + mock_client.assert_called_once() + + @patch("modules.chroma_ext.scripts.db_writer.chromadb.HttpClient") + @patch("modules.chroma_ext.scripts.db_writer.MyEmbeddingFunction") + @patch("modules.chroma_ext.scripts.db_writer.load_docx_with_metadata") + def test_unchanged_file_skipped(self, mock_load, mock_embed, mock_client): + collection = MagicMock() + collection.get.return_value = { + "ids": ["old"], + "metadatas": [{"file_signature": "sig1"}], + } + mock_client.return_value.get_or_create_collection.return_value = collection + + chunks = [ + DocumentChunk( + id="f::chunk:0", + text="txt", + metadata={"source": "/f.docx", "file_signature": "sig1"}, + ) + ] + mock_load.return_value = chunks + logger = MagicMock() + + sync_docx_directory_to_collection( + logger, "/docs", "test_collection", api_key="k", folder_id="f", host="h", port=8000 + ) + collection.add.assert_not_called() + collection.delete.assert_not_called() + + @patch("modules.chroma_ext.scripts.db_writer.chromadb.HttpClient") + @patch("modules.chroma_ext.scripts.db_writer.MyEmbeddingFunction") + @patch("modules.chroma_ext.scripts.db_writer.load_docx_with_metadata") + def test_changed_file_deletes_and_adds(self, mock_load, mock_embed, mock_client): + collection = MagicMock() + collection.get.side_effect = [ + { + "ids": ["old"], + "metadatas": [{"file_signature": "old_sig"}], + }, + {"ids": [], "metadatas": []}, + ] + mock_client.return_value.get_or_create_collection.return_value = collection + + chunks = [ + DocumentChunk( + id="f::chunk:0", + text="new text", + metadata={"source": "/f.docx", "file_signature": "new_sig"}, + ) + ] + mock_load.return_value = chunks + logger = MagicMock() + + sync_docx_directory_to_collection( + logger, "/docs", "test_collection", api_key="k", folder_id="f", host="h", port=8000 + ) + collection.delete.assert_any_call(where={"source": "/f.docx"}) + collection.add.assert_called_once_with( + ids=["f::chunk:0"], + documents=["new text"], + metadatas=[{"source": "/f.docx", "file_signature": "new_sig"}], + ) + + @patch("modules.chroma_ext.scripts.db_writer.chromadb.HttpClient") + @patch("modules.chroma_ext.scripts.db_writer.MyEmbeddingFunction") + @patch("modules.chroma_ext.scripts.db_writer.load_docx_with_metadata") + def test_removes_orphaned_sources(self, mock_load, mock_embed, mock_client, tmp_path): + collection = MagicMock() + collection.get.side_effect = [ + {"ids": [], "metadatas": []}, + { + "ids": ["old"], + "metadatas": [{"source": str(tmp_path / "gone.docx")}], + }, + ] + mock_client.return_value.get_or_create_collection.return_value = collection + + (tmp_path / "keep.docx").write_text("x") + chunks = [ + DocumentChunk( + id="keep::chunk:0", + text="keep", + metadata={"source": str(tmp_path / "keep.docx"), "file_signature": "sig"}, + ) + ] + mock_load.return_value = chunks + logger = MagicMock() + + sync_docx_directory_to_collection( + logger, str(tmp_path), "test_collection", api_key="k", folder_id="f", host="h", port=8000 + ) + collection.delete.assert_any_call(where={"source": str(tmp_path / "gone.docx")}) diff --git a/tests/modules/chroma_ext/test_base.py b/tests/modules/chroma_ext/test_base.py new file mode 100644 index 0000000..261d92d --- /dev/null +++ b/tests/modules/chroma_ext/test_base.py @@ -0,0 +1,210 @@ +from unittest.mock import MagicMock, patch + +import pandas as pd +import pytest + +from modules.chroma_ext.base import ChromaAdapter + + +@pytest.fixture +def adapter(): + logger = MagicMock() + with patch("modules.chroma_ext.base.chromadb.HttpClient") as MockClient, \ + patch("modules.chroma_ext.base.BM25Reranker") as MockReranker: + mock_client = MagicMock() + MockClient.return_value = mock_client + mock_reranker = MagicMock() + MockReranker.return_value = mock_reranker + inst = ChromaAdapter( + logger=logger, + similarity_filter=1.0, + reranker_type="bm25", + text_type="query", + API_KEY="key123456789", + FOLDER_ID="fld123456789", + CHROMA_HOST="testhost", + CHROMA_PORT=9000, + CHROMA_TOPK_DOCUMENTS=3, + CHROMA_MAX_RAG_DOCUMENTS=10, + ) + inst._mock_client = mock_client + inst._mock_reranker = mock_reranker + return inst + + +class TestChromaAdapterInit: + def test_validation_errors(self): + logger = MagicMock() + with patch("modules.chroma_ext.base.chromadb.HttpClient"): + # FOLDER_ID=None and API_KEY=None currently raise TypeError due to slice + # before validation (bug in code) + with pytest.raises(TypeError): + ChromaAdapter(logger=logger, API_KEY="key", FOLDER_ID=None) + with pytest.raises(TypeError): + ChromaAdapter(logger=logger, API_KEY=None, FOLDER_ID="fld123456789") + with pytest.raises(ValueError, match="TOPK"): + ChromaAdapter(logger=logger, API_KEY="key", FOLDER_ID="fld123456789", CHROMA_TOPK_DOCUMENTS=20, CHROMA_MAX_RAG_DOCUMENTS=20) + + def test_unsupported_reranker_does_not_raise(self): + # Current code instantiates NotImplementedError but does not raise it (bug). + logger = MagicMock() + with patch("modules.chroma_ext.base.chromadb.HttpClient"): + adapter = ChromaAdapter( + logger=logger, + API_KEY="key123456789", + FOLDER_ID="fld123456789", + reranker_type="unknown", + ) + assert getattr(adapter, "reranker", None) is None + + def test_params_set(self, adapter): + assert adapter.host == "testhost" + assert adapter.port == 9000 + assert adapter.topk_documents == 3 + assert adapter.max_rag_documents == 10 + assert adapter.similarity_filter == 1.0 + + +class TestChromaAdapterEmbeddingFunction: + @patch("modules.chroma_ext.base.MyEmbeddingFunction") + def test_lazy_initialization(self, MockEmb, adapter): + mock_ef = MagicMock() + MockEmb.return_value = mock_ef + ef = adapter.embedding_function + assert ef is mock_ef + MockEmb.assert_called_once() + # second call returns cached + assert adapter.embedding_function is mock_ef + + +class TestChromaAdapterStartSpan: + def test_prefers_current_span(self, adapter): + parent = MagicMock() + child = MagicMock() + parent.span.return_value = child + with patch("modules.chroma_ext.base.current_span") as mock_cs, \ + patch("modules.chroma_ext.base.current_trace") as mock_ct: + mock_cs.get.return_value = parent + mock_ct.get.return_value = MagicMock() + result = adapter._start_span("chroma_test", {"a": 1}) + assert result is child + + def test_fallback_to_trace(self, adapter): + trace = MagicMock() + child = MagicMock() + trace.span.return_value = child + with patch("modules.chroma_ext.base.current_span") as mock_cs, \ + patch("modules.chroma_ext.base.current_trace") as mock_ct: + mock_cs.get.return_value = None + mock_ct.get.return_value = trace + result = adapter._start_span("chroma_test", {"a": 1}) + assert result is child + + def test_none_when_no_context(self, adapter): + with patch("modules.chroma_ext.base.current_span") as mock_cs, \ + patch("modules.chroma_ext.base.current_trace") as mock_ct: + mock_cs.get.return_value = None + mock_ct.get.return_value = None + assert adapter._start_span("chroma_test", {"a": 1}) is None + + +class TestChromaAdapterGetInfoFromDb: + def test_success(self, adapter): + span = MagicMock() + adapter._start_span = MagicMock(return_value=span) + mock_collection = MagicMock() + mock_collection.query.return_value = { + "documents": [["doc1", "doc2"]], + "metadatas": [[{"m": 1}, {"m": 2}]], + "distances": [[0.1, 0.2]], + } + adapter._mock_client.get_collection.return_value = mock_collection + + result = adapter.get_info_from_db("q", "coll", n_results=5, where={"topic": "t"}) + assert result["documents"][0] == ["doc1", "doc2"] + span.end.assert_called_once_with(output={"documents_returned": 2}) + + def test_error_ends_span(self, adapter): + span = MagicMock() + adapter._start_span = MagicMock(return_value=span) + adapter._mock_client.get_collection.side_effect = RuntimeError("chroma down") + with pytest.raises(RuntimeError, match="chroma down"): + adapter.get_info_from_db("q", "coll") + span.end.assert_called_once_with(level="ERROR", status_message="chroma down") + + +class TestChromaAdapterGetFilteredDocuments: + def test_filters_by_distance_and_strips_body(self, adapter): + data_raw = { + "documents": [["keep1", "keep2", "drop"]], + "metadatas": [[{"a": 1}, {"a": 2}, {"a": 3}]], + "distances": [[0.5, 0.9, 1.5]], + } + result = adapter.get_filtered_documents(data_raw) + assert result["documents"] == ["keep1", "keep2"] + assert result["metadatas"] == [{"a": 1}, {"a": 2}] + + +class TestChromaAdapterGetPairs: + def test_builds_pairs(self, adapter): + result = adapter.get_pairs("query", ["d1", "d2"]) + assert result == [["query", "d1"], ["query", "d2"]] + + +class TestChromaAdapterApplyReranker: + def test_delegates_to_bm25(self, adapter): + adapter._mock_reranker.rerank.return_value = [1, 0] + idx = adapter.apply_reranker("q", ["d1", "d2", "d3"]) + adapter._mock_reranker.fit.assert_called_once_with(["d1", "d2", "d3"]) + adapter._mock_reranker.rerank.assert_called_once_with(query="q", top_k=3) + assert idx == [1, 0] + + +class TestChromaAdapterGetInfo: + def test_full_flow_returns_dataframe(self, adapter): + span = MagicMock() + adapter._start_span = MagicMock(return_value=span) + mock_collection = MagicMock() + mock_collection.query.return_value = { + "documents": [["d1", "d2"]], + "metadatas": [[{"t": "a"}, {"t": "b"}]], + "distances": [[0.1, 0.2]], + } + adapter._mock_client.get_collection.return_value = mock_collection + adapter._mock_reranker.rerank.return_value = [0] + + df = adapter.get_info("query", "coll", topics=["a", "b"]) + assert isinstance(df, pd.DataFrame) + assert df["documents"].tolist() == ["d1"] + # get_info creates its own span and get_info_from_db creates another; + # last end call belongs to the outer chroma_rag span + span.end.assert_called_with(output={"documents_found": 1}) + + def test_empty_filtered_documents(self, adapter): + span = MagicMock() + adapter._start_span = MagicMock(return_value=span) + mock_collection = MagicMock() + mock_collection.query.return_value = { + "documents": [["drop"]], + "metadatas": [[{"t": "a"}]], + "distances": [[2.0]], + } + adapter._mock_client.get_collection.return_value = mock_collection + + df = adapter.get_info("query", "coll") + assert isinstance(df, pd.DataFrame) + assert df.empty + span.end.assert_called_with(output={"documents_found": 0}) + + def test_exception_ends_span(self, adapter): + span = MagicMock() + adapter._start_span = MagicMock(return_value=span) + adapter._mock_client.get_collection.side_effect = ValueError("fail") + with pytest.raises(ValueError, match="fail"): + adapter.get_info("query", "coll") + span.end.assert_called_with(level="ERROR", status_message="fail") + + +class TestChromaAdapterHealthCheck: + def test_always_true(self, adapter): + assert adapter.health_check() is True diff --git a/tests/modules/chroma_ext/utils/__init__.py b/tests/modules/chroma_ext/utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/modules/chroma_ext/utils/test_embedings.py b/tests/modules/chroma_ext/utils/test_embedings.py new file mode 100644 index 0000000..c89562d --- /dev/null +++ b/tests/modules/chroma_ext/utils/test_embedings.py @@ -0,0 +1,134 @@ +from unittest.mock import MagicMock, patch + +import numpy as np +import pytest + +from modules.chroma_ext.utils.embedings import MyEmbeddingFunction + + +@pytest.fixture +def embedder(): + logger = MagicMock() + return MyEmbeddingFunction( + logger=logger, + folder_id="b1g2d3f4", + iam_token="t0k3n-12345678", + doc_model_uri="doc-uri", + query_model_uri="query-uri", + text_type="doc", + time_sleep=0, + max_retries=2, + request_timeout=5, + batch_size=2, + sleep_between_batches=0, + ) + + +class TestMyEmbeddingFunctionInit: + def test_defaults_and_kwargs(self, embedder): + assert embedder.api_url == "https://llm.api.cloud.yandex.net:443/foundationModels/v1/textEmbedding" + assert embedder.folder_id == "b1g2d3f4" + assert embedder.iam_token == "t0k3n-12345678" + assert embedder.text_type == "doc" + assert embedder.doc_model_uri == "doc-uri" + assert embedder.query_model_uri == "query-uri" + assert embedder.max_retries == 2 + assert embedder.batch_size == 2 + + +class TestMyEmbeddingFunctionGetSingleEmbedding: + @patch("modules.chroma_ext.utils.embedings.time.sleep", return_value=None) + @patch("modules.chroma_ext.utils.embedings.requests.post") + def test_success_returns_ndarray(self, mock_post, mock_sleep, embedder): + mock_resp = MagicMock() + mock_resp.ok = True + mock_resp.json.return_value = {"embedding": [0.1, 0.2, 0.3]} + mock_post.return_value = mock_resp + + result = embedder._get_single_embedding("hello") + assert isinstance(result, np.ndarray) + np.testing.assert_array_equal(result, np.array([0.1, 0.2, 0.3])) + + @patch("modules.chroma_ext.utils.embedings.time.sleep", return_value=None) + @patch("modules.chroma_ext.utils.embedings.requests.post") + def test_transient_retries_then_success(self, mock_post, mock_sleep, embedder): + bad_resp = MagicMock() + bad_resp.ok = False + bad_resp.status_code = 503 + bad_resp.text = "busy" + + good_resp = MagicMock() + good_resp.ok = True + good_resp.json.return_value = {"embedding": [1.0]} + + mock_post.side_effect = [bad_resp, good_resp] + + result = embedder._get_single_embedding("hello") + np.testing.assert_array_equal(result, np.array([1.0])) + assert mock_post.call_count == 2 + embedder.logger.warning.assert_called() + + @patch("modules.chroma_ext.utils.embedings.time.sleep", return_value=None) + @patch("modules.chroma_ext.utils.embedings.requests.post") + def test_non_transient_4xx_raises(self, mock_post, mock_sleep, embedder): + bad_resp = MagicMock() + bad_resp.ok = False + bad_resp.status_code = 400 + bad_resp.text = "bad request" + bad_resp.raise_for_status.side_effect = Exception("HTTP 400") + mock_post.return_value = bad_resp + + with pytest.raises(Exception, match="HTTP 400"): + embedder._get_single_embedding("hello") + + @patch("modules.chroma_ext.utils.embedings.time.sleep", return_value=None) + @patch("modules.chroma_ext.utils.embedings.requests.post") + def test_timeout_retries_then_raises(self, mock_post, mock_sleep, embedder): + from requests.exceptions import ConnectTimeout + + mock_post.side_effect = ConnectTimeout("timeout") + + with pytest.raises(ConnectTimeout): + embedder._get_single_embedding("hello") + assert mock_post.call_count == 2 + + +class TestMyEmbeddingFunctionBatched: + def test_batched_exact_and_remainder(self, embedder): + result = list(embedder._batched(["a", "b", "c", "d", "e"], 2)) + assert result == [["a", "b"], ["c", "d"], ["e"]] + + def test_batched_empty(self, embedder): + result = list(embedder._batched([], 3)) + assert result == [] + + +class TestMyEmbeddingFunctionCall: + @patch("modules.chroma_ext.utils.embedings.time.sleep", return_value=None) + @patch("modules.chroma_ext.utils.embedings.requests.post") + def test_call_single_string(self, mock_post, mock_sleep, embedder): + mock_resp = MagicMock() + mock_resp.ok = True + mock_resp.json.return_value = {"embedding": [0.5]} + mock_post.return_value = mock_resp + + result = embedder("hello") + assert isinstance(result, list) + assert len(result) == 1 + assert isinstance(result[0], np.ndarray) + + @patch("modules.chroma_ext.utils.embedings.time.sleep", return_value=None) + @patch("modules.chroma_ext.utils.embedings.requests.post") + def test_call_batches_with_sleep(self, mock_post, mock_sleep, embedder): + mock_resp = MagicMock() + mock_resp.ok = True + mock_resp.json.return_value = {"embedding": [0.1]} + mock_post.return_value = mock_resp + + embedder.batch_size = 2 + result = embedder(["a", "b", "c"]) + assert isinstance(result, list) + assert len(result) == 3 + assert all(isinstance(r, np.ndarray) for r in result) + # sleep вызывается: base sleep для каждого запроса + sleep_between_batches + assert mock_sleep.call_count >= 1 diff --git a/tests/modules/chroma_ext/utils/test_reranker.py b/tests/modules/chroma_ext/utils/test_reranker.py new file mode 100644 index 0000000..02b1bc1 --- /dev/null +++ b/tests/modules/chroma_ext/utils/test_reranker.py @@ -0,0 +1,60 @@ +from unittest.mock import MagicMock + +import pytest + +from modules.chroma_ext.utils.reranker import BM25Reranker + + +@pytest.fixture +def reranker(): + logger = MagicMock() + return BM25Reranker(logger=logger, tokenizer_name="gpt-3.5-turbo") + + +class TestBM25RerankerInit: + def test_initializes_tokenizer(self, reranker): + assert reranker.tokenizer is not None + reranker.logger.info.assert_called() + + +class TestBM25RerankerPreprocess: + def test_preprocess_lowercases_and_tokenizes(self, reranker): + tokens = reranker.preprocess("Hello World") + assert isinstance(tokens, list) + assert len(tokens) > 0 + # tiktoken tokens converted back to strings; "hello" and " world" or similar + text = " ".join(tokens) + assert "hello" in text.lower() + + def test_preprocess_filters_empty(self, reranker): + tokens = reranker.preprocess("") + assert tokens == [] + + +class TestBM25RerankerFit: + def test_fit_builds_bm25(self, reranker): + reranker.fit(["first document", "second document"]) + assert reranker.bm25 is not None + + +class TestBM25RerankerRerank: + def test_rerank_before_fit_raises(self, reranker): + with pytest.raises(ValueError, match="not fitted"): + reranker.rerank("query", top_k=2) + + def test_rerank_returns_top_k_indices(self, reranker): + docs = [ + "python programming language", + "cooking recipes for dinner", + "python snakes in the wild", + ] + reranker.fit(docs) + indices = reranker.rerank("python code", top_k=2) + assert len(indices) == 2 + assert all(isinstance(i, int) for i in indices) + + def test_rerank_top_k_larger_than_docs(self, reranker): + docs = ["only one document here"] + reranker.fit(docs) + indices = reranker.rerank("query", top_k=5) + assert len(indices) == 1 diff --git a/tests/modules/langfuse_ext/__init__.py b/tests/modules/langfuse_ext/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/modules/langfuse_ext/test_base.py b/tests/modules/langfuse_ext/test_base.py new file mode 100644 index 0000000..a367133 --- /dev/null +++ b/tests/modules/langfuse_ext/test_base.py @@ -0,0 +1,84 @@ +from unittest.mock import MagicMock, patch + +import pytest + +from modules.langfuse_ext.base import LangfuseClient + + +@pytest.fixture +def client(): + config = MagicMock() + config.host = "https://langfuse.test" + config.secret_key = "secret_key_123" + config.public_key = "public_key_456" + config.stage = "test" + logger = MagicMock() + return LangfuseClient(app_config=config, logger=logger) + + +class TestLangfuseClientInit: + def test_logs_masked_keys(self, client): + client.logger.debug.assert_any_call("Secret Key: secr**_123") + client.logger.debug.assert_any_call("Public Key: publ**_456") + assert client.client is not None + assert client.handler is not None + + +class TestLangfuseClientCreateClient: + @patch("modules.langfuse_ext.base.Langfuse") + def test_creates_langfuse_instance(self, MockLangfuse): + config = MagicMock() + config.host = "h" + config.secret_key = "s" + config.public_key = "p" + logger = MagicMock() + client = LangfuseClient(app_config=config, logger=logger) + # access property to trigger creation + _ = client._LangfuseClient__create_client + MockLangfuse.assert_called_with(secret_key="s", public_key="p", host="h") + + +class TestLangfuseClientCreateCallbackHandler: + @patch("modules.langfuse_ext.base.CallbackHandler") + def test_creates_handler(self, MockHandler): + config = MagicMock() + config.host = "h" + config.secret_key = "s" + config.public_key = "p" + config.stage = "stage" + logger = MagicMock() + client = LangfuseClient(app_config=config, logger=logger) + _ = client._LangfuseClient__create_callback_handler + MockHandler.assert_called_with( + public_key="p", secret_key="s", host="h", trace_name="stage" + ) + + +class TestLangfuseClientHealthCheck: + def test_true_when_auth_ok(self, client): + client.client = MagicMock() + client.client.auth_check.return_value = True + assert client.health_check() is True + + def test_false_when_auth_fails(self, client): + client.client = MagicMock() + client.client.auth_check.return_value = False + assert client.health_check() is False + + +class TestLangfuseClientOnStartup: + @pytest.mark.asyncio + async def test_reassigns_and_checks(self, client): + client.client = MagicMock() + client.handler = MagicMock() + with patch.object(client, "health_check", return_value=True) as mock_hc: + await client.on_startup() + mock_hc.assert_called_once() + + @pytest.mark.asyncio + async def test_catches_exceptions(self, client): + with patch.object( + type(client), "_LangfuseClient__create_client", property(lambda self: (_ for _ in ()).throw(RuntimeError("boom"))) + ): + # Should not raise + await client.on_startup() diff --git a/tests/modules/postgres_ext/__init__.py b/tests/modules/postgres_ext/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/modules/postgres_ext/test_base.py b/tests/modules/postgres_ext/test_base.py new file mode 100644 index 0000000..40473f0 --- /dev/null +++ b/tests/modules/postgres_ext/test_base.py @@ -0,0 +1,155 @@ +import asyncio +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from modules.postgres_ext.base import PostgresClient + + +@pytest.fixture +def client(): + config = MagicMock() + config.encoded_pass = "secret" + config.user = "u" + config.host = "h" + config.port = 5432 + config.postgres_db = "db" + config.pool_min_size = 1 + config.pool_max_size = 5 + config.pool_max_idle = 30.0 + config.conninfo = "postgresql://u:secret@h:5432/db" + logger = MagicMock() + return PostgresClient(config=config, logger=logger) + + +class TestPostgresClientInit: + def test_pool_starts_none(self, client): + assert client._pool is None + assert isinstance(client._lock, asyncio.Lock) + client.logger.info.assert_called() + + +@pytest.mark.asyncio +class TestPostgresClientEnsurePool: + async def test_creates_pool_once(self, client): + with patch("modules.postgres_ext.base.AsyncConnectionPool") as MockPool: + mock_pool = AsyncMock() + mock_pool.get_stats = MagicMock(return_value={"pool_size": 1}) + MockPool.return_value = mock_pool + + await client.ensure_pool() + await client.ensure_pool() + + MockPool.assert_called_once_with( + conninfo=client.settings.conninfo, + min_size=client.settings.pool_min_size, + max_size=client.settings.pool_max_size, + max_idle=client.settings.pool_max_idle, + ) + assert mock_pool.open.await_count == 1 + + async def test_race_condition_safe(self, client): + with patch("modules.postgres_ext.base.AsyncConnectionPool") as MockPool: + mock_pool = AsyncMock() + mock_pool.get_stats = MagicMock(return_value={"pool_size": 1}) + MockPool.return_value = mock_pool + + async def task(): + await client.ensure_pool() + + await asyncio.gather(task(), task(), task()) + MockPool.assert_called_once() + + +@pytest.mark.asyncio +class TestPostgresClientTakeConnInLoop: + async def test_happy_path(self, client): + with patch("modules.postgres_ext.base.AsyncConnectionPool") as MockPool: + mock_pool = AsyncMock() + mock_pool.get_stats = MagicMock(return_value={"pool_size": 1}) + mock_conn = AsyncMock() + mock_pool.getconn.return_value = mock_conn + MockPool.return_value = mock_pool + await client.ensure_pool() + + conn = await client._take_conn_in_loop(0, 3) + assert conn is mock_conn + mock_conn.execute.assert_any_call("SELECT 1") + mock_conn.execute.assert_any_call("ROLLBACK") + + async def test_retries_then_none(self, client): + with patch("modules.postgres_ext.base.AsyncConnectionPool") as MockPool: + mock_pool = AsyncMock() + mock_pool.get_stats = MagicMock(return_value={"pool_size": 1}) + mock_conn = AsyncMock() + mock_conn.execute.side_effect = RuntimeError("dead") + mock_pool.getconn.return_value = mock_conn + MockPool.return_value = mock_pool + await client.ensure_pool() + + conn = await client._take_conn_in_loop(0, 2) + assert conn is None + assert mock_conn.close.await_count == 2 + + +@pytest.mark.asyncio +class TestPostgresClientGetUserCheckpointer: + async def test_yields_saver_and_returns_conn(self, client): + with patch("modules.postgres_ext.base.AsyncConnectionPool") as MockPool, \ + patch("modules.postgres_ext.base.AsyncPostgresSaver") as MockSaver: + mock_pool = AsyncMock() + mock_pool.get_stats = MagicMock(return_value={"pool_size": 1}) + mock_conn = AsyncMock() + mock_pool.getconn.return_value = mock_conn + MockPool.return_value = mock_pool + mock_saver = MagicMock() + MockSaver.return_value = mock_saver + + await client.ensure_pool() + async with client.get_user_checkpointer() as saver: + assert saver is mock_saver + + mock_conn.set_autocommit.assert_awaited_once_with(True) + mock_pool.putconn.assert_awaited_once_with(mock_conn) + + +@pytest.mark.asyncio +class TestPostgresClientGetPoolStats: + async def test_none_when_no_pool(self, client): + assert await client.get_pool_stats() is None + + async def test_returns_stats(self, client): + with patch("modules.postgres_ext.base.AsyncConnectionPool") as MockPool: + mock_pool = AsyncMock() + mock_pool.get_stats = MagicMock(return_value={"pool_size": 2}) + MockPool.return_value = mock_pool + await client.ensure_pool() + stats = await client.get_pool_stats() + assert stats == {"pool_size": 2} + + +@pytest.mark.asyncio +class TestPostgresClientClose: + async def test_closes_and_nulls_pool(self, client): + with patch("modules.postgres_ext.base.AsyncConnectionPool") as MockPool: + mock_pool = AsyncMock() + mock_pool.get_stats = MagicMock(return_value={"pool_size": 1}) + MockPool.return_value = mock_pool + await client.ensure_pool() + await client.close() + mock_pool.close.assert_awaited_once() + assert client._pool is None + + async def test_idempotent(self, client): + await client.close() + assert client._pool is None + + +class TestPostgresClientHealthCheck: + def test_false_before_init(self, client): + assert client.health_check() is False + + def test_true_after_init(self, client): + # We can't easily async ensure_pool in sync test, so set pool manually + client._pool = MagicMock() + assert client.health_check() is True diff --git a/tests/modules/redis_ext/__init__.py b/tests/modules/redis_ext/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/modules/redis_ext/test_base.py b/tests/modules/redis_ext/test_base.py new file mode 100644 index 0000000..d88ed01 --- /dev/null +++ b/tests/modules/redis_ext/test_base.py @@ -0,0 +1,186 @@ +import json +from unittest.mock import ANY, MagicMock, patch + +import pytest + +from modules.redis_ext.base import RedisAdapter + + +@pytest.fixture +def mock_embeddings(): + return MagicMock() + + +@pytest.fixture +def mock_logger(): + return MagicMock() + + +@pytest.fixture +def adapter(mock_logger, mock_embeddings): + with patch("modules.redis_ext.base.RedisSemanticCache") as MockCache: + mock_cache = MagicMock() + MockCache.return_value = mock_cache + inst = RedisAdapter( + logger=mock_logger, + embeddings=mock_embeddings, + redis_url="redis://test:6379", + redis_threshold=0.1, + redis_ttl=60, + ) + inst._mock_cache = mock_cache + return inst + + +class TestRedisAdapterInit: + def test_uses_explicit_args(self, mock_logger, mock_embeddings): + with patch("modules.redis_ext.base.RedisSemanticCache") as MockCache: + RedisAdapter( + logger=mock_logger, + embeddings=mock_embeddings, + redis_url="redis://explicit:6379", + redis_threshold=0.2, + redis_ttl=120, + ) + MockCache.assert_called_once_with( + redis_url="redis://explicit:6379", + embeddings=mock_embeddings, + distance_threshold=0.2, + ttl=120, + ) + + def test_fallback_to_env_vars(self, monkeypatch, mock_logger, mock_embeddings): + monkeypatch.setenv("REDIS_URL", "redis://env:6379") + monkeypatch.setenv("REDIS_THRESHOLD", "0.3") + monkeypatch.setenv("REDIS_TTL", "240") + with patch("modules.redis_ext.base.RedisSemanticCache") as MockCache: + RedisAdapter( + logger=mock_logger, + embeddings=mock_embeddings, + redis_url=None, + redis_threshold=None, + redis_ttl=None, + ) + MockCache.assert_called_once_with( + redis_url="redis://env:6379", + embeddings=mock_embeddings, + distance_threshold=0.3, + ttl=240, + ) + + +class TestRedisAdapterSave: + def test_save_calls_update_and_ends_span(self, adapter): + span = MagicMock() + adapter._start_span = MagicMock(return_value=span) + adapter.save(meta_info="meta", query="q", output="out", json_data={"k": "v"}) + + adapter._mock_cache.update.assert_called_once() + args = adapter._mock_cache.update.call_args[0] + assert args[0] == "q" + assert args[1] == "meta" + generation = args[2][0] + assert json.loads(generation.text)["output"] == "out" + span.end.assert_called_once_with(output={"status": "saved"}) + + def test_save_error_ends_span_with_error(self, adapter): + span = MagicMock() + adapter._start_span = MagicMock(return_value=span) + adapter._mock_cache.update.side_effect = RuntimeError("boom") + + with pytest.raises(RuntimeError, match="boom"): + adapter.save(meta_info="meta", query="q") + + span.end.assert_called_once_with(level="ERROR", status_message="boom") + + +class TestRedisAdapterGet: + def test_get_hit_parses_json(self, adapter): + span = MagicMock() + adapter._start_span = MagicMock(return_value=span) + payload = {"output": "hello", "json": {"x": 1}} + gen = MagicMock() + gen.text = json.dumps(payload) + adapter._mock_cache.lookup.return_value = [gen] + + result = adapter.get(meta_info="meta", query="q") + assert result == payload + span.end.assert_called_once_with(output={"hit": True}) + + def test_get_miss_returns_none(self, adapter): + span = MagicMock() + adapter._start_span = MagicMock(return_value=span) + adapter._mock_cache.lookup.return_value = None + + result = adapter.get(meta_info="meta", query="q") + assert result is None + span.end.assert_called_once_with(output={"hit": False}) + + def test_get_json_decode_error_logs_and_returns_none(self, adapter): + span = MagicMock() + adapter._start_span = MagicMock(return_value=span) + gen = MagicMock() + gen.text = "not-json" + adapter._mock_cache.lookup.return_value = [gen] + + result = adapter.get(meta_info="meta", query="q") + assert result is None + adapter.logger.error.assert_called() + span.end.assert_called_once_with(level="ERROR", status_message=ANY) + + def test_get_exception_propagates(self, adapter): + span = MagicMock() + adapter._start_span = MagicMock(return_value=span) + adapter._mock_cache.lookup.side_effect = ValueError("redis down") + + with pytest.raises(ValueError, match="redis down"): + adapter.get(meta_info="meta", query="q") + + span.end.assert_called_once_with(level="ERROR", status_message="redis down") + + +class TestRedisAdapterHealthCheck: + def test_true_when_cache_present(self, adapter): + assert adapter.health_check() is True + + def test_false_when_cache_missing(self, adapter): + adapter.semantic_cache = None + assert adapter.health_check() is False + + +class TestRedisAdapterStartSpan: + def test_prefers_current_span(self, adapter): + child_span = MagicMock() + parent_span = MagicMock() + parent_span.span.return_value = child_span + + with patch("modules.redis_ext.base.current_span") as mock_cs, \ + patch("modules.redis_ext.base.current_trace") as mock_ct: + mock_cs.get.return_value = parent_span + mock_ct.get.return_value = MagicMock() + + result = adapter._start_span("redis_test", {"a": 1}) + assert result is child_span + parent_span.span.assert_called_once_with(name="redis_test", input={"a": 1}) + + def test_fallback_to_trace(self, adapter): + trace = MagicMock() + child = MagicMock() + trace.span.return_value = child + + with patch("modules.redis_ext.base.current_span") as mock_cs, \ + patch("modules.redis_ext.base.current_trace") as mock_ct: + mock_cs.get.return_value = None + mock_ct.get.return_value = trace + + result = adapter._start_span("redis_test", {"a": 1}) + assert result is child + trace.span.assert_called_once_with(name="redis_test", input={"a": 1}) + + def test_none_when_no_context(self, adapter): + with patch("modules.redis_ext.base.current_span") as mock_cs, \ + patch("modules.redis_ext.base.current_trace") as mock_ct: + mock_cs.get.return_value = None + mock_ct.get.return_value = None + + assert adapter._start_span("redis_test", {"a": 1}) is None diff --git a/tests/modules/redis_ext/utils/__init__.py b/tests/modules/redis_ext/utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/modules/redis_ext/utils/test_RedisAdapters.py b/tests/modules/redis_ext/utils/test_RedisAdapters.py new file mode 100644 index 0000000..5b5056a --- /dev/null +++ b/tests/modules/redis_ext/utils/test_RedisAdapters.py @@ -0,0 +1,122 @@ +from unittest.mock import MagicMock, patch + +import pytest + +from modules.redis_ext.utils.RedisAdapters import UserRateLimiter + + +@pytest.fixture +def limiter(): + mock_logger = MagicMock() + with patch("modules.redis_ext.utils.RedisAdapters.redis.Redis") as MockRedis: + mock_redis = MagicMock() + MockRedis.return_value = mock_redis + inst = UserRateLimiter( + logger=mock_logger, + host="test-host", + port=6380, + db=3, + decode_responses=True, + USER_QUERY_LIMIT_N=5, + USER_QUERY_LIMIT_TTL_SECONDS=60, + RATE_LIMIT_TEMPLATE="rl:{user_id}", + ) + inst._mock_redis = mock_redis + return inst + + +class TestUserRateLimiterInit: + def test_uses_defaults(self): + mock_logger = MagicMock() + with patch("modules.redis_ext.utils.RedisAdapters.redis.Redis") as MockRedis: + UserRateLimiter(logger=mock_logger) + MockRedis.assert_called_once_with( + host="127.0.0.1", + port=6379, + db=2, + decode_responses=True, + ) + + def test_uses_kwargs(self): + mock_logger = MagicMock() + with patch("modules.redis_ext.utils.RedisAdapters.redis.Redis") as MockRedis: + UserRateLimiter( + logger=mock_logger, + host="h", + port=1234, + db=7, + decode_responses=False, + ) + MockRedis.assert_called_once_with( + host="h", + port=1234, + db=7, + decode_responses=False, + ) + + +class TestUserRateLimiterCheckAndIncrement: + def test_new_key_sets_expire(self, limiter): + pipe = MagicMock() + pipe.incr.return_value = None + pipe.ttl.return_value = None + pipe.execute.return_value = [1, -2] + limiter._mock_redis.pipeline.return_value.__enter__.return_value = pipe + + allowed, count = limiter.check_and_increment("u1") + assert allowed is True + assert count == 1 + limiter._mock_redis.expire.assert_called_once_with("rl:u1", 60) + + def test_within_limit_no_expire(self, limiter): + pipe = MagicMock() + pipe.execute.return_value = [3, 55] + limiter._mock_redis.pipeline.return_value.__enter__.return_value = pipe + + allowed, count = limiter.check_and_increment("u1") + assert allowed is True + assert count == 3 + limiter._mock_redis.expire.assert_not_called() + + def test_exceeds_limit(self, limiter): + pipe = MagicMock() + pipe.execute.return_value = [6, 10] + limiter._mock_redis.pipeline.return_value.__enter__.return_value = pipe + + allowed, count = limiter.check_and_increment("u1") + assert allowed is False + assert count == 6 + + +class TestUserRateLimiterGetRemaining: + def test_key_exists(self, limiter): + limiter._mock_redis.get.return_value = "3" + assert limiter.get_remaining("u1") == 2 + + def test_key_missing(self, limiter): + limiter._mock_redis.get.return_value = None + assert limiter.get_remaining("u1") == 5 + + +class TestUserRateLimiterResetCounter: + def test_deletes_key(self, limiter): + limiter.reset_counter("u1") + limiter._mock_redis.delete.assert_called_once_with("rl:u1") + + +class TestUserRateLimiterTtl: + def test_delegates_to_redis(self, limiter): + limiter._mock_redis.ttl.return_value = 42 + assert limiter.ttl("u1") == 42 + limiter._mock_redis.ttl.assert_called_once_with("rl:u1") + + +class TestUserRateLimiterHealthCheck: + def test_healthy(self, limiter): + limiter._mock_redis.ping.return_value = True + assert limiter.health_check() is True + + def test_unhealthy(self, limiter): + limiter._mock_redis.ping.return_value = False + assert limiter.health_check() is False + limiter.logger.warning.assert_called_once() From dabb2dd31e0962a5ba89167576728dc282a6e5c5 Mon Sep 17 00:00:00 2001 From: Aleksandr Samofalov Date: Wed, 15 Apr 2026 01:58:44 +0300 Subject: [PATCH 4/5] docs(readme): add static test, coverage and python 3.12 badges --- README.md | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/README.md b/README.md index 1b357ea..65daf76 100644 --- a/README.md +++ b/README.md @@ -1,5 +1,9 @@ # UnionChatBot +![Tests](https://img.shields.io/badge/tests-102%2F102-brightgreen) +![Coverage](https://img.shields.io/badge/coverage-79%25-green) +![Python](https://img.shields.io/badge/python-3.12-blue.svg?logo=python&logoColor=white) + `UnionChatBot` - REST-сервис с агентом на базе LangChain/LangGraph и YandexGPT API. Проект включает: - FastAPI API слой; From 5864c04b38986fdef3ec3a9cc9c89d6b75002bb6 Mon Sep 17 00:00:00 2001 From: Aleksandr Samofalov Date: Wed, 15 Apr 2026 02:03:43 +0300 Subject: [PATCH 5/5] style: fix ruff formatting in 3 files --- src/agents/profkom_consultant/nodes/base.py | 5 +++- src/agents/profkom_consultant/nodes/core.py | 10 ++++++-- src/modules/chroma_ext/base.py | 28 +++++++++++++-------- 3 files changed, 29 insertions(+), 14 deletions(-) diff --git a/src/agents/profkom_consultant/nodes/base.py b/src/agents/profkom_consultant/nodes/base.py index 77b54c1..78498b2 100644 --- a/src/agents/profkom_consultant/nodes/base.py +++ b/src/agents/profkom_consultant/nodes/base.py @@ -26,6 +26,7 @@ def _llm_config(self, span): if span: return {"callbacks": [CallbackHandler(stateful_client=span)]} return {} + async def validate_text(self, state: AgentState) -> AgentState: """Проверяем, что текст вопроса пользователя соответсвует публичной политики. @@ -99,7 +100,9 @@ async def validate_final_answer(self, state: AgentState) -> AgentState: if not is_valid: state["final_answer"] = "Не прошёл валидацию" state["is_valid"] = is_valid - self.cache.save(meta_info="validate_final_answer", query=final_answer, output="", json_data=cache_data) + self.cache.save( + meta_info="validate_final_answer", query=final_answer, output="", json_data=cache_data + ) return state except Exception as e: diff --git a/src/agents/profkom_consultant/nodes/core.py b/src/agents/profkom_consultant/nodes/core.py index 7d16782..1701910 100644 --- a/src/agents/profkom_consultant/nodes/core.py +++ b/src/agents/profkom_consultant/nodes/core.py @@ -91,7 +91,10 @@ async def decompose_question(self, state: AgentState) -> None | dict[str, Any] | cache_data = {"parts": [p.strip() for p in content.split("") if p.strip()]} self.cache.save( - meta_info="decompose_question_" + state["user_id"], query=question, output="", json_data=cache_data + meta_info="decompose_question_" + state["user_id"], + query=question, + output="", + json_data=cache_data, ) return cache_data except Exception as e: @@ -125,7 +128,10 @@ async def call_llm(part: str) -> str: topic = await self._detect_topics_for_question(part) self.logger.info(f"Topic: {topic}") retrived_data = await asyncio.to_thread( - self.chorma_client.get_info, query=part, collection_name=self.COLLECTION_NAME, topics=[topic] + self.chorma_client.get_info, + query=part, + collection_name=self.COLLECTION_NAME, + topics=[topic], ) html_data = retrived_data.to_html() config = self._llm_config(span) diff --git a/src/modules/chroma_ext/base.py b/src/modules/chroma_ext/base.py index 8946555..5729779 100644 --- a/src/modules/chroma_ext/base.py +++ b/src/modules/chroma_ext/base.py @@ -109,12 +109,15 @@ def get_info_from_db( Returns: relevant documents """ - span = self._start_span("chroma_query", { - "query": query, - "collection": collection_name, - "n_results": n_results, - "where": where, - }) + span = self._start_span( + "chroma_query", + { + "query": query, + "collection": collection_name, + "n_results": n_results, + "where": where, + }, + ) try: self.logger.debug(f"get_info_from_db called for {collection_name}") collection = self.client.get_collection(name=collection_name, embedding_function=self.embedding_function) @@ -162,11 +165,14 @@ def apply_reranker(self, query, documents): def get_info(self, query: str, collection_name: str, topics: list[str] | None = None) -> pd.DataFrame: # TO DO: фильтрация по метаданным и потом только query! - span = self._start_span("chroma_rag", { - "query": query, - "collection": collection_name, - "topics": topics, - }) + span = self._start_span( + "chroma_rag", + { + "query": query, + "collection": collection_name, + "topics": topics, + }, + ) try: self.logger.debug(f"called {query} in get_info for {collection_name} and topics {topics}")