From 00607129b50098979886268723d524764d021d46 Mon Sep 17 00:00:00 2001 From: 2002yy <15135142681@163.com> Date: Fri, 5 Jun 2026 18:38:12 +0800 Subject: [PATCH] Add production RAG embeddings and clear mypy debt --- .env.example | 10 +++ README.md | 30 ++++--- docs/INTERVIEW_NOTES.md | 4 +- docs/RAG.md | 22 +++--- docs/TECH_STACK.md | 12 +-- docs/TESTING.md | 14 ++-- requirements-dev.in | 1 + requirements-dev.txt | 2 + src/after_session.py | 4 +- src/llm_client.py | 7 +- src/mode_manager.py | 2 + src/news/article_extractor.py | 2 +- src/news/article_fetcher.py | 10 +-- src/news/rss_fetcher.py | 4 +- src/rag/__init__.py | 12 ++- src/rag/backends.py | 9 ++- src/rag/chroma_backend.py | 5 +- src/rag/embeddings.py | 145 +++++++++++++++++++++++++++++++++- src/router.py | 12 ++- src/ui/chat_panel.py | 2 +- src/wechat_memory.py | 2 +- tests/test_rag_backends.py | 104 +++++++++++++++++++++++- 22 files changed, 351 insertions(+), 64 deletions(-) diff --git a/.env.example b/.env.example index 82fc438..2513647 100644 --- a/.env.example +++ b/.env.example @@ -70,3 +70,13 @@ DEEPSEEK_MODEL_PRO_NAME=deepseek-v4-pro # RAG_VECTOR_BACKEND=local # RAG_CHROMA_PATH=logs/chroma # RAG_CHROMA_COLLECTION=study_agent + +# === RAG Embeddings(默认 local_hash,无需 API key)=== +# local_hash 适合本地开发和测试;openai 适合 Chroma 持久化向量检索。 +# RAG_EMBEDDING_PROVIDER=local_hash +# RAG_EMBEDDING_PROVIDER=openai +# RAG_EMBEDDING_MODEL=text-embedding-3-small +# RAG_EMBEDDING_DIMENSIONS=1536 +# RAG_EMBEDDING_API_KEY= +# RAG_EMBEDDING_BASE_URL= +# RAG_EMBEDDING_TIMEOUT_SECONDS=30 diff --git a/README.md b/README.md index 0821573..9e3d7bf 100644 --- a/README.md +++ b/README.md @@ -3,7 +3,7 @@
A local AI learning assistant with long-term memory, role-based group chat, @@ -17,7 +17,7 @@ Study Agent 是一个本地优先的 AI 学习助手,重点不是简单调用 - **长期记忆**:Markdown memory + safe writer - **上下文分层**:fast / light / deep / archive - **联网搜索**:RSS / News fetch → article extraction → LLM digest → source tracing -- **RAG MVP**:本地 Markdown / TXT / DOCX / PDF 索引、关键词 / 本地向量原型 / hybrid / backend-vector 检索、引用上下文、来源块、Streamlit 检索/调试面板、聊天注入和 FastAPI RAG 接口 +- **RAG MVP**:本地 Markdown / TXT / DOCX / PDF 索引、关键词 / 本地向量原型 / hybrid / backend-vector 检索、可配置 embedding provider、可选 Chroma 持久化、引用上下文、来源块、Streamlit 检索/调试面板、聊天注入和 FastAPI RAG 接口 - **工程安全**:SSRF protection、detect-secrets、配置模板 - **工程质量**:pytest 测试套件、Ruff、GitHub Actions CI、打包检查 @@ -27,11 +27,11 @@ Study Agent 是一个本地优先的 AI 学习助手,重点不是简单调用 - **Model routing** with fast / light / deep / archive context tiers - **Long-term memory** based on Markdown files and safe-writer persistence - **Web search pipeline**: feed registry → URL safety checks → article extraction → LLM digest → auditable source trace -- **RAG MVP**: local Markdown / TXT / DOCX / PDF indexing, lexical / local vector prototype / hybrid / backend-vector retrieval, citation-first context formatting, source blocks, a Streamlit retrieval/debug panel, optional chat injection, and FastAPI RAG endpoints +- **RAG MVP**: local Markdown / TXT / DOCX / PDF indexing, lexical / local vector prototype / hybrid / backend-vector retrieval, configurable embedding providers, optional Chroma persistence, citation-first context formatting, source blocks, a Streamlit retrieval/debug panel, optional chat injection, and FastAPI RAG endpoints - **SSRF protection** for article fetching, **detect-secrets** in CI - **Batched session logging** and multi-layer caching for performance - **Performance budget**: mode-based `max_tokens` bounds on the main chat, WeChat, and news LLM paths -- **273 pytest tests**, Ruff clean, GitHub Actions CI workflow +- **277 pytest tests**, Ruff clean, mypy clean, GitHub Actions CI workflow For a detailed breakdown of the stack and engineering highlights, see [Technical Stack & Engineering Highlights](docs/TECH_STACK.md). @@ -207,13 +207,19 @@ pip-compile requirements-dev.in # 重新锁定开发依赖 参数优先级:代码显式参数 → 任务级环境变量 → 任务默认值 → 全局环境变量 → provider 级环境变量。完整配置见 [`.env.example`](.env.example) 和 [用户指南](USER_GUIDE.md)。 -RAG 向量后端默认使用 `local`,不需要额外服务;可选 `chroma` adapter 需要用户自行安装 `chromadb`: +RAG 向量后端默认使用 `local`,不需要额外服务;可选 `chroma` adapter 需要用户自行安装 `chromadb`。Embedding provider 默认 `local_hash`,生产检索可显式切到 OpenAI-compatible embeddings: ```bash RAG_VECTOR_BACKEND=local # RAG_VECTOR_BACKEND=chroma # RAG_CHROMA_PATH=logs/chroma # RAG_CHROMA_COLLECTION=study_agent + +RAG_EMBEDDING_PROVIDER=local_hash +# RAG_EMBEDDING_PROVIDER=openai +# RAG_EMBEDDING_MODEL=text-embedding-3-small +# RAG_EMBEDDING_DIMENSIONS=1536 +# RAG_EMBEDDING_API_KEY=... ``` --- @@ -243,7 +249,7 @@ RAG_VECTOR_BACKEND=local │ ├── config.py # 全局配置 │ ├── router.py # 路由配置 │ ├── news/ # 新闻聚合链路 -│ ├── rag/ # 本地 RAG MVP:加载、分块、索引、关键词/向量原型/可选后端检索 +│ ├── rag/ # 本地 RAG MVP:加载、分块、索引、关键词/向量原型/embedding/可选后端检索 │ └── ui/ # Streamlit UI 组件 ├── tests/ # pytest 测试套件 ├── docs/ # 设计文档与工程说明 @@ -264,13 +270,13 @@ RAG_VECTOR_BACKEND=local ## 测试 ```bash -pytest tests/ -v # current local baseline: 273 passed +pytest tests/ -v # current local baseline: 277 passed pytest tests/ --cov=src # 覆盖率 ruff check src/ tests/ # linting -mypy --explicit-package-bases src/ # CI soft check; may report type debt +mypy --explicit-package-bases src/ # type check ``` -CI 通过 GitHub Actions 在 push / pull request 上运行,集成 `pytest`、`ruff`、打包检查、`detect-secrets` 扫描,以及非阻断的 `mypy` soft check。当前验证状态见 [docs/TESTING.md](docs/TESTING.md)。 +CI 通过 GitHub Actions 在 push / pull request 上运行,集成 `pytest`、`ruff`、打包检查、`detect-secrets` 扫描,以及 `mypy` soft check。当前验证状态见 [docs/TESTING.md](docs/TESTING.md)。 --- @@ -307,9 +313,9 @@ CI 通过 GitHub Actions 在 push / pull request 上运行,集成 `pytest`、` 求职导向的技术演进路线: - [ ] FastAPI service layer (partial): `/health`, `/rag`, `/rag/index`, `/rag/query` implemented; `/chat` and `/memory` remain planned -- [x] RAG MVP: Markdown / TXT / DOCX / PDF loading, chunking, local keyword retrieval, local vector prototype, hybrid retrieval, citation context, source blocks, Streamlit retrieval panel, optional single-chat and WeChat interactive injection -- [ ] RAG document QA (partial): PDF parsing has file-size, page-count, extracted-text and encrypted-file guards; Chroma adapter scaffold exists; production embedding model retrieval remains planned -- [ ] Vector store: FAISS local prototype, pgvector engineering version +- [x] RAG MVP: Markdown / TXT / DOCX / PDF loading, chunking, local keyword retrieval, local vector prototype, hybrid retrieval, backend-vector retrieval, configurable embedding provider, optional Chroma adapter, citation context, source blocks, Streamlit retrieval panel, optional single-chat and WeChat interactive injection +- [ ] RAG document QA (partial): PDF parsing has file-size, page-count, extracted-text and encrypted-file guards; production embedding requires explicit API/env configuration and Chroma remains optional +- [ ] Vector store: Chroma optional adapter implemented; FAISS local prototype and pgvector engineering version remain planned - [ ] Web UI: TypeScript + Vue3 / React, streaming chat, source panel - [ ] Observability: trace_id, token usage, latency, provider fallback logs diff --git a/docs/INTERVIEW_NOTES.md b/docs/INTERVIEW_NOTES.md index 8a4a3bb..2fc3797 100644 --- a/docs/INTERVIEW_NOTES.md +++ b/docs/INTERVIEW_NOTES.md @@ -10,7 +10,7 @@ Study Agent 是一个本地优先的 AI 学习助手,重点在多 Provider 模 2. **长期记忆写入安全** — safe writer + preview/confirm 机制,防止不可逆的记忆污染 3. **联网搜索来源追溯** — Feed registry / RSS 多源聚合 → URL safety matrix → 文章正文三层提取 → LLM digest → pipeline trace 全过程来源可回溯 4. **Streamlit 重渲染性能优化** — 多层缓存策略、按模式批量落盘、主链路 token 预算控制 -5. **CI / Ruff / detect-secrets 工程检查** — 273 pytest tests、Ruff clean、GitHub Actions workflow、detect-secrets 对未豁免发现硬阻断 +5. **CI / Ruff / detect-secrets 工程检查** — 277 pytest tests、Ruff clean、mypy local clean、GitHub Actions workflow、detect-secrets 对未豁免发现硬阻断 ## 可讲亮点 @@ -23,7 +23,7 @@ Study Agent 是一个本地优先的 AI 学习助手,重点在多 Provider 模 ## 展示边界 -- `mypy` 已接入 CI soft check,但当前本地仍有类型错误,不能说类型检查 clean。 +- `mypy` 已接入 CI soft check,当前本地 `python -m mypy --explicit-package-bases src` clean;但 CI 配置仍是非阻断检查。 - `performance_budget.py` 覆盖主要 chat / WeChat / news LLM 路径,辅助 LLM 调用仍需继续收口。 - `article_fetcher.py` 负责真实网络读取前的 DNS/IP SSRF 校验;`link_resolver.py` 是网络无关的 URL 预检和跳转记录。 - `detect-secrets` 已接入 CI,并通过解析扫描 JSON 的 `results` 对未豁免发现硬阻断;测试里的 Basic Auth 形态 URL 样例已显式 allowlist。 diff --git a/docs/RAG.md b/docs/RAG.md index 9729ce4..b2e8c8d 100644 --- a/docs/RAG.md +++ b/docs/RAG.md @@ -2,7 +2,7 @@ ## Status -Current status: **MVP implemented with a local vector prototype, not a production vector-store RAG system yet**. +Current status: **MVP implemented with a local-first retrieval path, configurable embeddings and an optional Chroma adapter**. Implemented: @@ -21,12 +21,13 @@ Implemented: - UI source blocks for retrieved file paths, line ranges, scores and matched terms - FastAPI endpoints: `GET /health`, `POST /rag`, `POST /rag/index`, `POST /rag/query` - Streamlit knowledge/debug panel with index summary, document rows, chunk preview and score breakdowns -- Optional vector backend interface with local fallback and Chroma adapter scaffold +- Optional vector backend interface with local fallback and Chroma adapter +- Configurable embedding providers: deterministic `local_hash` by default, OpenAI-compatible embeddings when explicitly configured Not implemented yet: -- Production embedding model integration - FAISS, pgvector or managed vector stores +- Production-grade embedding evaluation, relevance tuning and re-index migration tooling - Automatic injection into every generation path; current injection covers single chat and WeChat interactive replies, but not news discussion or after-session feedback ## Module Map @@ -36,9 +37,9 @@ Not implemented yet: | `src/rag/loader.py` | Load supported local files into normalized `RagDocument` objects | | `src/rag/chunker.py` | Split documents into line-traceable `RagChunk` objects | | `src/rag/index.py` | Build, save, load and search a local JSON RAG index | -| `src/rag/embeddings.py` | Embedding provider contract and local hash embedding provider | +| `src/rag/embeddings.py` | Embedding provider contract, local hash provider and OpenAI-compatible provider | | `src/rag/backends.py` | Vector backend contract, local backend and environment-driven backend selection | -| `src/rag/chroma_backend.py` | Optional Chroma persistent backend adapter scaffold | +| `src/rag/chroma_backend.py` | Optional Chroma persistent backend adapter | | `src/rag/vector.py` | Deterministic local vector prototype and hybrid retrieval | | `src/rag/eval.py` | LLM-free retrieval quality evaluation over gold query fixtures | | `src/rag/service.py` | Application-facing helpers for indexing, querying and context formatting | @@ -67,7 +68,7 @@ Supported retrieval modes: - `lexical`: TF-IDF-style term scoring - `vector`: deterministic local hash-vector cosine similarity - `hybrid`: normalized lexical score plus vector similarity -- `backend_vector`: configured vector backend; defaults to local and can use the optional Chroma adapter +- `backend_vector`: configured vector backend; defaults to local and can use the optional Chroma adapter with configured embeddings Each result keeps: @@ -139,11 +140,12 @@ P4-C / P6 adds Streamlit inspection controls: P5 adds the first vector-backend abstraction: -- `EmbeddingProvider` protocol plus `LocalHashEmbeddingProvider` +- `EmbeddingProvider` protocol plus `LocalHashEmbeddingProvider` and `OpenAIEmbeddingProvider` - `VectorBackend` protocol plus `LocalVectorBackend` - `RAG_VECTOR_BACKEND=local|chroma` +- `RAG_EMBEDDING_PROVIDER=local_hash|openai`, `RAG_EMBEDDING_MODEL`, `RAG_EMBEDDING_DIMENSIONS`, `RAG_EMBEDDING_API_KEY`, `RAG_EMBEDDING_BASE_URL` - Optional `ChromaVectorBackend` using lazy `chromadb` import, `PersistentClient`, collection `upsert` and vector query -- `tests/test_rag_backends.py` verifies local backend behavior, environment config and Chroma fake-client upsert/query behavior +- `tests/test_rag_backends.py` verifies local backend behavior, embedding environment config, OpenAI-compatible embedding batching and Chroma fake-client upsert/query behavior ## Next Steps @@ -163,9 +165,9 @@ Goal: replace the local hash-vector prototype with optional real embeddings with - [x] Extract an embedding-provider and vector-backend contract. - [x] Keep JSON + lexical / hybrid retrieval as the zero-infrastructure fallback. -- [x] Add an optional Chroma adapter scaffold with lazy import and fake-client tests. +- [x] Add an optional Chroma adapter with lazy import and fake-client tests. - [x] Make vector backend selection explicit through config. -- [ ] Add a production embedding provider; current Chroma adapter uses the local hash embedding provider by default. +- [x] Add a production embedding provider path; current default remains `local_hash`, while OpenAI-compatible embeddings require explicit env/API configuration. ### P6: Knowledge UI diff --git a/docs/TECH_STACK.md b/docs/TECH_STACK.md index 4c10217..38a92a9 100644 --- a/docs/TECH_STACK.md +++ b/docs/TECH_STACK.md @@ -35,7 +35,7 @@ Study Agent 是一个本地运行的 AI 学习助理系统,面向个人学习 | Long-term Memory | Markdown files | 用 `summary.md`、`current_focus.md`、`learner_profile.md` 等文件保存长期记忆 | | Context Control | fast / light / deep / archive tiers | 按性能模式选择不同记忆文件组,控制 token 成本 | | Routing | Rule-based router + optional LLM router | 根据任务类型、用户选择和性能模式决定角色、学习模式和模型档位 | -| RAG MVP | `src/rag/*`, `src/ui/rag_panel.py`, `src/api.py`, JSON index | 本地 Markdown / TXT / DOCX / PDF 加载、分块、关键词 / 本地向量原型 / hybrid / backend-vector 检索、引用上下文拼装、来源块、Streamlit 检索/调试面板、聊天注入和 FastAPI RAG endpoints | +| RAG MVP | `src/rag/*`, `src/ui/rag_panel.py`, `src/api.py`, JSON index | 本地 Markdown / TXT / DOCX / PDF 加载、分块、关键词 / 本地向量原型 / hybrid / backend-vector 检索、可配置 embedding provider、可选 Chroma adapter、引用上下文拼装、来源块、Streamlit 检索/调试面板、聊天注入和 FastAPI RAG endpoints | | News Search | Feed registry / RSS / Google News / Bing News / RSSHub-style sources | 多源新闻聚合、源健康记录、去重、排序、来源追溯 | | Article Extraction | `trafilatura`, `readability-lxml`, `lxml` | 新闻网页正文读取与降级解析 | | Security | URL safety matrix, SSRF validation, redirect checks, secret scanning | 防止读取本地/内网资源,降低密钥误提交风险 | @@ -273,7 +273,7 @@ User query - 带 `source_path`、标题、chunk 序号和行号范围的分块 - 本地关键词 / TF-IDF-style 检索 - deterministic hash-vector 本地向量原型与 hybrid 检索模式 -- `EmbeddingProvider` / `VectorBackend` 抽象,默认 local backend,可选 Chroma adapter scaffold +- `EmbeddingProvider` / `VectorBackend` 抽象,默认 local backend,可选 Chroma adapter,可显式配置 OpenAI-compatible embedding provider - 简单中文 CJK bigram 匹配 - JSON index 保存与加载,默认路径为 `logs/rag_index.json` - `build_rag_context()` 将检索结果拼装为带引用的 LLM 上下文块 @@ -284,7 +284,7 @@ User query 未实现边界: -- 尚未接入生产 embedding model、FAISS、pgvector 或其他生产向量库;Chroma 目前是 optional adapter scaffold +- 默认仍是 local-first;生产 embedding 需要显式 API/env 配置,Chroma 需要额外安装 `chromadb`;FAISS、pgvector 或其他生产向量库仍未接入 - FastAPI 目前覆盖 health 和 RAG;`/chat`、`/memory` 仍是后续服务化任务 - 尚未自动注入所有生成路径;当前覆盖单人聊天和微信群互动回复,不覆盖新闻讨论或课后反馈 @@ -302,7 +302,7 @@ User query - ruff - package helper - detect-secrets(hard gate for unallowlisted findings) -- mypy soft check(当前仍有类型债,不作为阻断条件) +- mypy soft check(当前本地 `python -m mypy --explicit-package-bases src` clean;CI 中仍作为非阻断检查运行) 自定义打包脚本会排除: @@ -329,7 +329,7 @@ User query - 设计 **OpenAI-compatible 多 Provider LLM 接入层**,支持 OpenAI、DeepSeek、OpenRouter、SiliconFlow 与本地模型,通过 `.env` 管理 base_url、模型名、超时、重试和任务级 token 预算。 - 实现 **规则路由 + 性能模式** 的模型选择机制,根据用户输入、任务类型、手动配置和 fast / standard / deep 模式动态选择角色、学习模式与 flash / pro 模型。 - 设计基于 **Markdown 文件的长期记忆系统**,按 fast / light / deep / archive 分层读取 `summary`、`current_focus`、`learner_profile`、`project_context` 等上下文,降低 token 消耗。 -- 实现 **本地 RAG MVP**,支持 Markdown / TXT / DOCX / PDF 加载、来源行号分块、关键词检索、本地向量原型、hybrid 检索、引用上下文拼装、来源块、Streamlit 检索面板、单人聊天和微信群互动回复注入,并提供 FastAPI health / RAG endpoints;embedding、生产向量库和更完整的来源面板仍作为后续演进。 +- 实现 **本地 RAG MVP**,支持 Markdown / TXT / DOCX / PDF 加载、来源行号分块、关键词检索、本地向量原型、hybrid / backend-vector 检索、可配置 embedding provider、可选 Chroma adapter、引用上下文拼装、来源块、Streamlit 检索面板、单人聊天和微信群互动回复注入,并提供 FastAPI health / RAG endpoints;FAISS、pgvector 和更完整的来源面板仍作为后续演进。 - 实现 **联网新闻检索与群聊讨论链路**,支持 RSS 聚合、链接解析、正文抽取、LLM 摘要、角色讨论和来源追溯。 - 在网页正文抓取模块加入 **SSRF 防护**,校验 HTTP scheme、DNS 解析结果、私网 IP、loopback、reserved 地址和重定向目标,提高本地联网模块安全性。 - 封装 **安全文件写入工具**,通过临时文件、原子替换、覆盖前备份和 PermissionError 重试保障记忆与日志写入可靠性。 @@ -343,7 +343,7 @@ User query - Implemented a unified **OpenAI-compatible multi-provider LLM layer** supporting OpenAI, DeepSeek, OpenRouter, SiliconFlow and local models, with environment-based model, timeout, retry and token-budget configuration. - Designed a **rule-based model routing system** with fast / standard / deep performance modes to dynamically select role, learning mode and flash / pro model tier based on task type and user configuration. - Built a **Markdown-based long-term memory system** with fast / light / deep / archive context tiers to balance personalization and token cost. -- Implemented a **local RAG MVP** for Markdown / TXT / DOCX / PDF loading, source-line chunking, lexical retrieval, a local vector prototype, hybrid retrieval, citation-first context formatting, source blocks, a Streamlit retrieval panel, optional single-chat / WeChat interactive injection, and FastAPI health / RAG endpoints; embedding models, production vector stores and a richer source panel remain planned. +- Implemented a **local RAG MVP** for Markdown / TXT / DOCX / PDF loading, source-line chunking, lexical retrieval, a local vector prototype, hybrid / backend-vector retrieval, configurable embedding providers, an optional Chroma adapter, citation-first context formatting, source blocks, a Streamlit retrieval panel, optional single-chat / WeChat interactive injection, and FastAPI health / RAG endpoints; FAISS, pgvector and a richer source panel remain planned. - Implemented a **source-traced news pipeline** covering RSS aggregation, link resolution, article extraction, LLM digest generation and role-based group discussion. - Added **SSRF protection** to the article-fetching module by validating URL scheme, DNS resolution results, private IP ranges and redirect targets. - Encapsulated **safe file persistence** with temporary writes, atomic replacement, automatic backup and retry on permission errors. diff --git a/docs/TESTING.md b/docs/TESTING.md index 23f670d..53138ba 100644 --- a/docs/TESTING.md +++ b/docs/TESTING.md @@ -6,10 +6,10 @@ Current verified baseline: | Check | Status | Evidence | |---|---|---| -| pytest | Passed | `273 passed` locally on 2026-06-05 | -| Ruff | Passed | `python -m ruff check .` clean locally on 2026-06-04 | -| Package helper | Passed | `python tools/package_project_helper.py . NUL 0` locally on 2026-06-04 | -| mypy | Soft check, not clean | `python -m mypy --explicit-package-bases src/` reported 18 errors locally on 2026-06-04 | +| pytest | Passed | `277 passed` locally on 2026-06-05 | +| Ruff | Passed | `python -m ruff check .` clean locally on 2026-06-05 | +| Package helper | Passed | `python tools/package_project_helper.py . NUL 0` locally on 2026-06-05 | +| mypy | Passed locally; CI soft check | `python -m mypy --explicit-package-bases src` clean locally on 2026-06-05 | | detect-secrets | CI hard gate configured | Workflow parses scan JSON and fails when `results` contains any unallowlisted finding; local tracked-file scan was empty on 2026-06-04 | | GitHub Actions | Recent main runs passing | Latest 6 CI runs on `main` were `success` when checked on 2026-06-03 | @@ -26,7 +26,7 @@ Current verified baseline: | **Feed registry / health** | `test_feed_registry.py`, `test_feed_diagnostics.py` | 9 | | **RAG MVP** | `test_rag.py` | 24 | | **RAG evaluation** | `test_rag_eval.py` | 5 | -| **RAG vector backends** | `test_rag_backends.py` | 6 | +| **RAG vector backends** | `test_rag_backends.py` | 10 | | **FastAPI RAG endpoints** | `test_api.py` | 6 | | **Architecture flows** | `test_architecture_flows.py` | 12 | | **WeChat decoupling** | `test_wechat_decoupling.py` | 4 | @@ -77,11 +77,11 @@ def test_flush_uses_safe_writer(): ## Running Tests ```bash -python -m pytest # current baseline: 273 passed +python -m pytest # current baseline: 277 passed pytest tests/ -v # Verbose pytest tests/ --cov=src # Coverage python -m ruff check . # Linting -python -m mypy --explicit-package-bases src/ # Soft check; currently has type debt +python -m mypy --explicit-package-bases src # Type check; CI currently runs it as a soft check ``` Tracked-file secret scan used for local verification: diff --git a/requirements-dev.in b/requirements-dev.in index bedbee9..a09575f 100644 --- a/requirements-dev.in +++ b/requirements-dev.in @@ -4,3 +4,4 @@ pytest>=8.0,<9 ruff>=0.6,<1 mypy>=1.0,<2 detect-secrets>=1.5,<2 +types-PyYAML>=6.0,<7 diff --git a/requirements-dev.txt b/requirements-dev.txt index a43e30f..eb61cee 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -235,6 +235,8 @@ typing-inspection==0.4.2 # via # fastapi # pydantic +types-pyyaml==6.0.12.20260518 + # via -r requirements-dev.in tzdata==2026.2 # via # pandas diff --git a/src/after_session.py b/src/after_session.py index 3987f57..a6806b2 100644 --- a/src/after_session.py +++ b/src/after_session.py @@ -1,6 +1,6 @@ import json import time -from src.llm_client import chat +from src.llm_client import ModelProfile, chat from src.text_utils import strip_code_fences from src.log_utils import get_logger from src.task_events import TaskEventCallback, emit_task_event @@ -99,7 +99,7 @@ def generate_after_session_updates( memory_bundle: dict[str, str], role: str, mode: str, - model_profile: str = "pro", + model_profile: ModelProfile = "pro", event_callback: TaskEventCallback | None = None, ) -> dict[str, str]: started_at = time.perf_counter() diff --git a/src/llm_client.py b/src/llm_client.py index 31c7c22..fe197c3 100644 --- a/src/llm_client.py +++ b/src/llm_client.py @@ -103,11 +103,8 @@ def reset_client() -> None: def _normalize_provider_profile(provider_profile: str | None = None) -> str: - profile = ( - (provider_profile or os.getenv("LLM_PROVIDER_PROFILE", "openai")) - .strip() - .lower() - ) + raw_profile = provider_profile or os.getenv("LLM_PROVIDER_PROFILE") or "openai" + profile = raw_profile.strip().lower() if profile not in _SUPPORTED_PROVIDER_PROFILES: raise RuntimeError(f"Unsupported provider profile: {profile}") return profile diff --git a/src/mode_manager.py b/src/mode_manager.py index 952a637..1ca92f3 100644 --- a/src/mode_manager.py +++ b/src/mode_manager.py @@ -84,6 +84,8 @@ def build_runtime_profile(modes: RuntimeModes) -> RuntimeProfile: memory_write_allowed = False memory_write_reason = modes.memory_mode + preferred_model: str | None + if modes.performance_mode == "fast": allow_llm_router = False llm_router_reason = "fast_mode" diff --git a/src/news/article_extractor.py b/src/news/article_extractor.py index 326f715..df1214b 100644 --- a/src/news/article_extractor.py +++ b/src/news/article_extractor.py @@ -135,7 +135,7 @@ def extract_article_text_with_readability( max_chars: int = 5000, ) -> str: try: - from readability import Document + from readability import Document # type: ignore[import-untyped] doc = Document(html) summary_html = doc.summary(html_partial=True) diff --git a/src/news/article_fetcher.py b/src/news/article_fetcher.py index cd0a521..ccb648a 100644 --- a/src/news/article_fetcher.py +++ b/src/news/article_fetcher.py @@ -6,7 +6,9 @@ import os import re import time +from collections.abc import MutableMapping from socket import getaddrinfo +from typing import Any from urllib.parse import urlparse from urllib.request import ( HTTPRedirectHandler, @@ -32,7 +34,7 @@ def _prune_cache( - cache: dict[str, tuple[float, ...]], + cache: MutableMapping[str, Any], ttl: float, max_size: int, now: float, @@ -163,11 +165,7 @@ def redirect_request(self, req, fp, code, msg, headers, newurl): # Shared opener with safe redirect handler (replaces default HTTPRedirectHandler) -_SAFE_OPENER = build_opener() -for i, h in enumerate(_SAFE_OPENER.handlers): - if isinstance(h, HTTPRedirectHandler): - _SAFE_OPENER.handlers[i] = _SafeHTTPRedirectHandler() - break +_SAFE_OPENER = build_opener(_SafeHTTPRedirectHandler()) # ── Article fetching ────────────────────────────────────────────────── diff --git a/src/news/rss_fetcher.py b/src/news/rss_fetcher.py index 5aa5fb0..87cd2d0 100644 --- a/src/news/rss_fetcher.py +++ b/src/news/rss_fetcher.py @@ -4,7 +4,9 @@ import re import time +from collections.abc import MutableMapping from email.utils import parsedate_to_datetime +from typing import Any from urllib.request import Request, urlopen import xml.etree.ElementTree as ET @@ -94,7 +96,7 @@ def _record_feed_result_safely( def _prune_cache( - cache: dict[str, tuple[float, ...]], + cache: MutableMapping[str, Any], ttl: float, max_size: int, now: float, diff --git a/src/rag/__init__.py b/src/rag/__init__.py index 1b3cecf..73386e4 100644 --- a/src/rag/__init__.py +++ b/src/rag/__init__.py @@ -13,7 +13,13 @@ get_vector_backend_from_env, vector_backend_config_from_env, ) -from src.rag.embeddings import LocalHashEmbeddingProvider +from src.rag.embeddings import ( + LocalHashEmbeddingProvider, + OpenAIEmbeddingProvider, + embedding_provider_config_from_env, + get_embedding_provider, + get_embedding_provider_from_env, +) from src.rag.eval import ( RagEvalCase, RagEvalResult, @@ -50,6 +56,10 @@ "get_vector_backend", "get_vector_backend_from_env", "LocalHashEmbeddingProvider", + "OpenAIEmbeddingProvider", + "embedding_provider_config_from_env", + "get_embedding_provider", + "get_embedding_provider_from_env", "LocalVectorBackend", "load_eval_cases", "load_rag_index", diff --git a/src/rag/backends.py b/src/rag/backends.py index d7abc8e..84ff13d 100644 --- a/src/rag/backends.py +++ b/src/rag/backends.py @@ -5,7 +5,11 @@ from pathlib import Path from typing import Protocol -from src.rag.embeddings import EmbeddingProvider, LocalHashEmbeddingProvider +from src.rag.embeddings import ( + EmbeddingProvider, + LocalHashEmbeddingProvider, + get_embedding_provider_from_env, +) from src.rag.schema import RagIndex, RagSearchResult from src.rag.vector import search_rag_index_vector @@ -112,9 +116,10 @@ def get_vector_backend_from_env( embedding_provider: EmbeddingProvider | None = None, ) -> VectorBackend: config = vector_backend_config_from_env() + resolved_embedding_provider = embedding_provider or get_embedding_provider_from_env() return get_vector_backend( config["name"], path=config["path"], collection=config["collection"], - embedding_provider=embedding_provider, + embedding_provider=resolved_embedding_provider, ) diff --git a/src/rag/chroma_backend.py b/src/rag/chroma_backend.py index 4323039..42d5fe7 100644 --- a/src/rag/chroma_backend.py +++ b/src/rag/chroma_backend.py @@ -64,10 +64,11 @@ def upsert_index(self, index: RagIndex) -> None: if not ids: return + texts = [chunk.text for chunk in index.chunks] collection.upsert( ids=ids, - embeddings=[list(self.embedding_provider.embed(chunk.text)) for chunk in index.chunks], - documents=[chunk.text for chunk in index.chunks], + embeddings=[list(vector) for vector in self.embedding_provider.embed_many(texts)], + documents=texts, metadatas=[_chunk_metadata(chunk) for chunk in index.chunks], ) diff --git a/src/rag/embeddings.py b/src/rag/embeddings.py index eab2c18..c17092e 100644 --- a/src/rag/embeddings.py +++ b/src/rag/embeddings.py @@ -1,6 +1,7 @@ from __future__ import annotations -from typing import Protocol +import os +from typing import Any, Protocol from src.rag.vector import VECTOR_DIMENSIONS, embed_text @@ -12,6 +13,9 @@ class EmbeddingProvider(Protocol): def embed(self, text: str) -> tuple[float, ...]: """Embed one text string for vector retrieval.""" + def embed_many(self, texts: list[str]) -> list[tuple[float, ...]]: + """Embed a batch of text strings in input order.""" + class LocalHashEmbeddingProvider: name = "local_hash" @@ -19,3 +23,142 @@ class LocalHashEmbeddingProvider: def embed(self, text: str) -> tuple[float, ...]: return embed_text(text, dimensions=self.dimensions) + + def embed_many(self, texts: list[str]) -> list[tuple[float, ...]]: + return [self.embed(text) for text in texts] + + +class OpenAIEmbeddingProvider: + name = "openai" + + def __init__( + self, + *, + model: str = "text-embedding-3-small", + dimensions: int | None = None, + api_key: str | None = None, + base_url: str | None = None, + timeout: float | None = None, + client: Any | None = None, + ) -> None: + self.model = model + self.name = f"openai:{model}" + self.dimensions = dimensions or 1536 + self._dimensions_override = dimensions + self._api_key = api_key + self._base_url = base_url + self._timeout = timeout + self._client = client + + def _get_client(self) -> Any: + if self._client is not None: + return self._client + + api_key = self._api_key or os.getenv("RAG_EMBEDDING_API_KEY") or os.getenv("OPENAI_API_KEY") + if not api_key: + raise RuntimeError("RAG_EMBEDDING_API_KEY or OPENAI_API_KEY is required for OpenAI embeddings") + + try: + from openai import OpenAI + except Exception as exc: + raise RuntimeError("openai is required for OpenAI embeddings") from exc + + kwargs: dict[str, Any] = {"api_key": api_key} + base_url = self._base_url or os.getenv("RAG_EMBEDDING_BASE_URL") or os.getenv("OPENAI_BASE_URL") + if base_url: + kwargs["base_url"] = base_url + if self._timeout is not None: + kwargs["timeout"] = self._timeout + self._client = OpenAI(**kwargs) + return self._client + + def embed(self, text: str) -> tuple[float, ...]: + return self.embed_many([text])[0] + + def embed_many(self, texts: list[str]) -> list[tuple[float, ...]]: + if not texts: + return [] + + request: dict[str, Any] = {"model": self.model, "input": texts} + if self._dimensions_override is not None: + request["dimensions"] = self._dimensions_override + + response = self._get_client().embeddings.create(**request) + return [tuple(float(value) for value in item.embedding) for item in response.data] + + +def _env_int(name: str) -> int | None: + value = os.getenv(name, "").strip() + if not value: + return None + try: + parsed = int(value) + except ValueError as exc: + raise ValueError(f"{name} must be an integer") from exc + if parsed <= 0: + raise ValueError(f"{name} must be positive") + return parsed + + +def _env_float(name: str) -> float | None: + value = os.getenv(name, "").strip() + if not value: + return None + try: + parsed = float(value) + except ValueError as exc: + raise ValueError(f"{name} must be a number") from exc + if parsed <= 0: + raise ValueError(f"{name} must be positive") + return parsed + + +def embedding_provider_config_from_env() -> dict[str, str]: + return { + "provider": os.getenv("RAG_EMBEDDING_PROVIDER", "local_hash").strip() or "local_hash", + "model": os.getenv("RAG_EMBEDDING_MODEL", "text-embedding-3-small").strip() + or "text-embedding-3-small", + "dimensions": os.getenv("RAG_EMBEDDING_DIMENSIONS", "").strip(), + "base_url_configured": "true" + if (os.getenv("RAG_EMBEDDING_BASE_URL") or os.getenv("OPENAI_BASE_URL")) + else "false", + "api_key_configured": "true" + if (os.getenv("RAG_EMBEDDING_API_KEY") or os.getenv("OPENAI_API_KEY")) + else "false", + } + + +def get_embedding_provider( + name: str = "local_hash", + *, + model: str = "text-embedding-3-small", + dimensions: int | None = None, + api_key: str | None = None, + base_url: str | None = None, + timeout: float | None = None, + client: Any | None = None, +) -> EmbeddingProvider: + normalized = (name or "local_hash").strip().lower() + if normalized in {"local", "local_hash", "hash"}: + return LocalHashEmbeddingProvider() + if normalized in {"openai", "openai_compatible"}: + return OpenAIEmbeddingProvider( + model=model, + dimensions=dimensions, + api_key=api_key, + base_url=base_url, + timeout=timeout, + client=client, + ) + raise ValueError(f"Unsupported embedding provider: {name}") + + +def get_embedding_provider_from_env() -> EmbeddingProvider: + config = embedding_provider_config_from_env() + return get_embedding_provider( + config["provider"], + model=config["model"], + dimensions=_env_int("RAG_EMBEDDING_DIMENSIONS"), + base_url=os.getenv("RAG_EMBEDDING_BASE_URL") or os.getenv("OPENAI_BASE_URL"), + timeout=_env_float("RAG_EMBEDDING_TIMEOUT_SECONDS"), + ) diff --git a/src/router.py b/src/router.py index 8b9ce9e..6e4ea02 100644 --- a/src/router.py +++ b/src/router.py @@ -3,7 +3,7 @@ import re from dataclasses import dataclass from pathlib import Path -from typing import Literal +from typing import Literal, cast from src.mode_manager import RuntimeModes from src.model_stats import suggest_model_by_rules @@ -79,7 +79,13 @@ def _load_rules_from_markdown(text: str) -> list: model_match = re.search(r"model:\s*(\S+)", block) reason_match = re.search(r"reason:\s*(.+)", block) pri_match = re.search(r"priority:\s*(\d+)", block) - if all([kw_match, role_match, mode_match, model_match, reason_match]): + if ( + kw_match is not None + and role_match is not None + and mode_match is not None + and model_match is not None + and reason_match is not None + ): priority = int(pri_match.group(1)) if pri_match else 50 rules.append( ( @@ -194,7 +200,7 @@ def route_request( # Model selection with performance_mode awareness if selected_model != "auto": - model_profile = selected_model + model_profile = cast(Model, selected_model) elif runtime_modes.performance_mode == "deep": model_profile = "pro" elif runtime_modes.performance_mode == "fast": diff --git a/src/ui/chat_panel.py b/src/ui/chat_panel.py index 98f9a8b..c3852c1 100644 --- a/src/ui/chat_panel.py +++ b/src/ui/chat_panel.py @@ -398,7 +398,7 @@ def _mark_first_token(): mode=resolved_mode, model=resolved_model, user_input=user_input, - agent_reply=reply, + agent_reply=str(reply), memory_enabled=bool(memory_bundle), route_info=st.session_state.current_route, ) diff --git a/src/wechat_memory.py b/src/wechat_memory.py index 4d07eb0..ee72d42 100644 --- a/src/wechat_memory.py +++ b/src/wechat_memory.py @@ -32,7 +32,7 @@ def extract_memory_candidates( memory_bundle: dict[str, str] | None = None, model_profile: ModelProfile = "pro", ) -> dict: - empty = {k: [] for k in CANDIDATE_KEYS} + empty: dict[str, list] = {k: [] for k in CANDIDATE_KEYS} modes = load_runtime_modes() if not modes.memory_capture_enabled: diff --git a/tests/test_rag_backends.py b/tests/test_rag_backends.py index 9ee87c8..54c2e55 100644 --- a/tests/test_rag_backends.py +++ b/tests/test_rag_backends.py @@ -9,9 +9,33 @@ vector_backend_config_from_env, ) from src.rag.chroma_backend import ChromaVectorBackend +from src.rag.embeddings import ( + LocalHashEmbeddingProvider, + OpenAIEmbeddingProvider, + embedding_provider_config_from_env, + get_embedding_provider, + get_embedding_provider_from_env, +) from src.rag.index import build_rag_index +class _RecordingEmbeddingProvider: + name = "recording" + dimensions = 3 + + def __init__(self) -> None: + self.embed_inputs = [] + self.embed_many_inputs = [] + + def embed(self, text: str) -> tuple[float, ...]: + self.embed_inputs.append(text) + return (0.1, 0.2, 0.3) + + def embed_many(self, texts: list[str]) -> list[tuple[float, ...]]: + self.embed_many_inputs.append(list(texts)) + return [(0.1, 0.2, 0.3) for _text in texts] + + class _FakeCollection: def __init__(self) -> None: self.upserts = [] @@ -53,6 +77,77 @@ def get_or_create_collection(self, name: str): return self.collection +class _FakeEmbeddingClient: + def __init__(self) -> None: + self.embeddings = self + self.requests = [] + + def create(self, **kwargs): + self.requests.append(kwargs) + + class Item: + embedding = [0.1, 0.2, 0.3] + + class Response: + pass + + response = Response() + response.data = [Item() for _text in kwargs["input"]] + return response + + + +def test_embedding_provider_config_reads_environment(monkeypatch): + monkeypatch.setenv("RAG_EMBEDDING_PROVIDER", "openai") + monkeypatch.setenv("RAG_EMBEDDING_MODEL", "text-embedding-3-small") + monkeypatch.setenv("RAG_EMBEDDING_DIMENSIONS", "768") + monkeypatch.setenv("RAG_EMBEDDING_API_KEY", "test-key") + monkeypatch.setenv("RAG_EMBEDDING_BASE_URL", "https://example.test/v1") + + config = embedding_provider_config_from_env() + + assert config == { + "provider": "openai", + "model": "text-embedding-3-small", + "dimensions": "768", + "base_url_configured": "true", + "api_key_configured": "true", + } + + +def test_get_embedding_provider_from_env_defaults_to_local(monkeypatch): + monkeypatch.delenv("RAG_EMBEDDING_PROVIDER", raising=False) + + provider = get_embedding_provider_from_env() + + assert isinstance(provider, LocalHashEmbeddingProvider) + + +def test_get_embedding_provider_rejects_unknown_provider(): + with pytest.raises(ValueError, match="Unsupported embedding provider"): + get_embedding_provider("unknown") + + +def test_openai_embedding_provider_batches_requests(): + fake_client = _FakeEmbeddingClient() + provider = OpenAIEmbeddingProvider( + model="text-embedding-3-small", + dimensions=3, + client=fake_client, + ) + + vectors = provider.embed_many(["first", "second"]) + + assert vectors == [(0.1, 0.2, 0.3), (0.1, 0.2, 0.3)] + assert fake_client.requests == [ + { + "model": "text-embedding-3-small", + "input": ["first", "second"], + "dimensions": 3, + } + ] + + def test_local_vector_backend_queries_index(tmp_path): path = tmp_path / "notes.md" path.write_text("Vector backend retrieves local chunks.", encoding="utf-8") @@ -99,10 +194,12 @@ def test_chroma_backend_upserts_chunks_with_embeddings(tmp_path): path.write_text("Chroma adapter stores local chunks.", encoding="utf-8") index = build_rag_index([path], max_chars=200, overlap_chars=0) fake_client = _FakeClient() + embedding_provider = _RecordingEmbeddingProvider() backend = ChromaVectorBackend( path=tmp_path / "chroma", collection_name="study_agent_test", client=fake_client, + embedding_provider=embedding_provider, ) backend.upsert_index(index) @@ -110,17 +207,20 @@ def test_chroma_backend_upserts_chunks_with_embeddings(tmp_path): assert fake_client.collection_names == ["study_agent_test"] upsert = fake_client.collection.upserts[0] assert upsert["ids"] == [index.chunks[0].chunk_id] - assert len(upsert["embeddings"][0]) == 256 + assert upsert["embeddings"][0] == [0.1, 0.2, 0.3] assert upsert["documents"] == [index.chunks[0].text] assert upsert["metadatas"][0]["source_path"] == str(path) + assert embedding_provider.embed_many_inputs == [[index.chunks[0].text]] def test_chroma_backend_query_reconstructs_search_results(tmp_path): fake_client = _FakeClient() + embedding_provider = _RecordingEmbeddingProvider() backend = ChromaVectorBackend( path=tmp_path / "chroma", collection_name="study_agent_test", client=fake_client, + embedding_provider=embedding_provider, ) path = tmp_path / "notes.md" path.write_text("Local placeholder.", encoding="utf-8") @@ -130,6 +230,8 @@ def test_chroma_backend_query_reconstructs_search_results(tmp_path): assert backend.status().available is True assert fake_client.collection.last_query["n_results"] == 1 + assert fake_client.collection.last_query["query_embeddings"] == [[0.1, 0.2, 0.3]] + assert embedding_provider.embed_inputs == ["stored chunk"] assert results[0].score == 0.8 assert results[0].chunk.source_path == "notes.md" assert results[0].chunk.start_line == 1