diff --git a/apps/models_provider/impl/aliyun_bai_lian_model_provider/credential/embedding.py b/apps/models_provider/impl/aliyun_bai_lian_model_provider/credential/embedding.py
index 95da1e03b0f..0a3c3910771 100644
--- a/apps/models_provider/impl/aliyun_bai_lian_model_provider/credential/embedding.py
+++ b/apps/models_provider/impl/aliyun_bai_lian_model_provider/credential/embedding.py
@@ -13,10 +13,27 @@
from common import forms
from common.exception.app_exception import AppApiException
-from common.forms import BaseForm
+from common.forms import BaseForm, TooltipLabel
from models_provider.base_model_provider import BaseModelCredential, ValidCode
from models_provider.impl.aliyun_bai_lian_model_provider.model.embedding import AliyunBaiLianEmbedding
+class BaiLianEmbeddingModelParams(BaseForm):
+ dimensions = forms.SingleSelect(
+ TooltipLabel(
+ _('Dimensions'),
+ _('')
+ ),
+ required=True,
+ default_value=1024,
+ value_field='value',
+ text_field='label',
+ option_list=[
+ {'label': '1024', 'value': '1024'},
+ {'label': '768', 'value': '768'},
+ {'label': '512', 'value': '512'},
+ ]
+ )
+
class AliyunBaiLianEmbeddingCredential(BaseForm, BaseModelCredential):
@@ -71,4 +88,8 @@ def encryption_dict(self, model: Dict[str, Any]) -> Dict[str, Any]:
api_key = model.get('dashscope_api_key', '')
return {**model, 'dashscope_api_key': super().encryption(api_key)}
+
+ def get_model_params_setting_form(self, model_name):
+ return BaiLianEmbeddingModelParams()
+
dashscope_api_key = forms.PasswordInputField('API Key', required=True)
diff --git a/apps/models_provider/impl/aliyun_bai_lian_model_provider/model/embedding.py b/apps/models_provider/impl/aliyun_bai_lian_model_provider/model/embedding.py
index 0316782dd10..786469e0283 100644
--- a/apps/models_provider/impl/aliyun_bai_lian_model_provider/model/embedding.py
+++ b/apps/models_provider/impl/aliyun_bai_lian_model_provider/model/embedding.py
@@ -6,61 +6,43 @@
@dateļ¼2024/10/16 16:34
@desc:
"""
-from functools import reduce
from typing import Dict, List
-from langchain_community.embeddings import DashScopeEmbeddings
-from langchain_community.embeddings.dashscope import embed_with_retry
+from openai import OpenAI
from models_provider.base_model_provider import MaxKBBaseModel
-def proxy_embed_documents(texts: List[str], step_size, embed_documents):
- value = [embed_documents(texts[start_index:start_index + step_size]) for start_index in
- range(0, len(texts), step_size)]
- return reduce(lambda x, y: [*x, *y], value, [])
+class AliyunBaiLianEmbedding(MaxKBBaseModel):
+ model_name: str
+ optional_params: dict
+ def __init__(self, api_key, model_name: str, optional_params: dict):
+ self.client = OpenAI(api_key=api_key, base_url='https://dashscope.aliyuncs.com/compatible-mode/v1').embeddings
+ self.model_name = model_name
+ self.optional_params = optional_params
-class AliyunBaiLianEmbedding(MaxKBBaseModel, DashScopeEmbeddings):
@staticmethod
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
+ optional_params = MaxKBBaseModel.filter_optional_params(model_kwargs)
return AliyunBaiLianEmbedding(
- model=model_name,
- dashscope_api_key=model_credential.get('dashscope_api_key')
+ api_key=model_credential.get('dashscope_api_key'),
+ model_name=model_name,
+ optional_params=optional_params
)
- def embed_documents(self, texts: List[str]) -> List[List[float]]:
- if self.model == 'text-embedding-v3':
- return proxy_embed_documents(texts, 6, self._embed_documents)
- return self._embed_documents(texts)
-
- def _embed_documents(self, texts: List[str]) -> List[List[float]]:
- """Call out to DashScope's embedding endpoint for embedding search docs.
-
- Args:
- texts: The list of texts to embed.
- chunk_size: The chunk size of embeddings. If None, will use the chunk size
- specified by the class.
-
- Returns:
- List of embeddings, one for each text.
- """
- embeddings = embed_with_retry(
- self, input=texts, text_type="document", model=self.model
- )
- embedding_list = [item["embedding"] for item in embeddings]
- return embedding_list
-
- def embed_query(self, text: str) -> List[float]:
- """Call out to DashScope's embedding endpoint for embedding query text.
-
- Args:
- text: The text to embed.
-
- Returns:
- Embedding for the text.
- """
- embedding = embed_with_retry(
- self, input=[text], text_type="document", model=self.model
- )[0]["embedding"]
- return embedding
+ def embed_query(self, text: str):
+ res = self.embed_documents([text])
+ return res[0]
+
+ def embed_documents(
+ self, texts: List[str], chunk_size: int | None = None
+ ) -> List[List[float]]:
+ if len(self.optional_params) > 0:
+ res = self.client.create(
+ input=texts, model=self.model_name, encoding_format="float",
+ **self.optional_params
+ )
+ else:
+ res = self.client.create(input=texts, model=self.model_name, encoding_format="float")
+ return [e.embedding for e in res.data]
diff --git a/apps/models_provider/impl/openai_model_provider/model/embedding.py b/apps/models_provider/impl/openai_model_provider/model/embedding.py
index 3a0aaeeb1a4..7362e8fcd32 100644
--- a/apps/models_provider/impl/openai_model_provider/model/embedding.py
+++ b/apps/models_provider/impl/openai_model_provider/model/embedding.py
@@ -15,17 +15,21 @@
class OpenAIEmbeddingModel(MaxKBBaseModel):
model_name: str
+ optional_params: dict
- def __init__(self, api_key, base_url, model_name: str):
+ def __init__(self, api_key, base_url, model_name: str, optional_params: dict):
self.client = openai.OpenAI(api_key=api_key, base_url=base_url).embeddings
self.model_name = model_name
+ self.optional_params = optional_params
@staticmethod
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
+ optional_params = MaxKBBaseModel.filter_optional_params(model_kwargs)
return OpenAIEmbeddingModel(
api_key=model_credential.get('api_key'),
model_name=model_name,
base_url=model_credential.get('api_base'),
+ optional_params=optional_params
)
def embed_query(self, text: str):
@@ -35,5 +39,11 @@ def embed_query(self, text: str):
def embed_documents(
self, texts: List[str], chunk_size: int | None = None
) -> List[List[float]]:
- res = self.client.create(input=texts, model=self.model_name, encoding_format="float")
+ if len(self.optional_params) > 0:
+ res = self.client.create(
+ input=texts, model=self.model_name, encoding_format="float",
+ **self.optional_params
+ )
+ else:
+ res = self.client.create(input=texts, model=self.model_name, encoding_format="float")
return [e.embedding for e in res.data]
diff --git a/ui/src/views/model/component/CreateModelDialog.vue b/ui/src/views/model/component/CreateModelDialog.vue
index d3791947d27..78132e173b1 100644
--- a/ui/src/views/model/component/CreateModelDialog.vue
+++ b/ui/src/views/model/component/CreateModelDialog.vue
@@ -140,8 +140,7 @@
/>
@@ -150,7 +149,7 @@
{{ $t('common.add') }}
diff --git a/ui/src/views/model/component/ModelCard.vue b/ui/src/views/model/component/ModelCard.vue
index dffca98c755..006475a6525 100644
--- a/ui/src/views/model/component/ModelCard.vue
+++ b/ui/src/views/model/component/ModelCard.vue
@@ -95,6 +95,7 @@
currentModel.model_type === 'IMAGE' ||
currentModel.model_type === 'TTI' ||
currentModel.model_type === 'ITV' ||
+ currentModel.model_type === 'EMBEDDING' ||
currentModel.model_type === 'TTV') &&
permissionPrecise.paramSetting(model.id)
"