From f145151d64aa42563ef68c90292a362b3b8ee2ec Mon Sep 17 00:00:00 2001 From: wxg0103 <727495428@qq.com> Date: Wed, 6 Nov 2024 18:59:47 +0800 Subject: [PATCH] =?UTF-8?q?fix:=20=E4=BF=AE=E5=A4=8Dxinference=E5=90=91?= =?UTF-8?q?=E9=87=8F=E6=A8=A1=E5=9E=8B=E6=B7=BB=E5=8A=A0=E5=A4=B1=E8=B4=A5?= =?UTF-8?q?=E7=9A=84=E7=BC=BA=E9=99=B7?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../credential/embedding.py | 4 +- .../model/embedding.py | 72 ++++++++++++++++++- 2 files changed, 73 insertions(+), 3 deletions(-) diff --git a/apps/setting/models_provider/impl/xinference_model_provider/credential/embedding.py b/apps/setting/models_provider/impl/xinference_model_provider/credential/embedding.py index 200183e6c03..7cddb4f09da 100644 --- a/apps/setting/models_provider/impl/xinference_model_provider/credential/embedding.py +++ b/apps/setting/models_provider/impl/xinference_model_provider/credential/embedding.py @@ -15,7 +15,8 @@ def is_valid(self, model_type: str, model_name, model_credential: Dict[str, obje if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))): raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持') try: - model_list = provider.get_base_model_list(model_credential.get('api_base'), 'embedding') + model_list = provider.get_base_model_list(model_credential.get('api_base'), model_credential.get('api_key'), + 'embedding') except Exception as e: raise AppApiException(ValidCode.valid_error.value, "API 域名无效") exist = provider.get_model_info_by_name(model_list, model_name) @@ -36,3 +37,4 @@ def build_model(self, model_info: Dict[str, object]): return self api_base = forms.TextInputField('API 域名', required=True) + api_key = forms.PasswordInputField('API Key', required=True) diff --git a/apps/setting/models_provider/impl/xinference_model_provider/model/embedding.py b/apps/setting/models_provider/impl/xinference_model_provider/model/embedding.py index 1cf34aaf875..935f4d23919 100644 --- a/apps/setting/models_provider/impl/xinference_model_provider/model/embedding.py +++ b/apps/setting/models_provider/impl/xinference_model_provider/model/embedding.py @@ -1,18 +1,26 @@ # coding=utf-8 import threading -from typing import Dict +from typing import Dict, Optional, List, Any from langchain_community.embeddings import XinferenceEmbeddings +from langchain_core.embeddings import Embeddings from setting.models_provider.base_model_provider import MaxKBBaseModel -class XinferenceEmbedding(MaxKBBaseModel, XinferenceEmbeddings): +class XinferenceEmbedding(MaxKBBaseModel, Embeddings): + client: Any + server_url: Optional[str] + """URL of the xinference server""" + model_uid: Optional[str] + """UID of the launched model""" + @staticmethod def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs): return XinferenceEmbedding( model_uid=model_name, server_url=model_credential.get('api_base'), + api_key=model_credential.get('api_key'), ) def down_model(self): @@ -22,3 +30,63 @@ def start_down_model_thread(self): thread = threading.Thread(target=self.down_model) thread.daemon = True thread.start() + + def __init__( + self, server_url: Optional[str] = None, model_uid: Optional[str] = None, + api_key: Optional[str] = None + ): + try: + from xinference.client import RESTfulClient + except ImportError: + try: + from xinference_client import RESTfulClient + except ImportError as e: + raise ImportError( + "Could not import RESTfulClient from xinference. Please install it" + " with `pip install xinference` or `pip install xinference_client`." + ) from e + + if server_url is None: + raise ValueError("Please provide server URL") + + if model_uid is None: + raise ValueError("Please provide the model UID") + + self.server_url = server_url + + self.model_uid = model_uid + + self.api_key = api_key + + self.client = RESTfulClient(server_url, api_key) + + def embed_documents(self, texts: List[str]) -> List[List[float]]: + """Embed a list of documents using Xinference. + Args: + texts: The list of texts to embed. + Returns: + List of embeddings, one for each text. + """ + + model = self.client.get_model(self.model_uid) + + embeddings = [ + model.create_embedding(text)["data"][0]["embedding"] for text in texts + ] + return [list(map(float, e)) for e in embeddings] + + def embed_query(self, text: str) -> List[float]: + """Embed a query of documents using Xinference. + Args: + text: The text to embed. + Returns: + Embeddings for the text. + """ + + model = self.client.get_model(self.model_uid) + + embedding_res = model.create_embedding(text) + + embedding = embedding_res["data"][0]["embedding"] + + return list(map(float, embedding))