diff --git a/api/db/init_data.py b/api/db/init_data.py index 7a449619ee..fca7307739 100644 --- a/api/db/init_data.py +++ b/api/db/init_data.py @@ -16,6 +16,7 @@ import os import time import uuid +from copy import deepcopy from api.db import LLMType, UserTenantRole from api.db.db_models import init_database_tables as init_web_db, LLMFactories, LLM, TenantLLM @@ -166,6 +167,18 @@ def init_llm_factory(): "tags": "TEXT EMBEDDING,8K", "max_tokens": 8191, "model_type": LLMType.EMBEDDING.value + }, { + "fid": factory_infos[0]["name"], + "llm_name": "text-embedding-3-small", + "tags": "TEXT EMBEDDING,8K", + "max_tokens": 8191, + "model_type": LLMType.EMBEDDING.value + }, { + "fid": factory_infos[0]["name"], + "llm_name": "text-embedding-3-large", + "tags": "TEXT EMBEDDING,8K", + "max_tokens": 8191, + "model_type": LLMType.EMBEDDING.value }, { "fid": factory_infos[0]["name"], "llm_name": "whisper-1", @@ -376,6 +389,23 @@ def init_llm_factory(): LLMFactoriesService.filter_delete([LLMFactoriesService.model.name == "QAnything"]) LLMService.filter_delete([LLMService.model.fid == "QAnything"]) TenantLLMService.filter_update([TenantLLMService.model.llm_factory == "QAnything"], {"llm_factory": "Youdao"}) + ## insert openai two embedding models to the current openai user. + print("Start to insert 2 OpenAI embedding models...") + tenant_ids = set([row.tenant_id for row in TenantLLMService.get_openai_models()]) + for tid in tenant_ids: + for row in TenantLLMService.get_openai_models(llm_factory="OpenAI", tenant_id=tid): + row = row.to_dict() + row["model_type"] = LLMType.EMBEDDING.value + row["llm_name"] = "text-embedding-3-small" + row["used_tokens"] = 0 + try: + TenantLLMService.save(**row) + row = deepcopy(row) + row["llm_name"] = "text-embedding-3-large" + TenantLLMService.save(**row) + except Exception as e: + pass + break """ drop table llm; drop table llm_factories; diff --git a/api/db/services/llm_service.py b/api/db/services/llm_service.py index 5129bb798f..4776544fc7 100644 --- a/api/db/services/llm_service.py +++ b/api/db/services/llm_service.py @@ -135,6 +135,16 @@ def increase_usage(cls, tenant_id, llm_type, used_tokens, llm_name=None): .execute() return num + @classmethod + @DB.connection_context() + def get_openai_models(cls): + objs = cls.model.select().where( + (cls.model.llm_factory == "OpenAI"), + ~(cls.model.llm_name == "text-embedding-3-small"), + ~(cls.model.llm_name == "text-embedding-3-large") + ).dicts() + return list(objs) + class LLMBundle(object): def __init__(self, tenant_id, llm_type, llm_name=None, lang="Chinese"):