Skip to content

Commit

Permalink
Add Mistral AI as a new provider in LLM Deployment (mlflow#11020)
Browse files Browse the repository at this point in the history
Signed-off-by: Ngo Minh Thang Nguyen <thangnguyen.dn.1991@gmail.com>
Signed-off-by: Ben Wilson <39283302+BenWilson2@users.noreply.github.com>
Co-authored-by: Ngo Minh Thang Nguyen <thangnguyen.dn.1991@gmail.com>
Co-authored-by: Ben Wilson <39283302+BenWilson2@users.noreply.github.com>
  • Loading branch information
3 people authored and sateeshmannar committed Feb 20, 2024
1 parent ecbdb8c commit c20511f
Show file tree
Hide file tree
Showing 11 changed files with 637 additions and 1 deletion.
15 changes: 15 additions & 0 deletions docs/source/llms/deployments/index.rst
Expand Up @@ -259,6 +259,9 @@ below can be used as a helpful guide when configuring a given endpoint for any n
| AWS Bedrock | - Amazon Titan | N/A | N/A |
| | - Third-party providers | | |
+--------------------------+--------------------------+--------------------------+--------------------------+
| Mistral | - mistral-tiny | N/A | - mistral-embed |
| | - mistral-small | | |
+--------------------------+--------------------------+--------------------------+--------------------------+

§ For full compatibility references for ``OpenAI``, see the `OpenAI Model Compatibility Matrix <https://platform.openai.com/docs/models/model-endpoint-compatibility>`_.

Expand Down Expand Up @@ -304,6 +307,7 @@ As of now, the MLflow Deployments Server supports the following providers:
* **huggingface text generation inference**: This is used for models deployed using `Huggingface Text Generation Inference <https://huggingface.co/docs/text-generation-inference/index>`_.
* **ai21labs**: This is used for models offered by `AI21 Labs <https://studio.ai21.com/foundation-models>`_.
* **bedrock**: This is used for models offered by `AWS Bedrock <https://aws.amazon.com/bedrock/>`_.
* **mistral**: This is used for models offered by `Mistral <https://docs.mistral.ai/>`_.

More providers are being added continually. Check the latest version of the MLflow Deployments Server Docs for the
most up-to-date list of supported providers.
Expand Down Expand Up @@ -512,6 +516,7 @@ Each endpoint has the following configuration parameters:
- "huggingface-text-generation-inference"
- "ai21labs"
- "bedrock"
- "mistral"

- **name**: This is an optional field to specify the name of the model.
- **config**: This contains provider-specific configuration details.
Expand Down Expand Up @@ -683,6 +688,16 @@ To match your user's interaction and security access requirements, adjust the ``
+----------------------------+----------+---------+-----------------------------------------------------------------------------------------------+


Mistral
+++++++

+--------------------------+----------+--------------------------+-------------------------------------------------------+
| Configuration Parameter | Required | Default | Description |
+==========================+==========+==========================+=======================================================+
| **mistral_api_key** | Yes | N/A | This is the API key for the Mistral service. |
+--------------------------+----------+--------------------------+-------------------------------------------------------+


An example configuration for Azure OpenAI is:

.. code-block:: yaml
Expand Down
15 changes: 15 additions & 0 deletions docs/source/llms/gateway/index.rst
Expand Up @@ -302,6 +302,9 @@ below can be used as a helpful guide when configuring a given route for any newl
| AWS Bedrock | - Amazon Titan | N/A | N/A |
| | - Third-party providers | | |
+--------------------------+--------------------------+--------------------------+--------------------------+
| Mistral | - mistral-tiny | N/A | - mistral-embed |
| | - mistral-small | | |
+--------------------------+--------------------------+--------------------------+--------------------------+


† Llama 2 is licensed under the `LLAMA 2 Community License <https://ai.meta.com/llama/license/>`_, Copyright © Meta Platforms, Inc. All Rights Reserved.
Expand Down Expand Up @@ -343,6 +346,7 @@ As of now, the MLflow AI Gateway supports the following providers:
* **huggingface text generation inference**: This is used for models deployed using `Huggingface Text Generation Inference <https://huggingface.co/docs/text-generation-inference/index>`_.
* **ai21labs**: This is used for models offered by `AI21 Labs <https://studio.ai21.com/foundation-models>`_.
* **bedrock**: This is used for models offered by `AWS Bedrock <https://aws.amazon.com/bedrock/>`_.
* **mistral**: This is used for models offered by `Mistral <https://docs.mistral.ai/>`_.

More providers are being added continually. Check the latest version of the MLflow AI Gateway Docs for the
most up-to-date list of supported providers.
Expand Down Expand Up @@ -540,6 +544,7 @@ Each route has the following configuration parameters:
- "huggingface-text-generation-inference"
- "ai21labs"
- "bedrock"
- "mistral"

- **name**: This is an optional field to specify the name of the model.
- **config**: This contains provider-specific configuration details.
Expand Down Expand Up @@ -639,6 +644,16 @@ Top-level model configuration for AWS Bedrock routes must be one of the followin
+--------------------------+----------+------------------------------+-------------------------------------------------------+


Mistral
+++++++

+--------------------------+----------+--------------------------+-------------------------------------------------------+
| Configuration Parameter | Required | Default | Description |
+==========================+==========+==========================+=======================================================+
| **mistral_api_key** | Yes | N/A | This is the API key for the Mistral service. |
+--------------------------+----------+--------------------------+-------------------------------------------------------+


To use key-based authentication, define an AWS Bedrock route with the required fields below.
.. note::

Expand Down
1 change: 1 addition & 0 deletions examples/deployments/deployments_server/README.md
Expand Up @@ -47,6 +47,7 @@ For full examples of configurations and supported endpoint types, see:
- [AI21 Labs](ai21labs/config.yaml)
- [PaLM](palm/config.yaml)
- [AzureOpenAI](azure_openai/config.yaml)
- [Mistral](mistral/config.yaml)

## Step 3: Setting Access Keys

Expand Down
13 changes: 13 additions & 0 deletions examples/deployments/deployments_server/mistral/README.md
@@ -0,0 +1,13 @@
## Example endpoint configuration for Mistral

To see an example of specifying both the completions and the embeddings endpoints for Mistral, see [the configuration](config.yaml) YAML file.

This configuration file specifies two endpoints: 'completions' and 'embeddings', both using Mistral's models 'mistral-tiny' and 'mistral-embed', respectively.

## Setting a Mistral API Key

This example requires a [Mistral API key](https://docs.mistral.ai/):

```sh
export MISTRAL_API_KEY=...
```
16 changes: 16 additions & 0 deletions examples/deployments/deployments_server/mistral/config.yaml
@@ -0,0 +1,16 @@
endpoints:
- name: completions
endpoint_type: llm/v1/completions
model:
provider: mistral
name: mistral-tiny
config:
mistral_api_key: $MISTRAL_API_KEY

- name: embeddings
endpoint_type: llm/v1/embeddings
model:
provider: mistral
name: mistral-embed
config:
mistral_api_key: $MISTRAL_API_KEY
34 changes: 34 additions & 0 deletions examples/deployments/deployments_server/mistral/example.py
@@ -0,0 +1,34 @@
from mlflow.deployments import get_deploy_client


def main():
client = get_deploy_client("http://localhost:7000")

print(f"Mistral endpoints: {client.list_endpoints()}\n")
print(f"Mistral completions endpoint info: {client.get_endpoint(endpoint='completions')}\n")

# Completions request
response_completions = client.predict(
endpoint="completions",
inputs={
"prompt": "How many average size European ferrets can fit inside a standard olympic?",
"temperature": 0.1,
},
)
print(f"Mistral response for completions: {response_completions}")

# Embeddings request
response_embeddings = client.predict(
endpoint="embeddings",
inputs={
"input": [
"How does your culture celebrate the New Year, and how does it differ from other countries’ "
"celebrations?"
]
},
)
print(f"Mistral response for embeddings: {response_embeddings}")


if __name__ == "__main__":
main()
12 changes: 12 additions & 0 deletions mlflow/gateway/config.py
Expand Up @@ -46,6 +46,7 @@ class Provider(str, Enum):
# Note: The following providers are only supported on Databricks
DATABRICKS_MODEL_SERVING = "databricks-model-serving"
DATABRICKS = "databricks"
MISTRAL = "mistral"

@classmethod
def values(cls):
Expand Down Expand Up @@ -216,6 +217,15 @@ class AWSBedrockConfig(ConfigModel):
aws_config: Union[AWSRole, AWSIdAndKey, AWSBaseConfig]


class MistralConfig(ConfigModel):
mistral_api_key: str

# pylint: disable=no-self-argument
@validator("mistral_api_key", pre=True)
def validate_mistral_api_key(cls, value):
return _resolve_api_key_from_input(value)


config_types = {
Provider.COHERE: CohereConfig,
Provider.OPENAI: OpenAIConfig,
Expand All @@ -226,6 +236,7 @@ class AWSBedrockConfig(ConfigModel):
Provider.MLFLOW_MODEL_SERVING: MlflowModelServingConfig,
Provider.PALM: PaLMConfig,
Provider.HUGGINGFACE_TEXT_GENERATION_INFERENCE: HuggingFaceTextGenerationInferenceConfig,
Provider.MISTRAL: MistralConfig,
}


Expand Down Expand Up @@ -285,6 +296,7 @@ class Model(ConfigModel):
MlflowModelServingConfig,
HuggingFaceTextGenerationInferenceConfig,
PaLMConfig,
MistralConfig,
]
] = None

Expand Down
2 changes: 2 additions & 0 deletions mlflow/gateway/providers/__init__.py
Expand Up @@ -11,6 +11,7 @@ def get_provider(provider: Provider) -> Type[BaseProvider]:
from mlflow.gateway.providers.bedrock import AWSBedrockProvider
from mlflow.gateway.providers.cohere import CohereProvider
from mlflow.gateway.providers.huggingface import HFTextGenerationInferenceServerProvider
from mlflow.gateway.providers.mistral import MistralProvider
from mlflow.gateway.providers.mlflow import MlflowModelServingProvider
from mlflow.gateway.providers.mosaicml import MosaicMLProvider
from mlflow.gateway.providers.openai import OpenAIProvider
Expand All @@ -26,6 +27,7 @@ def get_provider(provider: Provider) -> Type[BaseProvider]:
Provider.MLFLOW_MODEL_SERVING: MlflowModelServingProvider,
Provider.HUGGINGFACE_TEXT_GENERATION_INFERENCE: HFTextGenerationInferenceServerProvider,
Provider.BEDROCK: AWSBedrockProvider,
Provider.MISTRAL: MistralProvider,
}
if prov := provider_to_class.get(provider):
return prov
Expand Down
164 changes: 164 additions & 0 deletions mlflow/gateway/providers/mistral.py
@@ -0,0 +1,164 @@
import time
from typing import Any, Dict

from fastapi.encoders import jsonable_encoder

from mlflow.gateway.config import MistralConfig, RouteConfig
from mlflow.gateway.providers.base import BaseProvider, ProviderAdapter
from mlflow.gateway.providers.utils import send_request
from mlflow.gateway.schemas import completions, embeddings


class MistralAdapter(ProviderAdapter):
@classmethod
def model_to_completions(cls, resp, config):
# Response example (https://docs.mistral.ai/api/#operation/createChatCompletion)
# ```
# {
# "id": "string",
# "object": "string",
# "created": "integer",
# "model": "string",
# "choices": [
# {
# "index": "integer",
# "message": {
# "role": "string",
# "content": "string"
# },
# "finish_reason": "string",
# }
# ],
# "usage":
# {
# "prompt_tokens": "integer",
# "completion_tokens": "integer",
# "total_tokens": "integer",
# }
# }
# ```
return completions.ResponsePayload(
created=int(time.time()),
object="text_completion",
model=config.model.name,
choices=[
completions.Choice(
index=idx,
text=c["message"]["content"],
finish_reason=c["finish_reason"],
)
for idx, c in enumerate(resp["choices"])
],
usage=completions.CompletionsUsage(
prompt_tokens=resp["usage"]["prompt_tokens"],
completion_tokens=resp["usage"]["completion_tokens"],
total_tokens=resp["usage"]["total_tokens"],
),
)

@classmethod
def model_to_embeddings(cls, resp, config):
# Response example (https://docs.mistral.ai/api/#operation/createEmbedding):
# ```
# {
# "id": "string",
# "object": "string",
# "data": [
# {
# "object": "string",
# "embedding":
# [
# float,
# float
# ]
# "index": "integer",
# }
# ],
# "model": "string",
# "usage":
# {
# "prompt_tokens": "integer",
# "total_tokens": "integer",
# }
# }
# ```
return embeddings.ResponsePayload(
data=[
embeddings.EmbeddingObject(
embedding=data["embedding"],
index=data["index"],
)
for data in resp["data"]
],
model=config.model.name,
usage=embeddings.EmbeddingsUsage(
prompt_tokens=resp["usage"]["prompt_tokens"],
total_tokens=resp["usage"]["total_tokens"],
),
)

@classmethod
def completions_to_model(cls, payload, config):
payload.pop("stop", None)
payload.pop("n", None)
payload["messages"] = [{"role": "user", "content": payload.pop("prompt")}]

# The range of Mistral's temperature is 0-1, but ours is 0-2, so we scale it.
if "temperature" in payload:
payload["temperature"] = 0.5 * payload["temperature"]

return payload

@classmethod
def embeddings_to_model(cls, payload, config):
return payload


class MistralProvider(BaseProvider):
NAME = "Mistral"

def __init__(self, config: RouteConfig) -> None:
super().__init__(config)
if config.model.config is None or not isinstance(config.model.config, MistralConfig):
raise TypeError(f"Unexpected config type {config.model.config}")
self.mistral_config: MistralConfig = config.model.config

@property
def auth_headers(self) -> Dict[str, str]:
return {"Authorization": f"Bearer {self.mistral_config.mistral_api_key}"}

@property
def base_url(self) -> str:
return "https://api.mistral.ai/v1/"

async def _request(self, path: str, payload: Dict[str, Any]) -> Dict[str, Any]:
return await send_request(
headers=self.auth_headers,
base_url=self.base_url,
path=path,
payload=payload,
)

async def completions(self, payload: completions.RequestPayload) -> completions.ResponsePayload:
payload = jsonable_encoder(payload, exclude_none=True)
self.check_for_model_field(payload)
resp = await self._request(
"chat/completions",
{
"model": self.config.model.name,
**MistralAdapter.completions_to_model(payload, self.config),
},
)
return MistralAdapter.model_to_completions(resp, self.config)

async def embeddings(self, payload: embeddings.RequestPayload) -> embeddings.ResponsePayload:
payload = jsonable_encoder(payload, exclude_none=True)
self.check_for_model_field(payload)
resp = await self._request(
"embeddings",
{
"model": self.config.model.name,
**MistralAdapter.embeddings_to_model(payload, self.config),
},
)
return MistralAdapter.model_to_embeddings(resp, self.config)

0 comments on commit c20511f

Please sign in to comment.