forked from mlflow/mlflow
-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add Mistral AI as a new provider in LLM Deployment (mlflow#11020)
Signed-off-by: Ngo Minh Thang Nguyen <thangnguyen.dn.1991@gmail.com> Signed-off-by: Ben Wilson <39283302+BenWilson2@users.noreply.github.com> Co-authored-by: Ngo Minh Thang Nguyen <thangnguyen.dn.1991@gmail.com> Co-authored-by: Ben Wilson <39283302+BenWilson2@users.noreply.github.com>
- Loading branch information
1 parent
ecbdb8c
commit c20511f
Showing
11 changed files
with
637 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
## Example endpoint configuration for Mistral | ||
|
||
To see an example of specifying both the completions and the embeddings endpoints for Mistral, see [the configuration](config.yaml) YAML file. | ||
|
||
This configuration file specifies two endpoints: 'completions' and 'embeddings', both using Mistral's models 'mistral-tiny' and 'mistral-embed', respectively. | ||
|
||
## Setting a Mistral API Key | ||
|
||
This example requires a [Mistral API key](https://docs.mistral.ai/): | ||
|
||
```sh | ||
export MISTRAL_API_KEY=... | ||
``` |
16 changes: 16 additions & 0 deletions
16
examples/deployments/deployments_server/mistral/config.yaml
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
endpoints: | ||
- name: completions | ||
endpoint_type: llm/v1/completions | ||
model: | ||
provider: mistral | ||
name: mistral-tiny | ||
config: | ||
mistral_api_key: $MISTRAL_API_KEY | ||
|
||
- name: embeddings | ||
endpoint_type: llm/v1/embeddings | ||
model: | ||
provider: mistral | ||
name: mistral-embed | ||
config: | ||
mistral_api_key: $MISTRAL_API_KEY |
34 changes: 34 additions & 0 deletions
34
examples/deployments/deployments_server/mistral/example.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
from mlflow.deployments import get_deploy_client | ||
|
||
|
||
def main(): | ||
client = get_deploy_client("http://localhost:7000") | ||
|
||
print(f"Mistral endpoints: {client.list_endpoints()}\n") | ||
print(f"Mistral completions endpoint info: {client.get_endpoint(endpoint='completions')}\n") | ||
|
||
# Completions request | ||
response_completions = client.predict( | ||
endpoint="completions", | ||
inputs={ | ||
"prompt": "How many average size European ferrets can fit inside a standard olympic?", | ||
"temperature": 0.1, | ||
}, | ||
) | ||
print(f"Mistral response for completions: {response_completions}") | ||
|
||
# Embeddings request | ||
response_embeddings = client.predict( | ||
endpoint="embeddings", | ||
inputs={ | ||
"input": [ | ||
"How does your culture celebrate the New Year, and how does it differ from other countries’ " | ||
"celebrations?" | ||
] | ||
}, | ||
) | ||
print(f"Mistral response for embeddings: {response_embeddings}") | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,164 @@ | ||
import time | ||
from typing import Any, Dict | ||
|
||
from fastapi.encoders import jsonable_encoder | ||
|
||
from mlflow.gateway.config import MistralConfig, RouteConfig | ||
from mlflow.gateway.providers.base import BaseProvider, ProviderAdapter | ||
from mlflow.gateway.providers.utils import send_request | ||
from mlflow.gateway.schemas import completions, embeddings | ||
|
||
|
||
class MistralAdapter(ProviderAdapter): | ||
@classmethod | ||
def model_to_completions(cls, resp, config): | ||
# Response example (https://docs.mistral.ai/api/#operation/createChatCompletion) | ||
# ``` | ||
# { | ||
# "id": "string", | ||
# "object": "string", | ||
# "created": "integer", | ||
# "model": "string", | ||
# "choices": [ | ||
# { | ||
# "index": "integer", | ||
# "message": { | ||
# "role": "string", | ||
# "content": "string" | ||
# }, | ||
# "finish_reason": "string", | ||
# } | ||
# ], | ||
# "usage": | ||
# { | ||
# "prompt_tokens": "integer", | ||
# "completion_tokens": "integer", | ||
# "total_tokens": "integer", | ||
# } | ||
# } | ||
# ``` | ||
return completions.ResponsePayload( | ||
created=int(time.time()), | ||
object="text_completion", | ||
model=config.model.name, | ||
choices=[ | ||
completions.Choice( | ||
index=idx, | ||
text=c["message"]["content"], | ||
finish_reason=c["finish_reason"], | ||
) | ||
for idx, c in enumerate(resp["choices"]) | ||
], | ||
usage=completions.CompletionsUsage( | ||
prompt_tokens=resp["usage"]["prompt_tokens"], | ||
completion_tokens=resp["usage"]["completion_tokens"], | ||
total_tokens=resp["usage"]["total_tokens"], | ||
), | ||
) | ||
|
||
@classmethod | ||
def model_to_embeddings(cls, resp, config): | ||
# Response example (https://docs.mistral.ai/api/#operation/createEmbedding): | ||
# ``` | ||
# { | ||
# "id": "string", | ||
# "object": "string", | ||
# "data": [ | ||
# { | ||
# "object": "string", | ||
# "embedding": | ||
# [ | ||
# float, | ||
# float | ||
# ] | ||
# "index": "integer", | ||
# } | ||
# ], | ||
# "model": "string", | ||
# "usage": | ||
# { | ||
# "prompt_tokens": "integer", | ||
# "total_tokens": "integer", | ||
# } | ||
# } | ||
# ``` | ||
return embeddings.ResponsePayload( | ||
data=[ | ||
embeddings.EmbeddingObject( | ||
embedding=data["embedding"], | ||
index=data["index"], | ||
) | ||
for data in resp["data"] | ||
], | ||
model=config.model.name, | ||
usage=embeddings.EmbeddingsUsage( | ||
prompt_tokens=resp["usage"]["prompt_tokens"], | ||
total_tokens=resp["usage"]["total_tokens"], | ||
), | ||
) | ||
|
||
@classmethod | ||
def completions_to_model(cls, payload, config): | ||
payload.pop("stop", None) | ||
payload.pop("n", None) | ||
payload["messages"] = [{"role": "user", "content": payload.pop("prompt")}] | ||
|
||
# The range of Mistral's temperature is 0-1, but ours is 0-2, so we scale it. | ||
if "temperature" in payload: | ||
payload["temperature"] = 0.5 * payload["temperature"] | ||
|
||
return payload | ||
|
||
@classmethod | ||
def embeddings_to_model(cls, payload, config): | ||
return payload | ||
|
||
|
||
class MistralProvider(BaseProvider): | ||
NAME = "Mistral" | ||
|
||
def __init__(self, config: RouteConfig) -> None: | ||
super().__init__(config) | ||
if config.model.config is None or not isinstance(config.model.config, MistralConfig): | ||
raise TypeError(f"Unexpected config type {config.model.config}") | ||
self.mistral_config: MistralConfig = config.model.config | ||
|
||
@property | ||
def auth_headers(self) -> Dict[str, str]: | ||
return {"Authorization": f"Bearer {self.mistral_config.mistral_api_key}"} | ||
|
||
@property | ||
def base_url(self) -> str: | ||
return "https://api.mistral.ai/v1/" | ||
|
||
async def _request(self, path: str, payload: Dict[str, Any]) -> Dict[str, Any]: | ||
return await send_request( | ||
headers=self.auth_headers, | ||
base_url=self.base_url, | ||
path=path, | ||
payload=payload, | ||
) | ||
|
||
async def completions(self, payload: completions.RequestPayload) -> completions.ResponsePayload: | ||
payload = jsonable_encoder(payload, exclude_none=True) | ||
self.check_for_model_field(payload) | ||
resp = await self._request( | ||
"chat/completions", | ||
{ | ||
"model": self.config.model.name, | ||
**MistralAdapter.completions_to_model(payload, self.config), | ||
}, | ||
) | ||
return MistralAdapter.model_to_completions(resp, self.config) | ||
|
||
async def embeddings(self, payload: embeddings.RequestPayload) -> embeddings.ResponsePayload: | ||
payload = jsonable_encoder(payload, exclude_none=True) | ||
self.check_for_model_field(payload) | ||
resp = await self._request( | ||
"embeddings", | ||
{ | ||
"model": self.config.model.name, | ||
**MistralAdapter.embeddings_to_model(payload, self.config), | ||
}, | ||
) | ||
return MistralAdapter.model_to_embeddings(resp, self.config) |
Oops, something went wrong.