From ae2e08d90df910c8a6291cd3f5a6d201454a0208 Mon Sep 17 00:00:00 2001 From: CaptainB Date: Sun, 28 Sep 2025 12:03:03 +0800 Subject: [PATCH] feat: add optional parameters to OpenAIEmbeddingModel for enhanced embedding functionality --- .../credential/embedding.py | 23 +++++- .../model/embedding.py | 72 +++++++------------ .../openai_model_provider/model/embedding.py | 14 +++- .../model/component/CreateModelDialog.vue | 5 +- ui/src/views/model/component/ModelCard.vue | 1 + 5 files changed, 64 insertions(+), 51 deletions(-) diff --git a/apps/models_provider/impl/aliyun_bai_lian_model_provider/credential/embedding.py b/apps/models_provider/impl/aliyun_bai_lian_model_provider/credential/embedding.py index 95da1e03b0f..0a3c3910771 100644 --- a/apps/models_provider/impl/aliyun_bai_lian_model_provider/credential/embedding.py +++ b/apps/models_provider/impl/aliyun_bai_lian_model_provider/credential/embedding.py @@ -13,10 +13,27 @@ from common import forms from common.exception.app_exception import AppApiException -from common.forms import BaseForm +from common.forms import BaseForm, TooltipLabel from models_provider.base_model_provider import BaseModelCredential, ValidCode from models_provider.impl.aliyun_bai_lian_model_provider.model.embedding import AliyunBaiLianEmbedding +class BaiLianEmbeddingModelParams(BaseForm): + dimensions = forms.SingleSelect( + TooltipLabel( + _('Dimensions'), + _('') + ), + required=True, + default_value=1024, + value_field='value', + text_field='label', + option_list=[ + {'label': '1024', 'value': '1024'}, + {'label': '768', 'value': '768'}, + {'label': '512', 'value': '512'}, + ] + ) + class AliyunBaiLianEmbeddingCredential(BaseForm, BaseModelCredential): @@ -71,4 +88,8 @@ def encryption_dict(self, model: Dict[str, Any]) -> Dict[str, Any]: api_key = model.get('dashscope_api_key', '') return {**model, 'dashscope_api_key': super().encryption(api_key)} + + def get_model_params_setting_form(self, model_name): + return BaiLianEmbeddingModelParams() + dashscope_api_key = forms.PasswordInputField('API Key', required=True) diff --git a/apps/models_provider/impl/aliyun_bai_lian_model_provider/model/embedding.py b/apps/models_provider/impl/aliyun_bai_lian_model_provider/model/embedding.py index 0316782dd10..786469e0283 100644 --- a/apps/models_provider/impl/aliyun_bai_lian_model_provider/model/embedding.py +++ b/apps/models_provider/impl/aliyun_bai_lian_model_provider/model/embedding.py @@ -6,61 +6,43 @@ @date:2024/10/16 16:34 @desc: """ -from functools import reduce from typing import Dict, List -from langchain_community.embeddings import DashScopeEmbeddings -from langchain_community.embeddings.dashscope import embed_with_retry +from openai import OpenAI from models_provider.base_model_provider import MaxKBBaseModel -def proxy_embed_documents(texts: List[str], step_size, embed_documents): - value = [embed_documents(texts[start_index:start_index + step_size]) for start_index in - range(0, len(texts), step_size)] - return reduce(lambda x, y: [*x, *y], value, []) +class AliyunBaiLianEmbedding(MaxKBBaseModel): + model_name: str + optional_params: dict + def __init__(self, api_key, model_name: str, optional_params: dict): + self.client = OpenAI(api_key=api_key, base_url='https://dashscope.aliyuncs.com/compatible-mode/v1').embeddings + self.model_name = model_name + self.optional_params = optional_params -class AliyunBaiLianEmbedding(MaxKBBaseModel, DashScopeEmbeddings): @staticmethod def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs): + optional_params = MaxKBBaseModel.filter_optional_params(model_kwargs) return AliyunBaiLianEmbedding( - model=model_name, - dashscope_api_key=model_credential.get('dashscope_api_key') + api_key=model_credential.get('dashscope_api_key'), + model_name=model_name, + optional_params=optional_params ) - def embed_documents(self, texts: List[str]) -> List[List[float]]: - if self.model == 'text-embedding-v3': - return proxy_embed_documents(texts, 6, self._embed_documents) - return self._embed_documents(texts) - - def _embed_documents(self, texts: List[str]) -> List[List[float]]: - """Call out to DashScope's embedding endpoint for embedding search docs. - - Args: - texts: The list of texts to embed. - chunk_size: The chunk size of embeddings. If None, will use the chunk size - specified by the class. - - Returns: - List of embeddings, one for each text. - """ - embeddings = embed_with_retry( - self, input=texts, text_type="document", model=self.model - ) - embedding_list = [item["embedding"] for item in embeddings] - return embedding_list - - def embed_query(self, text: str) -> List[float]: - """Call out to DashScope's embedding endpoint for embedding query text. - - Args: - text: The text to embed. - - Returns: - Embedding for the text. - """ - embedding = embed_with_retry( - self, input=[text], text_type="document", model=self.model - )[0]["embedding"] - return embedding + def embed_query(self, text: str): + res = self.embed_documents([text]) + return res[0] + + def embed_documents( + self, texts: List[str], chunk_size: int | None = None + ) -> List[List[float]]: + if len(self.optional_params) > 0: + res = self.client.create( + input=texts, model=self.model_name, encoding_format="float", + **self.optional_params + ) + else: + res = self.client.create(input=texts, model=self.model_name, encoding_format="float") + return [e.embedding for e in res.data] diff --git a/apps/models_provider/impl/openai_model_provider/model/embedding.py b/apps/models_provider/impl/openai_model_provider/model/embedding.py index 3a0aaeeb1a4..7362e8fcd32 100644 --- a/apps/models_provider/impl/openai_model_provider/model/embedding.py +++ b/apps/models_provider/impl/openai_model_provider/model/embedding.py @@ -15,17 +15,21 @@ class OpenAIEmbeddingModel(MaxKBBaseModel): model_name: str + optional_params: dict - def __init__(self, api_key, base_url, model_name: str): + def __init__(self, api_key, base_url, model_name: str, optional_params: dict): self.client = openai.OpenAI(api_key=api_key, base_url=base_url).embeddings self.model_name = model_name + self.optional_params = optional_params @staticmethod def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs): + optional_params = MaxKBBaseModel.filter_optional_params(model_kwargs) return OpenAIEmbeddingModel( api_key=model_credential.get('api_key'), model_name=model_name, base_url=model_credential.get('api_base'), + optional_params=optional_params ) def embed_query(self, text: str): @@ -35,5 +39,11 @@ def embed_query(self, text: str): def embed_documents( self, texts: List[str], chunk_size: int | None = None ) -> List[List[float]]: - res = self.client.create(input=texts, model=self.model_name, encoding_format="float") + if len(self.optional_params) > 0: + res = self.client.create( + input=texts, model=self.model_name, encoding_format="float", + **self.optional_params + ) + else: + res = self.client.create(input=texts, model=self.model_name, encoding_format="float") return [e.embedding for e in res.data] diff --git a/ui/src/views/model/component/CreateModelDialog.vue b/ui/src/views/model/component/CreateModelDialog.vue index d3791947d27..78132e173b1 100644 --- a/ui/src/views/model/component/CreateModelDialog.vue +++ b/ui/src/views/model/component/CreateModelDialog.vue @@ -140,8 +140,7 @@ /> @@ -150,7 +149,7 @@ {{ $t('common.add') }} diff --git a/ui/src/views/model/component/ModelCard.vue b/ui/src/views/model/component/ModelCard.vue index dffca98c755..006475a6525 100644 --- a/ui/src/views/model/component/ModelCard.vue +++ b/ui/src/views/model/component/ModelCard.vue @@ -95,6 +95,7 @@ currentModel.model_type === 'IMAGE' || currentModel.model_type === 'TTI' || currentModel.model_type === 'ITV' || + currentModel.model_type === 'EMBEDDING' || currentModel.model_type === 'TTV') && permissionPrecise.paramSetting(model.id) "