-
Notifications
You must be signed in to change notification settings - Fork 99
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat(langchain): make models serializable (#224)
Closes: #222
- Loading branch information
Showing
7 changed files
with
125 additions
and
42 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
import os | ||
import tempfile | ||
|
||
from dotenv import load_dotenv | ||
|
||
from genai.credentials import Credentials | ||
from genai.extensions.langchain import LangChainInterface | ||
from genai.schemas import GenerateParams, ReturnOptions | ||
|
||
# make sure you have a .env file under genai root with | ||
# GENAI_KEY=<your-genai-key> | ||
load_dotenv() | ||
api_key = os.getenv("GENAI_KEY", None) | ||
api_endpoint = os.getenv("GENAI_API", None) | ||
credentials = Credentials(api_key, api_endpoint) | ||
|
||
llm = LangChainInterface( | ||
model="google/flan-ul2", | ||
credentials=credentials, | ||
params=GenerateParams( | ||
decoding_method="sample", | ||
max_new_tokens=10, | ||
min_new_tokens=1, | ||
stream=True, | ||
temperature=0.5, | ||
top_k=50, | ||
top_p=1, | ||
return_options=ReturnOptions(generated_tokens=True, token_logprobs=True, input_tokens=True), | ||
), | ||
) | ||
|
||
with tempfile.NamedTemporaryFile(suffix=".json") as tmp: | ||
print(f"Serializing LLM instance into '{tmp.name}'") | ||
llm.save(tmp.name) | ||
print(f"Loading serialized instance from '{tmp.name}'") | ||
llm_new = LangChainInterface.load_from_file(file=tmp.name, credentials=credentials) | ||
print("Comparing old instance with the new instance") | ||
assert llm == llm_new | ||
print(f"Done, removing '{tmp.name}'") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,37 +1,28 @@ | ||
import re | ||
from typing import Optional | ||
from warnings import warn | ||
|
||
from pydantic import BaseModel, Field, field_validator | ||
|
||
class Credentials: | ||
DEFAULT_API = "https://workbench-api.res.ibm.com" | ||
|
||
def __init__( | ||
self, | ||
api_key: str, | ||
api_endpoint: str = DEFAULT_API, | ||
): | ||
""" | ||
Instantiate the credentials object | ||
class Credentials(BaseModel): | ||
api_key: str = Field(..., description="The GENAI API Key") | ||
api_endpoint: str = Field(..., description="GENAI API Endpoint") | ||
|
||
Args: | ||
api_key (str): The GENAI API Key | ||
api_endpoint (str, optional): GENAI API Endpoint. Defaults to DEFAULT_API. | ||
""" | ||
def __init__(self, api_key: str, api_endpoint: Optional[str] = "https://workbench-api.res.ibm.com", **kwargs): | ||
if api_key is None: | ||
raise ValueError("api_key must be provided") | ||
self.api_key = api_key | ||
if api_endpoint is None: | ||
raise ValueError("api_endpoint must be provided") | ||
self.api_endpoint = api_endpoint.rstrip("/") | ||
self._remove_api_endpoint_version() | ||
super().__init__(api_key=api_key, api_endpoint=api_endpoint, **kwargs) | ||
|
||
def _remove_api_endpoint_version(self) -> None: | ||
[api, *version] = re.split(r"(/v\d+$)", self.api_endpoint, maxsplit=1) | ||
@field_validator("api_endpoint", mode="after") | ||
def format_api_endpoint(cls, value: str): | ||
[api, *version] = re.split(r"(/v\d+$)", value.rstrip("/"), maxsplit=1) | ||
if version: | ||
warn( | ||
DeprecationWarning( | ||
f"The 'api_endpoint' property should not contain any explicit API version" | ||
f"(rename it from '{self.api_endpoint}' to just '{api}')" | ||
) | ||
f"The 'api_endpoint' property should not contain any explicit API version" | ||
f"(rename it from '{value}' to just '{api}')", | ||
DeprecationWarning, | ||
) | ||
self.api_endpoint = api | ||
return api |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters