Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,27 @@

from common import forms
from common.exception.app_exception import AppApiException
from common.forms import BaseForm
from common.forms import BaseForm, TooltipLabel
from models_provider.base_model_provider import BaseModelCredential, ValidCode
from models_provider.impl.aliyun_bai_lian_model_provider.model.embedding import AliyunBaiLianEmbedding

class BaiLianEmbeddingModelParams(BaseForm):
dimensions = forms.SingleSelect(
TooltipLabel(
_('Dimensions'),
_('')
),
required=True,
default_value=1024,
value_field='value',
text_field='label',
option_list=[
{'label': '1024', 'value': '1024'},
{'label': '768', 'value': '768'},
{'label': '512', 'value': '512'},
]
)


class AliyunBaiLianEmbeddingCredential(BaseForm, BaseModelCredential):

Expand Down Expand Up @@ -71,4 +88,8 @@ def encryption_dict(self, model: Dict[str, Any]) -> Dict[str, Any]:
api_key = model.get('dashscope_api_key', '')
return {**model, 'dashscope_api_key': super().encryption(api_key)}


def get_model_params_setting_form(self, model_name):
return BaiLianEmbeddingModelParams()

dashscope_api_key = forms.PasswordInputField('API Key', required=True)
Original file line number Diff line number Diff line change
Expand Up @@ -6,61 +6,43 @@
@date:2024/10/16 16:34
@desc:
"""
from functools import reduce
from typing import Dict, List

from langchain_community.embeddings import DashScopeEmbeddings
from langchain_community.embeddings.dashscope import embed_with_retry
from openai import OpenAI

from models_provider.base_model_provider import MaxKBBaseModel


def proxy_embed_documents(texts: List[str], step_size, embed_documents):
value = [embed_documents(texts[start_index:start_index + step_size]) for start_index in
range(0, len(texts), step_size)]
return reduce(lambda x, y: [*x, *y], value, [])
class AliyunBaiLianEmbedding(MaxKBBaseModel):
model_name: str
optional_params: dict

def __init__(self, api_key, model_name: str, optional_params: dict):
self.client = OpenAI(api_key=api_key, base_url='https://dashscope.aliyuncs.com/compatible-mode/v1').embeddings
self.model_name = model_name
self.optional_params = optional_params

class AliyunBaiLianEmbedding(MaxKBBaseModel, DashScopeEmbeddings):
@staticmethod
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
optional_params = MaxKBBaseModel.filter_optional_params(model_kwargs)
return AliyunBaiLianEmbedding(
model=model_name,
dashscope_api_key=model_credential.get('dashscope_api_key')
api_key=model_credential.get('dashscope_api_key'),
model_name=model_name,
optional_params=optional_params
)

def embed_documents(self, texts: List[str]) -> List[List[float]]:
if self.model == 'text-embedding-v3':
return proxy_embed_documents(texts, 6, self._embed_documents)
return self._embed_documents(texts)

def _embed_documents(self, texts: List[str]) -> List[List[float]]:
"""Call out to DashScope's embedding endpoint for embedding search docs.
Args:
texts: The list of texts to embed.
chunk_size: The chunk size of embeddings. If None, will use the chunk size
specified by the class.
Returns:
List of embeddings, one for each text.
"""
embeddings = embed_with_retry(
self, input=texts, text_type="document", model=self.model
)
embedding_list = [item["embedding"] for item in embeddings]
return embedding_list

def embed_query(self, text: str) -> List[float]:
"""Call out to DashScope's embedding endpoint for embedding query text.
Args:
text: The text to embed.
Returns:
Embedding for the text.
"""
embedding = embed_with_retry(
self, input=[text], text_type="document", model=self.model
)[0]["embedding"]
return embedding
def embed_query(self, text: str):
res = self.embed_documents([text])
return res[0]

def embed_documents(
self, texts: List[str], chunk_size: int | None = None
) -> List[List[float]]:
if len(self.optional_params) > 0:
res = self.client.create(
input=texts, model=self.model_name, encoding_format="float",
**self.optional_params
)
else:
res = self.client.create(input=texts, model=self.model_name, encoding_format="float")
return [e.embedding for e in res.data]
14 changes: 12 additions & 2 deletions apps/models_provider/impl/openai_model_provider/model/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,17 +15,21 @@

class OpenAIEmbeddingModel(MaxKBBaseModel):
model_name: str
optional_params: dict

def __init__(self, api_key, base_url, model_name: str):
def __init__(self, api_key, base_url, model_name: str, optional_params: dict):
self.client = openai.OpenAI(api_key=api_key, base_url=base_url).embeddings
self.model_name = model_name
self.optional_params = optional_params

@staticmethod
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
optional_params = MaxKBBaseModel.filter_optional_params(model_kwargs)
return OpenAIEmbeddingModel(
api_key=model_credential.get('api_key'),
model_name=model_name,
base_url=model_credential.get('api_base'),
optional_params=optional_params
)

def embed_query(self, text: str):
Expand All @@ -35,5 +39,11 @@ def embed_query(self, text: str):
def embed_documents(
self, texts: List[str], chunk_size: int | None = None
) -> List[List[float]]:
res = self.client.create(input=texts, model=self.model_name, encoding_format="float")
if len(self.optional_params) > 0:
res = self.client.create(
input=texts, model=self.model_name, encoding_format="float",
**self.optional_params
)
else:
res = self.client.create(input=texts, model=self.model_name, encoding_format="float")
return [e.embedding for e in res.data]
5 changes: 2 additions & 3 deletions ui/src/views/model/component/CreateModelDialog.vue
Original file line number Diff line number Diff line change
Expand Up @@ -140,8 +140,7 @@
/>
<el-empty
v-else-if="
base_form_data.model_type === 'RERANKER' ||
base_form_data.model_type === 'EMBEDDING'
base_form_data.model_type === 'RERANKER'
"
:description="$t('views.model.tip.emptyMessage2')"
/>
Expand All @@ -150,7 +149,7 @@
<el-button
type="text"
@click.stop="openAddDrawer()"
:disabled="!['TTS', 'LLM', 'IMAGE', 'TTI', 'TTV', 'ITV','STT'].includes(base_form_data.model_type)"
:disabled="!['TTS', 'LLM', 'IMAGE', 'TTI', 'TTV', 'ITV','STT', 'EMBEDDING'].includes(base_form_data.model_type)"
>
<AppIcon iconName="app-add-outlined" class="mr-4"/> {{ $t('common.add') }}
</el-button>
Expand Down
1 change: 1 addition & 0 deletions ui/src/views/model/component/ModelCard.vue
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@
currentModel.model_type === 'IMAGE' ||
currentModel.model_type === 'TTI' ||
currentModel.model_type === 'ITV' ||
currentModel.model_type === 'EMBEDDING' ||
currentModel.model_type === 'TTV') &&
permissionPrecise.paramSetting(model.id)
"
Expand Down
Loading