diff --git a/api/db/init_data.py b/api/db/init_data.py index 458e4ba257..1bf6f01a60 100644 --- a/api/db/init_data.py +++ b/api/db/init_data.py @@ -21,6 +21,8 @@ from api.db import LLMType, UserTenantRole from api.db.db_models import init_database_tables as init_web_db, LLMFactories, LLM, TenantLLM from api.db.services import UserService +from api.db.services.document_service import DocumentService +from api.db.services.knowledgebase_service import KnowledgebaseService from api.db.services.llm_service import LLMFactoriesService, LLMService, TenantLLMService, LLMBundle from api.db.services.user_service import TenantService, UserTenantService from api.settings import CHAT_MDL, EMBEDDING_MDL, ASR_MDL, IMAGE2TEXT_MDL, PARSERS, LLM_FACTORY, API_KEY, LLM_BASE_URL @@ -406,6 +408,8 @@ def init_llm_factory(): except Exception as e: pass break + for kb_id in KnowledgebaseService.get_all_ids(): + KnowledgebaseService.update_by_id(kb_id, {"doc_num": DocumentService.get_kb_doc_count(kb_id)}) """ drop table llm; drop table llm_factories; diff --git a/api/db/services/document_service.py b/api/db/services/document_service.py index 1bb5015928..c4dc4db042 100644 --- a/api/db/services/document_service.py +++ b/api/db/services/document_service.py @@ -265,3 +265,9 @@ def update_progress(cls): except Exception as e: stat_logger.error("fetch task exception:" + str(e)) + @classmethod + @DB.connection_context() + def get_kb_doc_count(cls, kb_id): + return len(cls.model.select(cls.model.id).where( + cls.model.kb_id == kb_id).dicts()) + diff --git a/api/db/services/knowledgebase_service.py b/api/db/services/knowledgebase_service.py index 50633b85b5..f1bc131ea9 100644 --- a/api/db/services/knowledgebase_service.py +++ b/api/db/services/knowledgebase_service.py @@ -112,3 +112,8 @@ def get_by_name(cls, kb_name, tenant_id): if kb: return True, kb[0] return False, None + + @classmethod + @DB.connection_context() + def get_all_ids(cls): + return [m["id"] for m in cls.model.select(cls.model.id).dicts()]