From 35a280e68e21c1a8ef874095d19ff799db31b6d8 Mon Sep 17 00:00:00 2001
From: Ja YOUNG Lee <43683780+jayolee@users.noreply.github.com>
Date: Fri, 21 Jul 2023 17:52:20 +0100
Subject: [PATCH] feat: Remove ModelType enums (#105)
---
GETTING_STARTED.md | 17 +++++-----
.../rst_source/genai.schemas.models.rst | 7 -----
.../docs/source/rst_source/genai.schemas.rst | 1 -
examples/dev/async-flaky-request-handler.py | 4 +--
examples/dev/async-flaky-responses-ordered.py | 6 ++--
examples/dev/generate-all-models.py | 5 +--
examples/dev/logging_example.py | 4 +--
.../watsonx-prompt-output.py | 7 +++--
.../watsonx-prompt-pattern-ux-async.py | 7 +++--
.../watsonx-prompt-pattern-ux.py | 7 +++--
src/genai/extensions/langchain/llm.py | 10 +++---
src/genai/model.py | 6 ++--
src/genai/schemas/__init__.py | 2 --
src/genai/schemas/models.py | 31 -------------------
src/genai/schemas/responses.py | 5 ++-
src/genai/services/async_generator.py | 2 +-
tests/extensions/test_langchain.py | 6 ++--
tests/test_logging.py | 10 +++---
tests/test_model.py | 14 ++++-----
tests/test_model_async.py | 10 +++---
20 files changed, 60 insertions(+), 101 deletions(-)
delete mode 100644 documentation/docs/source/rst_source/genai.schemas.models.rst
delete mode 100644 src/genai/schemas/models.py
diff --git a/GETTING_STARTED.md b/GETTING_STARTED.md
index 2cc79129..09e9c3d2 100644
--- a/GETTING_STARTED.md
+++ b/GETTING_STARTED.md
@@ -2,7 +2,6 @@
## Table of Contents
-
* [Table of Contents](#table-of-contents)
* [Installation](#installation)
* [Gen AI Endpoint](#gen-ai-endpoint)
@@ -23,10 +22,13 @@
```bash
pip install ibm-generative-ai
```
+
#### Known Issue Fixes:
+
- **[SSL Issue]** If you run into "SSL_CERTIFICATE_VERIFY_FAILED" please run the following code snippet here: [support](SUPPORT.md).
### Prerequisites
+
Python version >= 3.9
Pip version >= 22.0.1
@@ -71,13 +73,11 @@ creds = Credentials(api_key=my_api_key, api_endpoint=my_api_endpoint)
```
-
## Examples
There are a number of examples you can try in the [`examples/user`](examples/user) directory.
Login to [workbench.res.ibm.com](https://workbench.res.ibm.com/) and get your GenAI API key. Then, create a `.env` file and assign the `GENAI_KEY` value as below example. [More information](#gen-ai-endpoint)
-
```ini
GENAI_KEY=YOUR_GENAI_API_KEY
# GENAI_API=GENAI_API_ENDPOINT << for a different endpoint
@@ -258,6 +258,7 @@ To learn more about logging in python, you can follow the tutorial [here](https:
Since generating responses for a large number of prompts can be time-consuming and there could be unforeseen circumstances such as internet connectivity issues, here are some strategies
to work with:
+
- Start with a small number of prompts to prototype the code. You can enable logging as described above for debugging during prototyping.
- Include exception handling in sensitive sections such as callbacks.
- Checkpoint/save prompts and received responses periodically.
@@ -292,10 +293,13 @@ us if you want support for some framework as an extension or want to design an e
### LangChain Extension
Install the langchain extension as follows:
+
```bash
pip install "ibm-generative-ai[langchain]"
```
+
Currently the langchain extension allows IBM Generative AI models to be wrapped as Langchain LLMs and translation between genai PromptPatterns and LangChain PromptTemplates. Below are sample snippets
+
```python
import os
from dotenv import load_dotenv
@@ -327,13 +331,6 @@ print(langchain_model(template.format(question="What is life?")))
print(genai_model.generate([pattern.sub("question", "What is life?")])[0].generated_text)
```
-## [Deprecated] Model Types
-
-Model types can be imported from the [ModelType class](src/genai/schemas/models.py). If you want to use a model that is not included in this class, you can pass it as a string as exemplified [here](src/genai/schemas/models.py).
-
-Models can be selected by passing their string id to the Model class as exemplified [here](src/genai/schemas/models.py).
-
-
## Support
Need help? Check out how to get [support](SUPPORT.md)
diff --git a/documentation/docs/source/rst_source/genai.schemas.models.rst b/documentation/docs/source/rst_source/genai.schemas.models.rst
deleted file mode 100644
index b349309e..00000000
--- a/documentation/docs/source/rst_source/genai.schemas.models.rst
+++ /dev/null
@@ -1,7 +0,0 @@
-Models
-===========================
-
-.. automodule:: genai.schemas.models
- :members:
- :undoc-members:
- :show-inheritance:
diff --git a/documentation/docs/source/rst_source/genai.schemas.rst b/documentation/docs/source/rst_source/genai.schemas.rst
index 57ca8f78..b9239a25 100644
--- a/documentation/docs/source/rst_source/genai.schemas.rst
+++ b/documentation/docs/source/rst_source/genai.schemas.rst
@@ -10,7 +10,6 @@ Submodules
genai.schemas.descriptions
genai.schemas.generate_params
genai.schemas.history_params
- genai.schemas.models
genai.schemas.responses
genai.schemas.token_params
genai.schemas.tunes_params
diff --git a/examples/dev/async-flaky-request-handler.py b/examples/dev/async-flaky-request-handler.py
index 1a65372e..5ce9a8a4 100644
--- a/examples/dev/async-flaky-request-handler.py
+++ b/examples/dev/async-flaky-request-handler.py
@@ -7,7 +7,7 @@
from dotenv import load_dotenv
from genai.model import Credentials, Model
-from genai.schemas import GenerateParams, ModelType, TokenParams
+from genai.schemas import GenerateParams, TokenParams
from genai.services.connection_manager import ConnectionManager
from genai.services.request_handler import RequestHandler
@@ -80,7 +80,7 @@ async def flaky_async_generate(
tokenize_params = TokenParams(return_tokens=True)
-flan_ul2 = Model(ModelType.FLAN_UL2, params=generate_params, credentials=creds)
+flan_ul2 = Model("google/flan-ul2", params=generate_params, credentials=creds)
prompts = ["Generate a random number > {}: ".format(i) for i in range(25)]
for response in flan_ul2.generate_async(prompts, ordered=True):
pass
diff --git a/examples/dev/async-flaky-responses-ordered.py b/examples/dev/async-flaky-responses-ordered.py
index c6a735b4..34968bfb 100644
--- a/examples/dev/async-flaky-responses-ordered.py
+++ b/examples/dev/async-flaky-responses-ordered.py
@@ -6,7 +6,7 @@
from dotenv import load_dotenv
from genai.model import Credentials, GenAiException, Model
-from genai.schemas import GenerateParams, ModelType, TokenParams
+from genai.schemas import GenerateParams, TokenParams
from genai.services.async_generator import AsyncResponseGenerator
num_requests = 0
@@ -83,7 +83,7 @@ def tokenize_async(self, prompts, ordered=False, callback=None, options=None):
tokenize_params = TokenParams(return_tokens=True)
-flan_ul2 = FlakyModel(ModelType.FLAN_UL2_20B, params=generate_params, credentials=creds)
+flan_ul2 = FlakyModel("google/flan-ul2", params=generate_params, credentials=creds)
prompts = ["Generate a random number > {}: ".format(i) for i in range(17)]
print("======== Async Generate with ordered=True ======== ")
counter = 0
@@ -97,7 +97,7 @@ def tokenize_async(self, prompts, ordered=False, callback=None, options=None):
num_requests = 0
# Instantiate a model proxy object to send your requests
-flan_ul2 = FlakyModel(ModelType.FLAN_UL2_20B, params=tokenize_params, credentials=creds)
+flan_ul2 = FlakyModel("google/flan-ul2", params=tokenize_params, credentials=creds)
prompts = ["Generate a random number > {}: ".format(i) for i in range(23)]
print("======== Async Tokenize with ordered=True ======== ")
counter = 0
diff --git a/examples/dev/generate-all-models.py b/examples/dev/generate-all-models.py
index 4e1b97c0..78dd8ace 100644
--- a/examples/dev/generate-all-models.py
+++ b/examples/dev/generate-all-models.py
@@ -3,7 +3,7 @@
from dotenv import load_dotenv
from genai.model import Credentials, Model
-from genai.schemas import GenerateParams, ModelType
+from genai.schemas import GenerateParams
# make sure you have a .env file under genai root with
# GENAI_KEY=
@@ -24,7 +24,8 @@
" during iteration it will do symb1 symb1 symb1 due to how it"
" maps internally. ===="
)
-for key, modelid in ModelType.__members__.items():
+for model_card in Model.models(credentials=creds):
+ modelid = model_card.id
model = Model(modelid, params=params, credentials=creds)
responses = [response.generated_text for response in model.generate(prompts)]
print(modelid, ":", responses)
diff --git a/examples/dev/logging_example.py b/examples/dev/logging_example.py
index c6230824..58acc38f 100644
--- a/examples/dev/logging_example.py
+++ b/examples/dev/logging_example.py
@@ -4,7 +4,7 @@
from dotenv import load_dotenv
from genai.model import Credentials, Model
-from genai.schemas import GenerateParams, ModelType
+from genai.schemas import GenerateParams
logging.basicConfig(level=logging.INFO)
@@ -22,7 +22,7 @@
params = GenerateParams(decoding_method="sample", max_new_tokens=10)
# Instantiate a model proxy object to send your requests
-flan_ul2 = Model(ModelType.FLAN_UL2, params=params, credentials=creds)
+flan_ul2 = Model("google/flan-ul2", params=params, credentials=creds)
prompts = ["Hello! How are you?", "How's the weather?"]
for response in flan_ul2.generate_async(prompts):
diff --git a/examples/user/prompt_templating/watsonx-prompt-output.py b/examples/user/prompt_templating/watsonx-prompt-output.py
index 97962fed..d417734c 100644
--- a/examples/user/prompt_templating/watsonx-prompt-output.py
+++ b/examples/user/prompt_templating/watsonx-prompt-output.py
@@ -2,9 +2,10 @@
from dotenv import load_dotenv
-from genai.model import Credentials, Model
+from genai.credentials import Credentials
+from genai.model import Model
from genai.prompt_pattern import PromptPattern
-from genai.schemas import GenerateParams, ModelType
+from genai.schemas import GenerateParams
# make sure you have a .env file under genai root with
# GENAI_KEY=
@@ -15,7 +16,7 @@
creds = Credentials(api_key, api_endpoint=api_url)
params = GenerateParams(temperature=0.5)
-model = Model(ModelType.FLAN_UL2, params=params, credentials=creds)
+model = Model("google/flan-ul2", params=params, credentials=creds)
_template = """
diff --git a/examples/user/prompt_templating/watsonx-prompt-pattern-ux-async.py b/examples/user/prompt_templating/watsonx-prompt-pattern-ux-async.py
index 6ec70baf..82b44f90 100644
--- a/examples/user/prompt_templating/watsonx-prompt-pattern-ux-async.py
+++ b/examples/user/prompt_templating/watsonx-prompt-pattern-ux-async.py
@@ -3,10 +3,11 @@
from dotenv import load_dotenv
-from genai.model import Credentials, Model
+from genai.credentials import Credentials
+from genai.model import Model
from genai.options import Options
from genai.prompt_pattern import PromptPattern
-from genai.schemas import GenerateParams, ModelType
+from genai.schemas import GenerateParams
# make sure you have a .env file under genai root with
# GENAI_KEY=
@@ -18,7 +19,7 @@
creds = Credentials(api_key, api_endpoint=api_url)
params = GenerateParams(temperature=0.5)
-model = Model(ModelType.FLAN_UL2, params=params, credentials=creds)
+model = Model("google/flan-ul2", params=params, credentials=creds)
_template = """
diff --git a/examples/user/prompt_templating/watsonx-prompt-pattern-ux.py b/examples/user/prompt_templating/watsonx-prompt-pattern-ux.py
index e0b56b4b..6453ee91 100644
--- a/examples/user/prompt_templating/watsonx-prompt-pattern-ux.py
+++ b/examples/user/prompt_templating/watsonx-prompt-pattern-ux.py
@@ -3,10 +3,11 @@
from dotenv import load_dotenv
-from genai.model import Credentials, Model
+from genai.credentials import Credentials
+from genai.model import Model
from genai.options import Options
from genai.prompt_pattern import PromptPattern
-from genai.schemas import GenerateParams, ModelType
+from genai.schemas import GenerateParams
# make sure you have a .env file under genai root with
# GENAI_KEY=
@@ -18,7 +19,7 @@
creds = Credentials(api_key, api_endpoint=api_url)
params = GenerateParams(temperature=0.5)
-model = Model(ModelType.FLAN_UL2, params=params, credentials=creds)
+model = Model("google/flan-ul2", params=params, credentials=creds)
_template = """
diff --git a/src/genai/extensions/langchain/llm.py b/src/genai/extensions/langchain/llm.py
index a7450f24..68a553eb 100644
--- a/src/genai/extensions/langchain/llm.py
+++ b/src/genai/extensions/langchain/llm.py
@@ -1,6 +1,6 @@
"""Wrapper around IBM GENAI APIs for use in langchain"""
import logging
-from typing import Any, List, Mapping, Optional, Union
+from typing import Any, List, Mapping, Optional
from pydantic import BaseModel, Extra
@@ -11,7 +11,7 @@
raise ImportError("Could not import langchain: Please install ibm-generative-ai[langchain] extension.")
from genai import Credentials, Model
-from genai.schemas import GenerateParams, ModelType
+from genai.schemas import GenerateParams
logger = logging.getLogger(__name__)
@@ -28,11 +28,11 @@ class LangChainInterface(LLM, BaseModel):
parameter, which is an instance of GenerateParams.
Example:
.. code-block:: python
- llm = LangChainInterface(model=ModelType.FLAN_UL2, credentials=creds)
+ llm = LangChainInterface(model="google/flan-ul2", credentials=creds)
"""
credentials: Credentials = None
- model: Optional[Union[ModelType, str]] = None
+ model: Optional[str] = None
params: Optional[GenerateParams] = None
class Config:
@@ -63,7 +63,7 @@ def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
The string generated by the model.
Example:
.. code-block:: python
- llm = LangChainInterface(model_id=ModelType.FLAN_UL2, credentials=creds)
+ llm = LangChainInterface(model_id="google/flan-ul2", credentials=creds)
response = llm("What is a molecule")
"""
params = self.params or GenerateParams()
diff --git a/src/genai/model.py b/src/genai/model.py
index 0b8fc1a9..96b4ab99 100644
--- a/src/genai/model.py
+++ b/src/genai/model.py
@@ -10,7 +10,7 @@
from genai.metadata import Metadata
from genai.options import Options
from genai.prompt_pattern import PromptPattern
-from genai.schemas import GenerateParams, ModelType, TokenParams
+from genai.schemas import GenerateParams, TokenParams
from genai.schemas.responses import (
GenerateResponse,
GenerateResult,
@@ -38,14 +38,14 @@ class Model:
def __init__(
self,
- model: Union[ModelType, str],
+ model: str,
params: Union[GenerateParams, TokenParams, Any] = None,
credentials: Credentials = None,
):
"""Instantiates the Model Interface
Args:
- model (Union[ModelType, str]): The type of model to use
+ model (str): The type of model to use
params (Union[GenerateParams, TokenParams]): Parameters to use during generate requests
credentials (Credentials): The API Credentials
"""
diff --git a/src/genai/schemas/__init__.py b/src/genai/schemas/__init__.py
index c7b22bfc..f1c25f13 100644
--- a/src/genai/schemas/__init__.py
+++ b/src/genai/schemas/__init__.py
@@ -7,7 +7,6 @@
ReturnOptions,
)
from genai.schemas.history_params import HistoryParams
-from genai.schemas.models import ModelType
from genai.schemas.responses import GenerateResult, TokenizeResult
from genai.schemas.token_params import TokenParams
from genai.schemas.tunes_params import (
@@ -24,7 +23,6 @@
"ReturnOptions",
"TokenParams",
"HistoryParams",
- "ModelType",
"GenerateResult",
"TokenizeResult",
"FileListParams",
diff --git a/src/genai/schemas/models.py b/src/genai/schemas/models.py
deleted file mode 100644
index fdd3aad8..00000000
--- a/src/genai/schemas/models.py
+++ /dev/null
@@ -1,31 +0,0 @@
-import warnings
-from enum import Enum
-
-warnings.simplefilter("always", DeprecationWarning)
-warnings.warn(
- """\x1b[33;20m
-The class ModelType is being deprecated.
-Please replace any reference to ModelType by its model id string equivalent.
-Example :
- ModelType.FLAN_T5 becomes "google/flan-t5-xxl"\x1b[0m
-""",
- DeprecationWarning,
- stacklevel=2,
-)
-
-
-class ModelType(str, Enum):
- CODEGEN_MONO_16B = "salesforce/codegen-16b-mono"
- DIAL_FLAN_T5 = "prakharz/dial-flant5-xl"
- DIAL_FLAN_T5_3B = "prakharz/dial-flant5-xl"
- FLAN_T5 = "google/flan-t5-xxl"
- FLAN_T5_11B = "google/flan-t5-xxl"
- FLAN_T5_3B = "google/flan-t5-xl"
- FLAN_UL2 = "google/flan-ul2"
- FLAN_UL2_20B = "google/flan-ul2"
- GPT_JT_6B_V1 = "togethercomputer/gpt-jt-6b-v1"
- GPT_NEOX_20B = "eleutherai/gpt-neox-20b"
- MT0 = "bigscience/mt0-xxl"
- MT0_13B = "bigscience/mt0-xxl"
- UL2 = "google/ul2"
- UL2_20B = "google/ul2"
diff --git a/src/genai/schemas/responses.py b/src/genai/schemas/responses.py
index e0c6c871..228cd007 100644
--- a/src/genai/schemas/responses.py
+++ b/src/genai/schemas/responses.py
@@ -6,7 +6,6 @@
from pydantic import BaseModel, Extra, root_validator
from genai.schemas.generate_params import GenerateParams
-from genai.schemas.models import ModelType
logger = logging.getLogger(__name__)
@@ -76,7 +75,7 @@ class GenerateResult(GenAiResponseModel):
class GenerateResponse(GenAiResponseModel):
- model_id: Union[ModelType, str]
+ model_id: str
created_at: datetime
results: List[GenerateResult]
@@ -98,7 +97,7 @@ class TokenizeResult(GenAiResponseModel):
class TokenizeResponse(GenAiResponseModel):
- model_id: Union[ModelType, str]
+ model_id: str
created_at: datetime
results: List[TokenizeResult]
diff --git a/src/genai/services/async_generator.py b/src/genai/services/async_generator.py
index 31e0e1d1..cd4f4ab9 100644
--- a/src/genai/services/async_generator.py
+++ b/src/genai/services/async_generator.py
@@ -21,7 +21,7 @@ def __init__(
"""Instantiates the ConcurrentWrapper Interface.
Args:
- model_id (ModelType): The type of model to use
+ model_id (str): The type of model to use
prompts (list): List of prompts
params (GenerateParams): Parameters to use during generate requests
service (ServiceInterface): The service interface
diff --git a/tests/extensions/test_langchain.py b/tests/extensions/test_langchain.py
index 228b05dd..96a9da0b 100644
--- a/tests/extensions/test_langchain.py
+++ b/tests/extensions/test_langchain.py
@@ -3,7 +3,7 @@
import pytest
from genai import Credentials
-from genai.schemas import GenerateParams, ModelType
+from genai.schemas import GenerateParams
from genai.schemas.responses import GenerateResponse
from genai.services import ServiceInterface
from tests.assets.response_helper import SimpleResponse
@@ -32,14 +32,14 @@ def prompts(self):
def test_langchain_interface(self, mocked_post_request, credentials, params, prompts):
from genai.extensions.langchain import LangChainInterface
- GENERATE_RESPONSE = SimpleResponse.generate(model=ModelType.FLAN_UL2, inputs=prompts, params=params)
+ GENERATE_RESPONSE = SimpleResponse.generate(model="google/flan-ul2", inputs=prompts, params=params)
expected_generated_response = GenerateResponse(**GENERATE_RESPONSE)
response = MagicMock(status_code=200)
response.json.return_value = GENERATE_RESPONSE
mocked_post_request.return_value = response
- model = LangChainInterface(model=ModelType.FLAN_UL2, params=params, credentials=credentials)
+ model = LangChainInterface(model="google/flan-ul2", params=params, credentials=credentials)
observed = model(prompts[0])
assert observed == expected_generated_response.results[0].generated_text
diff --git a/tests/test_logging.py b/tests/test_logging.py
index 7f36bc48..2860bcd2 100644
--- a/tests/test_logging.py
+++ b/tests/test_logging.py
@@ -3,7 +3,7 @@
import pytest
from genai import Credentials, Model
-from genai.schemas import GenerateParams, ModelType
+from genai.schemas import GenerateParams
logger = logging.getLogger()
logger.addHandler(logging.StreamHandler())
@@ -14,7 +14,7 @@ class TestLogging:
def test_no_leaked_logs(self, caplog):
credentials = Credentials("GENAI_API_KEY")
params = GenerateParams()
- Model(ModelType.FLAN_UL2, params=params, credentials=credentials)
+ Model("google/flan-ul2", params=params, credentials=credentials)
assert len(caplog.records) == 0
@@ -23,6 +23,6 @@ def test_basic_logs(self, caplog):
credentials = Credentials("GENAI_API_KEY")
params = GenerateParams()
- Model(ModelType.FLAN_UL2, params=params, credentials=credentials)
- # Enums are converted to strings slightly differently across python 3.9, 3.10 and 3.11
- assert any(x in caplog.text for x in [ModelType.FLAN_UL2, ModelType.FLAN_UL2.value, "ModelType.FLAN_UL2"])
+ Model("google/flan-ul2", params=params, credentials=credentials)
+
+ assert "google/flan-ul2" in caplog.text
diff --git a/tests/test_model.py b/tests/test_model.py
index cc389916..b90be5d2 100644
--- a/tests/test_model.py
+++ b/tests/test_model.py
@@ -4,7 +4,7 @@
from genai import Credentials, Model
from genai.exceptions import GenAiException
-from genai.schemas import GenerateParams, ModelType
+from genai.schemas import GenerateParams
from genai.schemas.responses import (
GenerateResponse,
ModelCard,
@@ -48,14 +48,14 @@ def prompts(self):
def test_generate(self, mocked_post_request, credentials, params, prompts):
"""Tests that we can call the generate endpoint"""
- GENERATE_RESPONSE = SimpleResponse.generate(model=ModelType.FLAN_UL2, inputs=prompts, params=params)
+ GENERATE_RESPONSE = SimpleResponse.generate(model="google/flan-ul2", inputs=prompts, params=params)
expected_generated_response = GenerateResponse(**GENERATE_RESPONSE)
response = MagicMock(status_code=200)
response.json.return_value = GENERATE_RESPONSE
mocked_post_request.return_value = response
- model = Model(ModelType.FLAN_UL2, params=params, credentials=credentials)
+ model = Model("google/flan-ul2", params=params, credentials=credentials)
responses = model.generate_as_completed(prompts=prompts)
responses_list = list(responses)
@@ -70,7 +70,7 @@ def test_generate_throws_exception_for_non_200(self, mock_service_generate, cred
mock_service_generate.return_value = MagicMock(status_code=500)
- model = Model(ModelType.FLAN_UL2, params=params, credentials=credentials)
+ model = Model("google/flan-ul2", params=params, credentials=credentials)
with pytest.raises(GenAiException):
model.generate(prompts=prompts)
@@ -78,7 +78,7 @@ def test_generate_throws_exception_for_non_200(self, mock_service_generate, cred
@patch("genai.services.RequestHandler.post", side_effect=Exception("some general error"))
def test_generate_throws_exception_for_generic_exception(self, credentials, params, prompts):
"""Tests that the GenAiException is thrown if a generic Exception is raised"""
- model = Model(ModelType.FLAN_UL2, params=params, credentials=credentials)
+ model = Model("google/flan-ul2", params=params, credentials=credentials)
with pytest.raises(GenAiException, match="some general error"):
model.generate(prompts=prompts)
@@ -87,14 +87,14 @@ def test_generate_throws_exception_for_generic_exception(self, credentials, para
def test_tokenize(self, mocked_post_request, credentials, params):
"""Tests that we can call the tokenize endpoint"""
- TOKENIZE_RESPONSE = SimpleResponse.tokenize(model=ModelType.FLAN_UL2, inputs=["a", "b", "c"])
+ TOKENIZE_RESPONSE = SimpleResponse.tokenize(model="google/flan-ul2", inputs=["a", "b", "c"])
expected_token_response = TokenizeResponse(**TOKENIZE_RESPONSE)
mock_response = MagicMock(status_code=200)
mock_response.json.return_value = TOKENIZE_RESPONSE
mocked_post_request.return_value = mock_response
- model = Model(ModelType.FLAN_UL2, params=params, credentials=credentials)
+ model = Model("google/flan-ul2", params=params, credentials=credentials)
responses = model.tokenize(["a", "b", "c"], False)
diff --git a/tests/test_model_async.py b/tests/test_model_async.py
index c27d049b..712e0408 100644
--- a/tests/test_model_async.py
+++ b/tests/test_model_async.py
@@ -3,7 +3,7 @@
import pytest
from genai import Credentials, Model
-from genai.schemas import GenerateParams, ModelType, TokenParams
+from genai.schemas import GenerateParams, TokenParams
from genai.schemas.responses import GenerateResponse, TokenizeResponse
from genai.services import ServiceInterface
from tests.assets.response_helper import SimpleResponse
@@ -46,7 +46,7 @@ async def test_generate_async(self, mock_generate_json, generate_params):
creds = Credentials("TEST_API_KEY")
mock_generate_json.side_effect = expected
- model = Model(ModelType.FLAN_UL2, params=generate_params, credentials=creds)
+ model = Model("google/flan-ul2", params=generate_params, credentials=creds)
counter = 0
responses = list(model.generate_async(prompts))
@@ -67,7 +67,7 @@ async def test_tokenize_async(self, mock_tokenize_json, tokenize_params):
creds = Credentials("TEST_API_KEY")
mock_tokenize_json.side_effect = expected
- model = Model(ModelType.FLAN_UL2, params=tokenize_params, credentials=creds)
+ model = Model("google/flan-ul2", params=tokenize_params, credentials=creds)
counter = 0
responses = list(model.tokenize_async(prompts))
@@ -92,7 +92,7 @@ def tasks_completed(result):
message += result.generated_text
prompts = ["TEST_PROMPT"] * num_prompts
- model = Model(ModelType.FLAN_UL2, params=generate_params, credentials=creds)
+ model = Model("google/flan-ul2", params=generate_params, credentials=creds)
for result in model.generate_async(prompts, callback=tasks_completed):
pass
@@ -114,7 +114,7 @@ def tasks_completed(result):
message += result.tokens
prompts = ["TEST_PROMPT"] * num_prompts
- model = Model(ModelType.FLAN_UL2, params=tokenize_params, credentials=creds)
+ model = Model("google/flan-ul2", params=tokenize_params, credentials=creds)
for result in model.tokenize_async(prompts, callback=tasks_completed):
pass