Skip to content

Commit

Permalink
Add 2 embeding models from OpenAI (infiniflow#812)
Browse files Browse the repository at this point in the history
### What problem does this PR solve?

infiniflow#810 

### Type of change

- [x] New Feature (non-breaking change which adds functionality)
  • Loading branch information
KevinHuSh committed May 17, 2024
1 parent d54d137 commit e73ce39
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 0 deletions.
30 changes: 30 additions & 0 deletions api/db/init_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import os
import time
import uuid
from copy import deepcopy

from api.db import LLMType, UserTenantRole
from api.db.db_models import init_database_tables as init_web_db, LLMFactories, LLM, TenantLLM
Expand Down Expand Up @@ -166,6 +167,18 @@ def init_llm_factory():
"tags": "TEXT EMBEDDING,8K",
"max_tokens": 8191,
"model_type": LLMType.EMBEDDING.value
}, {
"fid": factory_infos[0]["name"],
"llm_name": "text-embedding-3-small",
"tags": "TEXT EMBEDDING,8K",
"max_tokens": 8191,
"model_type": LLMType.EMBEDDING.value
}, {
"fid": factory_infos[0]["name"],
"llm_name": "text-embedding-3-large",
"tags": "TEXT EMBEDDING,8K",
"max_tokens": 8191,
"model_type": LLMType.EMBEDDING.value
}, {
"fid": factory_infos[0]["name"],
"llm_name": "whisper-1",
Expand Down Expand Up @@ -376,6 +389,23 @@ def init_llm_factory():
LLMFactoriesService.filter_delete([LLMFactoriesService.model.name == "QAnything"])
LLMService.filter_delete([LLMService.model.fid == "QAnything"])
TenantLLMService.filter_update([TenantLLMService.model.llm_factory == "QAnything"], {"llm_factory": "Youdao"})
## insert openai two embedding models to the current openai user.
print("Start to insert 2 OpenAI embedding models...")
tenant_ids = set([row.tenant_id for row in TenantLLMService.get_openai_models()])
for tid in tenant_ids:
for row in TenantLLMService.get_openai_models(llm_factory="OpenAI", tenant_id=tid):
row = row.to_dict()
row["model_type"] = LLMType.EMBEDDING.value
row["llm_name"] = "text-embedding-3-small"
row["used_tokens"] = 0
try:
TenantLLMService.save(**row)
row = deepcopy(row)
row["llm_name"] = "text-embedding-3-large"
TenantLLMService.save(**row)
except Exception as e:
pass
break
"""
drop table llm;
drop table llm_factories;
Expand Down
10 changes: 10 additions & 0 deletions api/db/services/llm_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,16 @@ def increase_usage(cls, tenant_id, llm_type, used_tokens, llm_name=None):
.execute()
return num

@classmethod
@DB.connection_context()
def get_openai_models(cls):
objs = cls.model.select().where(
(cls.model.llm_factory == "OpenAI"),
~(cls.model.llm_name == "text-embedding-3-small"),
~(cls.model.llm_name == "text-embedding-3-large")
).dicts()
return list(objs)


class LLMBundle(object):
def __init__(self, tenant_id, llm_type, llm_name=None, lang="Chinese"):
Expand Down

0 comments on commit e73ce39

Please sign in to comment.