From 35a280e68e21c1a8ef874095d19ff799db31b6d8 Mon Sep 17 00:00:00 2001 From: Ja YOUNG Lee <43683780+jayolee@users.noreply.github.com> Date: Fri, 21 Jul 2023 17:52:20 +0100 Subject: [PATCH] feat: Remove ModelType enums (#105) --- GETTING_STARTED.md | 17 +++++----- .../rst_source/genai.schemas.models.rst | 7 ----- .../docs/source/rst_source/genai.schemas.rst | 1 - examples/dev/async-flaky-request-handler.py | 4 +-- examples/dev/async-flaky-responses-ordered.py | 6 ++-- examples/dev/generate-all-models.py | 5 +-- examples/dev/logging_example.py | 4 +-- .../watsonx-prompt-output.py | 7 +++-- .../watsonx-prompt-pattern-ux-async.py | 7 +++-- .../watsonx-prompt-pattern-ux.py | 7 +++-- src/genai/extensions/langchain/llm.py | 10 +++--- src/genai/model.py | 6 ++-- src/genai/schemas/__init__.py | 2 -- src/genai/schemas/models.py | 31 ------------------- src/genai/schemas/responses.py | 5 ++- src/genai/services/async_generator.py | 2 +- tests/extensions/test_langchain.py | 6 ++-- tests/test_logging.py | 10 +++--- tests/test_model.py | 14 ++++----- tests/test_model_async.py | 10 +++--- 20 files changed, 60 insertions(+), 101 deletions(-) delete mode 100644 documentation/docs/source/rst_source/genai.schemas.models.rst delete mode 100644 src/genai/schemas/models.py diff --git a/GETTING_STARTED.md b/GETTING_STARTED.md index 2cc79129..09e9c3d2 100644 --- a/GETTING_STARTED.md +++ b/GETTING_STARTED.md @@ -2,7 +2,6 @@ ## Table of Contents - * [Table of Contents](#table-of-contents) * [Installation](#installation) * [Gen AI Endpoint](#gen-ai-endpoint) @@ -23,10 +22,13 @@ ```bash pip install ibm-generative-ai ``` + #### Known Issue Fixes: + - **[SSL Issue]** If you run into "SSL_CERTIFICATE_VERIFY_FAILED" please run the following code snippet here: [support](SUPPORT.md). ### Prerequisites + Python version >= 3.9 Pip version >= 22.0.1 @@ -71,13 +73,11 @@ creds = Credentials(api_key=my_api_key, api_endpoint=my_api_endpoint) ``` - ## Examples There are a number of examples you can try in the [`examples/user`](examples/user) directory. Login to [workbench.res.ibm.com](https://workbench.res.ibm.com/) and get your GenAI API key. Then, create a `.env` file and assign the `GENAI_KEY` value as below example. [More information](#gen-ai-endpoint) - ```ini GENAI_KEY=YOUR_GENAI_API_KEY # GENAI_API=GENAI_API_ENDPOINT << for a different endpoint @@ -258,6 +258,7 @@ To learn more about logging in python, you can follow the tutorial [here](https: Since generating responses for a large number of prompts can be time-consuming and there could be unforeseen circumstances such as internet connectivity issues, here are some strategies to work with: + - Start with a small number of prompts to prototype the code. You can enable logging as described above for debugging during prototyping. - Include exception handling in sensitive sections such as callbacks. - Checkpoint/save prompts and received responses periodically. @@ -292,10 +293,13 @@ us if you want support for some framework as an extension or want to design an e ### LangChain Extension Install the langchain extension as follows: + ```bash pip install "ibm-generative-ai[langchain]" ``` + Currently the langchain extension allows IBM Generative AI models to be wrapped as Langchain LLMs and translation between genai PromptPatterns and LangChain PromptTemplates. Below are sample snippets + ```python import os from dotenv import load_dotenv @@ -327,13 +331,6 @@ print(langchain_model(template.format(question="What is life?"))) print(genai_model.generate([pattern.sub("question", "What is life?")])[0].generated_text) ``` -## [Deprecated] Model Types - -Model types can be imported from the [ModelType class](src/genai/schemas/models.py). If you want to use a model that is not included in this class, you can pass it as a string as exemplified [here](src/genai/schemas/models.py). - -Models can be selected by passing their string id to the Model class as exemplified [here](src/genai/schemas/models.py). - - ## Support Need help? Check out how to get [support](SUPPORT.md) diff --git a/documentation/docs/source/rst_source/genai.schemas.models.rst b/documentation/docs/source/rst_source/genai.schemas.models.rst deleted file mode 100644 index b349309e..00000000 --- a/documentation/docs/source/rst_source/genai.schemas.models.rst +++ /dev/null @@ -1,7 +0,0 @@ -Models -=========================== - -.. automodule:: genai.schemas.models - :members: - :undoc-members: - :show-inheritance: diff --git a/documentation/docs/source/rst_source/genai.schemas.rst b/documentation/docs/source/rst_source/genai.schemas.rst index 57ca8f78..b9239a25 100644 --- a/documentation/docs/source/rst_source/genai.schemas.rst +++ b/documentation/docs/source/rst_source/genai.schemas.rst @@ -10,7 +10,6 @@ Submodules genai.schemas.descriptions genai.schemas.generate_params genai.schemas.history_params - genai.schemas.models genai.schemas.responses genai.schemas.token_params genai.schemas.tunes_params diff --git a/examples/dev/async-flaky-request-handler.py b/examples/dev/async-flaky-request-handler.py index 1a65372e..5ce9a8a4 100644 --- a/examples/dev/async-flaky-request-handler.py +++ b/examples/dev/async-flaky-request-handler.py @@ -7,7 +7,7 @@ from dotenv import load_dotenv from genai.model import Credentials, Model -from genai.schemas import GenerateParams, ModelType, TokenParams +from genai.schemas import GenerateParams, TokenParams from genai.services.connection_manager import ConnectionManager from genai.services.request_handler import RequestHandler @@ -80,7 +80,7 @@ async def flaky_async_generate( tokenize_params = TokenParams(return_tokens=True) -flan_ul2 = Model(ModelType.FLAN_UL2, params=generate_params, credentials=creds) +flan_ul2 = Model("google/flan-ul2", params=generate_params, credentials=creds) prompts = ["Generate a random number > {}: ".format(i) for i in range(25)] for response in flan_ul2.generate_async(prompts, ordered=True): pass diff --git a/examples/dev/async-flaky-responses-ordered.py b/examples/dev/async-flaky-responses-ordered.py index c6a735b4..34968bfb 100644 --- a/examples/dev/async-flaky-responses-ordered.py +++ b/examples/dev/async-flaky-responses-ordered.py @@ -6,7 +6,7 @@ from dotenv import load_dotenv from genai.model import Credentials, GenAiException, Model -from genai.schemas import GenerateParams, ModelType, TokenParams +from genai.schemas import GenerateParams, TokenParams from genai.services.async_generator import AsyncResponseGenerator num_requests = 0 @@ -83,7 +83,7 @@ def tokenize_async(self, prompts, ordered=False, callback=None, options=None): tokenize_params = TokenParams(return_tokens=True) -flan_ul2 = FlakyModel(ModelType.FLAN_UL2_20B, params=generate_params, credentials=creds) +flan_ul2 = FlakyModel("google/flan-ul2", params=generate_params, credentials=creds) prompts = ["Generate a random number > {}: ".format(i) for i in range(17)] print("======== Async Generate with ordered=True ======== ") counter = 0 @@ -97,7 +97,7 @@ def tokenize_async(self, prompts, ordered=False, callback=None, options=None): num_requests = 0 # Instantiate a model proxy object to send your requests -flan_ul2 = FlakyModel(ModelType.FLAN_UL2_20B, params=tokenize_params, credentials=creds) +flan_ul2 = FlakyModel("google/flan-ul2", params=tokenize_params, credentials=creds) prompts = ["Generate a random number > {}: ".format(i) for i in range(23)] print("======== Async Tokenize with ordered=True ======== ") counter = 0 diff --git a/examples/dev/generate-all-models.py b/examples/dev/generate-all-models.py index 4e1b97c0..78dd8ace 100644 --- a/examples/dev/generate-all-models.py +++ b/examples/dev/generate-all-models.py @@ -3,7 +3,7 @@ from dotenv import load_dotenv from genai.model import Credentials, Model -from genai.schemas import GenerateParams, ModelType +from genai.schemas import GenerateParams # make sure you have a .env file under genai root with # GENAI_KEY= @@ -24,7 +24,8 @@ " during iteration it will do symb1 symb1 symb1 due to how it" " maps internally. ====" ) -for key, modelid in ModelType.__members__.items(): +for model_card in Model.models(credentials=creds): + modelid = model_card.id model = Model(modelid, params=params, credentials=creds) responses = [response.generated_text for response in model.generate(prompts)] print(modelid, ":", responses) diff --git a/examples/dev/logging_example.py b/examples/dev/logging_example.py index c6230824..58acc38f 100644 --- a/examples/dev/logging_example.py +++ b/examples/dev/logging_example.py @@ -4,7 +4,7 @@ from dotenv import load_dotenv from genai.model import Credentials, Model -from genai.schemas import GenerateParams, ModelType +from genai.schemas import GenerateParams logging.basicConfig(level=logging.INFO) @@ -22,7 +22,7 @@ params = GenerateParams(decoding_method="sample", max_new_tokens=10) # Instantiate a model proxy object to send your requests -flan_ul2 = Model(ModelType.FLAN_UL2, params=params, credentials=creds) +flan_ul2 = Model("google/flan-ul2", params=params, credentials=creds) prompts = ["Hello! How are you?", "How's the weather?"] for response in flan_ul2.generate_async(prompts): diff --git a/examples/user/prompt_templating/watsonx-prompt-output.py b/examples/user/prompt_templating/watsonx-prompt-output.py index 97962fed..d417734c 100644 --- a/examples/user/prompt_templating/watsonx-prompt-output.py +++ b/examples/user/prompt_templating/watsonx-prompt-output.py @@ -2,9 +2,10 @@ from dotenv import load_dotenv -from genai.model import Credentials, Model +from genai.credentials import Credentials +from genai.model import Model from genai.prompt_pattern import PromptPattern -from genai.schemas import GenerateParams, ModelType +from genai.schemas import GenerateParams # make sure you have a .env file under genai root with # GENAI_KEY= @@ -15,7 +16,7 @@ creds = Credentials(api_key, api_endpoint=api_url) params = GenerateParams(temperature=0.5) -model = Model(ModelType.FLAN_UL2, params=params, credentials=creds) +model = Model("google/flan-ul2", params=params, credentials=creds) _template = """ diff --git a/examples/user/prompt_templating/watsonx-prompt-pattern-ux-async.py b/examples/user/prompt_templating/watsonx-prompt-pattern-ux-async.py index 6ec70baf..82b44f90 100644 --- a/examples/user/prompt_templating/watsonx-prompt-pattern-ux-async.py +++ b/examples/user/prompt_templating/watsonx-prompt-pattern-ux-async.py @@ -3,10 +3,11 @@ from dotenv import load_dotenv -from genai.model import Credentials, Model +from genai.credentials import Credentials +from genai.model import Model from genai.options import Options from genai.prompt_pattern import PromptPattern -from genai.schemas import GenerateParams, ModelType +from genai.schemas import GenerateParams # make sure you have a .env file under genai root with # GENAI_KEY= @@ -18,7 +19,7 @@ creds = Credentials(api_key, api_endpoint=api_url) params = GenerateParams(temperature=0.5) -model = Model(ModelType.FLAN_UL2, params=params, credentials=creds) +model = Model("google/flan-ul2", params=params, credentials=creds) _template = """ diff --git a/examples/user/prompt_templating/watsonx-prompt-pattern-ux.py b/examples/user/prompt_templating/watsonx-prompt-pattern-ux.py index e0b56b4b..6453ee91 100644 --- a/examples/user/prompt_templating/watsonx-prompt-pattern-ux.py +++ b/examples/user/prompt_templating/watsonx-prompt-pattern-ux.py @@ -3,10 +3,11 @@ from dotenv import load_dotenv -from genai.model import Credentials, Model +from genai.credentials import Credentials +from genai.model import Model from genai.options import Options from genai.prompt_pattern import PromptPattern -from genai.schemas import GenerateParams, ModelType +from genai.schemas import GenerateParams # make sure you have a .env file under genai root with # GENAI_KEY= @@ -18,7 +19,7 @@ creds = Credentials(api_key, api_endpoint=api_url) params = GenerateParams(temperature=0.5) -model = Model(ModelType.FLAN_UL2, params=params, credentials=creds) +model = Model("google/flan-ul2", params=params, credentials=creds) _template = """ diff --git a/src/genai/extensions/langchain/llm.py b/src/genai/extensions/langchain/llm.py index a7450f24..68a553eb 100644 --- a/src/genai/extensions/langchain/llm.py +++ b/src/genai/extensions/langchain/llm.py @@ -1,6 +1,6 @@ """Wrapper around IBM GENAI APIs for use in langchain""" import logging -from typing import Any, List, Mapping, Optional, Union +from typing import Any, List, Mapping, Optional from pydantic import BaseModel, Extra @@ -11,7 +11,7 @@ raise ImportError("Could not import langchain: Please install ibm-generative-ai[langchain] extension.") from genai import Credentials, Model -from genai.schemas import GenerateParams, ModelType +from genai.schemas import GenerateParams logger = logging.getLogger(__name__) @@ -28,11 +28,11 @@ class LangChainInterface(LLM, BaseModel): parameter, which is an instance of GenerateParams. Example: .. code-block:: python - llm = LangChainInterface(model=ModelType.FLAN_UL2, credentials=creds) + llm = LangChainInterface(model="google/flan-ul2", credentials=creds) """ credentials: Credentials = None - model: Optional[Union[ModelType, str]] = None + model: Optional[str] = None params: Optional[GenerateParams] = None class Config: @@ -63,7 +63,7 @@ def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str: The string generated by the model. Example: .. code-block:: python - llm = LangChainInterface(model_id=ModelType.FLAN_UL2, credentials=creds) + llm = LangChainInterface(model_id="google/flan-ul2", credentials=creds) response = llm("What is a molecule") """ params = self.params or GenerateParams() diff --git a/src/genai/model.py b/src/genai/model.py index 0b8fc1a9..96b4ab99 100644 --- a/src/genai/model.py +++ b/src/genai/model.py @@ -10,7 +10,7 @@ from genai.metadata import Metadata from genai.options import Options from genai.prompt_pattern import PromptPattern -from genai.schemas import GenerateParams, ModelType, TokenParams +from genai.schemas import GenerateParams, TokenParams from genai.schemas.responses import ( GenerateResponse, GenerateResult, @@ -38,14 +38,14 @@ class Model: def __init__( self, - model: Union[ModelType, str], + model: str, params: Union[GenerateParams, TokenParams, Any] = None, credentials: Credentials = None, ): """Instantiates the Model Interface Args: - model (Union[ModelType, str]): The type of model to use + model (str): The type of model to use params (Union[GenerateParams, TokenParams]): Parameters to use during generate requests credentials (Credentials): The API Credentials """ diff --git a/src/genai/schemas/__init__.py b/src/genai/schemas/__init__.py index c7b22bfc..f1c25f13 100644 --- a/src/genai/schemas/__init__.py +++ b/src/genai/schemas/__init__.py @@ -7,7 +7,6 @@ ReturnOptions, ) from genai.schemas.history_params import HistoryParams -from genai.schemas.models import ModelType from genai.schemas.responses import GenerateResult, TokenizeResult from genai.schemas.token_params import TokenParams from genai.schemas.tunes_params import ( @@ -24,7 +23,6 @@ "ReturnOptions", "TokenParams", "HistoryParams", - "ModelType", "GenerateResult", "TokenizeResult", "FileListParams", diff --git a/src/genai/schemas/models.py b/src/genai/schemas/models.py deleted file mode 100644 index fdd3aad8..00000000 --- a/src/genai/schemas/models.py +++ /dev/null @@ -1,31 +0,0 @@ -import warnings -from enum import Enum - -warnings.simplefilter("always", DeprecationWarning) -warnings.warn( - """\x1b[33;20m -The class ModelType is being deprecated. -Please replace any reference to ModelType by its model id string equivalent. -Example : - ModelType.FLAN_T5 becomes "google/flan-t5-xxl"\x1b[0m -""", - DeprecationWarning, - stacklevel=2, -) - - -class ModelType(str, Enum): - CODEGEN_MONO_16B = "salesforce/codegen-16b-mono" - DIAL_FLAN_T5 = "prakharz/dial-flant5-xl" - DIAL_FLAN_T5_3B = "prakharz/dial-flant5-xl" - FLAN_T5 = "google/flan-t5-xxl" - FLAN_T5_11B = "google/flan-t5-xxl" - FLAN_T5_3B = "google/flan-t5-xl" - FLAN_UL2 = "google/flan-ul2" - FLAN_UL2_20B = "google/flan-ul2" - GPT_JT_6B_V1 = "togethercomputer/gpt-jt-6b-v1" - GPT_NEOX_20B = "eleutherai/gpt-neox-20b" - MT0 = "bigscience/mt0-xxl" - MT0_13B = "bigscience/mt0-xxl" - UL2 = "google/ul2" - UL2_20B = "google/ul2" diff --git a/src/genai/schemas/responses.py b/src/genai/schemas/responses.py index e0c6c871..228cd007 100644 --- a/src/genai/schemas/responses.py +++ b/src/genai/schemas/responses.py @@ -6,7 +6,6 @@ from pydantic import BaseModel, Extra, root_validator from genai.schemas.generate_params import GenerateParams -from genai.schemas.models import ModelType logger = logging.getLogger(__name__) @@ -76,7 +75,7 @@ class GenerateResult(GenAiResponseModel): class GenerateResponse(GenAiResponseModel): - model_id: Union[ModelType, str] + model_id: str created_at: datetime results: List[GenerateResult] @@ -98,7 +97,7 @@ class TokenizeResult(GenAiResponseModel): class TokenizeResponse(GenAiResponseModel): - model_id: Union[ModelType, str] + model_id: str created_at: datetime results: List[TokenizeResult] diff --git a/src/genai/services/async_generator.py b/src/genai/services/async_generator.py index 31e0e1d1..cd4f4ab9 100644 --- a/src/genai/services/async_generator.py +++ b/src/genai/services/async_generator.py @@ -21,7 +21,7 @@ def __init__( """Instantiates the ConcurrentWrapper Interface. Args: - model_id (ModelType): The type of model to use + model_id (str): The type of model to use prompts (list): List of prompts params (GenerateParams): Parameters to use during generate requests service (ServiceInterface): The service interface diff --git a/tests/extensions/test_langchain.py b/tests/extensions/test_langchain.py index 228b05dd..96a9da0b 100644 --- a/tests/extensions/test_langchain.py +++ b/tests/extensions/test_langchain.py @@ -3,7 +3,7 @@ import pytest from genai import Credentials -from genai.schemas import GenerateParams, ModelType +from genai.schemas import GenerateParams from genai.schemas.responses import GenerateResponse from genai.services import ServiceInterface from tests.assets.response_helper import SimpleResponse @@ -32,14 +32,14 @@ def prompts(self): def test_langchain_interface(self, mocked_post_request, credentials, params, prompts): from genai.extensions.langchain import LangChainInterface - GENERATE_RESPONSE = SimpleResponse.generate(model=ModelType.FLAN_UL2, inputs=prompts, params=params) + GENERATE_RESPONSE = SimpleResponse.generate(model="google/flan-ul2", inputs=prompts, params=params) expected_generated_response = GenerateResponse(**GENERATE_RESPONSE) response = MagicMock(status_code=200) response.json.return_value = GENERATE_RESPONSE mocked_post_request.return_value = response - model = LangChainInterface(model=ModelType.FLAN_UL2, params=params, credentials=credentials) + model = LangChainInterface(model="google/flan-ul2", params=params, credentials=credentials) observed = model(prompts[0]) assert observed == expected_generated_response.results[0].generated_text diff --git a/tests/test_logging.py b/tests/test_logging.py index 7f36bc48..2860bcd2 100644 --- a/tests/test_logging.py +++ b/tests/test_logging.py @@ -3,7 +3,7 @@ import pytest from genai import Credentials, Model -from genai.schemas import GenerateParams, ModelType +from genai.schemas import GenerateParams logger = logging.getLogger() logger.addHandler(logging.StreamHandler()) @@ -14,7 +14,7 @@ class TestLogging: def test_no_leaked_logs(self, caplog): credentials = Credentials("GENAI_API_KEY") params = GenerateParams() - Model(ModelType.FLAN_UL2, params=params, credentials=credentials) + Model("google/flan-ul2", params=params, credentials=credentials) assert len(caplog.records) == 0 @@ -23,6 +23,6 @@ def test_basic_logs(self, caplog): credentials = Credentials("GENAI_API_KEY") params = GenerateParams() - Model(ModelType.FLAN_UL2, params=params, credentials=credentials) - # Enums are converted to strings slightly differently across python 3.9, 3.10 and 3.11 - assert any(x in caplog.text for x in [ModelType.FLAN_UL2, ModelType.FLAN_UL2.value, "ModelType.FLAN_UL2"]) + Model("google/flan-ul2", params=params, credentials=credentials) + + assert "google/flan-ul2" in caplog.text diff --git a/tests/test_model.py b/tests/test_model.py index cc389916..b90be5d2 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -4,7 +4,7 @@ from genai import Credentials, Model from genai.exceptions import GenAiException -from genai.schemas import GenerateParams, ModelType +from genai.schemas import GenerateParams from genai.schemas.responses import ( GenerateResponse, ModelCard, @@ -48,14 +48,14 @@ def prompts(self): def test_generate(self, mocked_post_request, credentials, params, prompts): """Tests that we can call the generate endpoint""" - GENERATE_RESPONSE = SimpleResponse.generate(model=ModelType.FLAN_UL2, inputs=prompts, params=params) + GENERATE_RESPONSE = SimpleResponse.generate(model="google/flan-ul2", inputs=prompts, params=params) expected_generated_response = GenerateResponse(**GENERATE_RESPONSE) response = MagicMock(status_code=200) response.json.return_value = GENERATE_RESPONSE mocked_post_request.return_value = response - model = Model(ModelType.FLAN_UL2, params=params, credentials=credentials) + model = Model("google/flan-ul2", params=params, credentials=credentials) responses = model.generate_as_completed(prompts=prompts) responses_list = list(responses) @@ -70,7 +70,7 @@ def test_generate_throws_exception_for_non_200(self, mock_service_generate, cred mock_service_generate.return_value = MagicMock(status_code=500) - model = Model(ModelType.FLAN_UL2, params=params, credentials=credentials) + model = Model("google/flan-ul2", params=params, credentials=credentials) with pytest.raises(GenAiException): model.generate(prompts=prompts) @@ -78,7 +78,7 @@ def test_generate_throws_exception_for_non_200(self, mock_service_generate, cred @patch("genai.services.RequestHandler.post", side_effect=Exception("some general error")) def test_generate_throws_exception_for_generic_exception(self, credentials, params, prompts): """Tests that the GenAiException is thrown if a generic Exception is raised""" - model = Model(ModelType.FLAN_UL2, params=params, credentials=credentials) + model = Model("google/flan-ul2", params=params, credentials=credentials) with pytest.raises(GenAiException, match="some general error"): model.generate(prompts=prompts) @@ -87,14 +87,14 @@ def test_generate_throws_exception_for_generic_exception(self, credentials, para def test_tokenize(self, mocked_post_request, credentials, params): """Tests that we can call the tokenize endpoint""" - TOKENIZE_RESPONSE = SimpleResponse.tokenize(model=ModelType.FLAN_UL2, inputs=["a", "b", "c"]) + TOKENIZE_RESPONSE = SimpleResponse.tokenize(model="google/flan-ul2", inputs=["a", "b", "c"]) expected_token_response = TokenizeResponse(**TOKENIZE_RESPONSE) mock_response = MagicMock(status_code=200) mock_response.json.return_value = TOKENIZE_RESPONSE mocked_post_request.return_value = mock_response - model = Model(ModelType.FLAN_UL2, params=params, credentials=credentials) + model = Model("google/flan-ul2", params=params, credentials=credentials) responses = model.tokenize(["a", "b", "c"], False) diff --git a/tests/test_model_async.py b/tests/test_model_async.py index c27d049b..712e0408 100644 --- a/tests/test_model_async.py +++ b/tests/test_model_async.py @@ -3,7 +3,7 @@ import pytest from genai import Credentials, Model -from genai.schemas import GenerateParams, ModelType, TokenParams +from genai.schemas import GenerateParams, TokenParams from genai.schemas.responses import GenerateResponse, TokenizeResponse from genai.services import ServiceInterface from tests.assets.response_helper import SimpleResponse @@ -46,7 +46,7 @@ async def test_generate_async(self, mock_generate_json, generate_params): creds = Credentials("TEST_API_KEY") mock_generate_json.side_effect = expected - model = Model(ModelType.FLAN_UL2, params=generate_params, credentials=creds) + model = Model("google/flan-ul2", params=generate_params, credentials=creds) counter = 0 responses = list(model.generate_async(prompts)) @@ -67,7 +67,7 @@ async def test_tokenize_async(self, mock_tokenize_json, tokenize_params): creds = Credentials("TEST_API_KEY") mock_tokenize_json.side_effect = expected - model = Model(ModelType.FLAN_UL2, params=tokenize_params, credentials=creds) + model = Model("google/flan-ul2", params=tokenize_params, credentials=creds) counter = 0 responses = list(model.tokenize_async(prompts)) @@ -92,7 +92,7 @@ def tasks_completed(result): message += result.generated_text prompts = ["TEST_PROMPT"] * num_prompts - model = Model(ModelType.FLAN_UL2, params=generate_params, credentials=creds) + model = Model("google/flan-ul2", params=generate_params, credentials=creds) for result in model.generate_async(prompts, callback=tasks_completed): pass @@ -114,7 +114,7 @@ def tasks_completed(result): message += result.tokens prompts = ["TEST_PROMPT"] * num_prompts - model = Model(ModelType.FLAN_UL2, params=tokenize_params, credentials=creds) + model = Model("google/flan-ul2", params=tokenize_params, credentials=creds) for result in model.tokenize_async(prompts, callback=tasks_completed): pass