Skip to content

Commit

Permalink
Add AnyscaleLLM (#447)
Browse files Browse the repository at this point in the history
* Add `AnyscaleLLM`
Add `tests`

* Update `base_url` behavior

* Update tests/unit/llm/test_anyscale.py

Co-authored-by: Alvaro Bartolome <alvaro@argilla.io>

* Update src/distilabel/llm/anyscale.py

Co-authored-by: Alvaro Bartolome <alvaro@argilla.io>

* Update tests/unit/llm/test_anyscale.py

Co-authored-by: Alvaro Bartolome <alvaro@argilla.io>

* Update src/distilabel/llm/anyscale.py

Co-authored-by: Alvaro Bartolome <alvaro@argilla.io>

* Removed default `model` from `LLMs`

* Update `tests`

* Added reference to `google/gemma-7b-it`

* Fix pass `OpenAILLM` `model` explicitely

* Remove `load` method

---------

Co-authored-by: Alvaro Bartolome <alvaro@argilla.io>
  • Loading branch information
davidberenstein1957 and alvarobartt committed Mar 20, 2024
1 parent 7513529 commit 87afbc5
Show file tree
Hide file tree
Showing 5 changed files with 74 additions and 3 deletions.
29 changes: 29 additions & 0 deletions src/distilabel/llm/anyscale.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
# Copyright 2023-present, Argilla, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


from distilabel.llm.openai import OpenAILLM


class AnyscaleLLM(OpenAILLM):
"""
Anyscale LLM implementation running the async API client of OpenAI because of duplicate API behavior.
Attributes:
model: the model name to use for the LLM, e.g., `google/gemma-7b-it`. [Supported models](https://docs.endpoints.anyscale.com/text-generation/supported-models/google-gemma-7b-it).
base_url: the base URL to use for the Anyscale API can be set with `OPENAI_BASE_URL`. Default is "https://api.endpoints.anyscale.com/v1".
api_key: the API key to authenticate the requests to the Anyscale API. Can be set with `OPENAI_API_KEY`. Default is `None`.
"""

base_url: str = "https://api.endpoints.anyscale.com/v1"
2 changes: 1 addition & 1 deletion src/distilabel/llm/mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ class MistralLLM(AsyncLLM):
max_concurrent_requests: the maximum number of concurrent requests to send. Defaults to 64.
"""

model: str = "mistral-medium"
model: str
endpoint: str = "https://api.mistral.ai"
api_key: Optional[SecretStr] = os.getenv("MISTRAL_API_KEY", None) # type: ignore
max_retries: int = 5
Expand Down
2 changes: 1 addition & 1 deletion src/distilabel/llm/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ class OpenAILLM(AsyncLLM):
api_key: the API key to authenticate the requests to the OpenAI API.
"""

model: str = "gpt-3.5-turbo"
model: str
base_url: Optional[str] = None
api_key: Optional[SecretStr] = os.getenv("OPENAI_API_KEY", None) # type: ignore

Expand Down
2 changes: 1 addition & 1 deletion tests/integration/test_pipe_llms.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def test_pipeline_with_llms_serde():
os.environ["OPENAI_API_KEY"] = "sk-***"
generate_response = TextGeneration(
name="generate_response",
llm=OpenAILLM(),
llm=OpenAILLM(model="gpt-3.5-turbo"),
output_mappings={"generation": "output"},
)
rename_columns.connect(generate_response)
Expand Down
42 changes: 42 additions & 0 deletions tests/unit/llm/test_anyscale.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
# Copyright 2023-present, Argilla, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os

from distilabel.llm.anyscale import AnyscaleLLM

MODEL = "mistralai/Mixtral-8x7B-Instruct-v0.1"


class TestAnyscaleLLM:
def test_anyscale_llm(self) -> None:
llm = AnyscaleLLM(model=MODEL, api_key="api.key")
assert isinstance(llm, AnyscaleLLM)
assert llm.model_name == MODEL

def test_serialization(self) -> None:
os.environ["OPENAI_API_KEY"] = "api.key"
llm = AnyscaleLLM(model=MODEL)

_dump = {
"model": MODEL,
"base_url": "https://api.endpoints.anyscale.com/v1",
"type_info": {
"module": "distilabel.llm.anyscale",
"name": "AnyscaleLLM",
},
}

assert llm.dump() == _dump
assert isinstance(AnyscaleLLM.from_dict(_dump), AnyscaleLLM)

0 comments on commit 87afbc5

Please sign in to comment.