Skip to content

Commit

Permalink
feat(langchain): make models serializable (#224)
Browse files Browse the repository at this point in the history
Closes: #222
  • Loading branch information
Tomas2D committed Nov 8, 2023
1 parent e03caf1 commit ef9be36
Show file tree
Hide file tree
Showing 7 changed files with 125 additions and 42 deletions.
3 changes: 2 additions & 1 deletion examples/user/langchain_generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from genai.credentials import Credentials
from genai.extensions.langchain import LangChainInterface
from genai.schemas import GenerateParams
from genai.schemas import GenerateParams, ReturnOptions

# make sure you have a .env file under genai root with
# GENAI_KEY=<your-genai-key>
Expand Down Expand Up @@ -39,6 +39,7 @@ def on_llm_new_token(
temperature=0.5,
top_k=50,
top_p=1,
return_options=ReturnOptions(generated_tokens=True, token_logprobs=True, input_tokens=True),
),
)

Expand Down
39 changes: 39 additions & 0 deletions examples/user/langchain_serialization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import os
import tempfile

from dotenv import load_dotenv

from genai.credentials import Credentials
from genai.extensions.langchain import LangChainInterface
from genai.schemas import GenerateParams, ReturnOptions

# make sure you have a .env file under genai root with
# GENAI_KEY=<your-genai-key>
load_dotenv()
api_key = os.getenv("GENAI_KEY", None)
api_endpoint = os.getenv("GENAI_API", None)
credentials = Credentials(api_key, api_endpoint)

llm = LangChainInterface(
model="google/flan-ul2",
credentials=credentials,
params=GenerateParams(
decoding_method="sample",
max_new_tokens=10,
min_new_tokens=1,
stream=True,
temperature=0.5,
top_k=50,
top_p=1,
return_options=ReturnOptions(generated_tokens=True, token_logprobs=True, input_tokens=True),
),
)

with tempfile.NamedTemporaryFile(suffix=".json") as tmp:
print(f"Serializing LLM instance into '{tmp.name}'")
llm.save(tmp.name)
print(f"Loading serialized instance from '{tmp.name}'")
llm_new = LangChainInterface.load_from_file(file=tmp.name, credentials=credentials)
print("Comparing old instance with the new instance")
assert llm == llm_new
print(f"Done, removing '{tmp.name}'")
37 changes: 14 additions & 23 deletions src/genai/credentials.py
Original file line number Diff line number Diff line change
@@ -1,37 +1,28 @@
import re
from typing import Optional
from warnings import warn

from pydantic import BaseModel, Field, field_validator

class Credentials:
DEFAULT_API = "https://workbench-api.res.ibm.com"

def __init__(
self,
api_key: str,
api_endpoint: str = DEFAULT_API,
):
"""
Instantiate the credentials object
class Credentials(BaseModel):
api_key: str = Field(..., description="The GENAI API Key")
api_endpoint: str = Field(..., description="GENAI API Endpoint")

Args:
api_key (str): The GENAI API Key
api_endpoint (str, optional): GENAI API Endpoint. Defaults to DEFAULT_API.
"""
def __init__(self, api_key: str, api_endpoint: Optional[str] = "https://workbench-api.res.ibm.com", **kwargs):
if api_key is None:
raise ValueError("api_key must be provided")
self.api_key = api_key
if api_endpoint is None:
raise ValueError("api_endpoint must be provided")
self.api_endpoint = api_endpoint.rstrip("/")
self._remove_api_endpoint_version()
super().__init__(api_key=api_key, api_endpoint=api_endpoint, **kwargs)

def _remove_api_endpoint_version(self) -> None:
[api, *version] = re.split(r"(/v\d+$)", self.api_endpoint, maxsplit=1)
@field_validator("api_endpoint", mode="after")
def format_api_endpoint(cls, value: str):
[api, *version] = re.split(r"(/v\d+$)", value.rstrip("/"), maxsplit=1)
if version:
warn(
DeprecationWarning(
f"The 'api_endpoint' property should not contain any explicit API version"
f"(rename it from '{self.api_endpoint}' to just '{api}')"
)
f"The 'api_endpoint' property should not contain any explicit API version"
f"(rename it from '{value}' to just '{api}')",
DeprecationWarning,
)
self.api_endpoint = api
return api
32 changes: 24 additions & 8 deletions src/genai/extensions/langchain/chat_llm.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,12 @@
"""Wrapper around IBM GENAI APIs for use in Langchain"""
import logging
from pathlib import Path
from typing import Any, Dict, Iterator, Optional, Union

from pydantic import ConfigDict

from genai import Credentials, Model
from genai.exceptions import GenAiException
from genai.extensions.langchain.utils import (
create_generation_info_from_response,
create_llm_output,
extract_token_usage,
update_token_usage,
)
from genai.schemas import GenerateParams
from genai.schemas.chat import AIMessage, BaseMessage, HumanMessage, SystemMessage
from genai.schemas.generate_params import ChatOptions
Expand All @@ -28,6 +23,14 @@
from langchain.schema.messages import SystemMessage as LCSystemMessage
from langchain.schema.messages import get_buffer_string
from langchain.schema.output import ChatGeneration, ChatGenerationChunk, ChatResult

from .utils import (
create_generation_info_from_response,
create_llm_output,
extract_token_usage,
load_config,
update_token_usage,
)
except ImportError:
raise ImportError("Could not import langchain: Please install ibm-generative-ai[langchain] extension.")

Expand Down Expand Up @@ -75,6 +78,14 @@ class LangChainChatInterface(BaseChatModel):
params: Optional[GenerateParams] = None
model_config = ConfigDict(extra="forbid", protected_namespaces=())

@classmethod
def is_lc_serializable(cls) -> bool:
return True

@property
def lc_secrets(self) -> Dict[str, str]:
return {"credentials": "CREDENTIALS"}

@property
def _llm_type(self) -> str:
return "Chat IBM GENAI"
Expand All @@ -83,10 +94,15 @@ def _llm_type(self) -> str:
def _identifying_params(self) -> Dict[str, Any]:
_params = to_model_instance(self.params, GenerateParams)
return {
**{"model": self.model},
**{"params": _params},
"model": self.model,
"params": _params.model_dump(),
}

@classmethod
def load_from_file(cls, file: Union[str, Path], *, credentials: Credentials):
config = load_config(file)
return cls(**config, credentials=credentials)

class Config:
"""Configuration for this pydantic object."""

Expand Down
31 changes: 22 additions & 9 deletions src/genai/extensions/langchain/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,25 +2,27 @@
import asyncio
import logging
from functools import partial
from typing import Any, Iterator, List, Mapping, Optional
from pathlib import Path
from typing import Any, Iterator, List, Mapping, Optional, Union

from langchain.callbacks.manager import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from pydantic import ConfigDict

from genai.exceptions import GenAiException
from genai.utils.general import to_model_instance

try:
from langchain.callbacks.manager import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain.llms.base import LLM
from langchain.schema import LLMResult
from langchain.schema.output import GenerationChunk

from .utils import (
create_generation_info_from_result,
create_llm_output,
load_config,
update_llm_result,
update_token_usage,
)
Expand Down Expand Up @@ -57,16 +59,26 @@ class LangChainInterface(LLM):
def _identifying_params(self) -> Mapping[str, Any]:
"""Get the identifying parameters."""
_params = to_model_instance(self.params, GenerateParams)
return {
**{"model": self.model},
**{"params": _params},
}
return {"model": self.model, "params": _params.model_dump()}

@classmethod
def load_from_file(cls, file: Union[str, Path], *, credentials: Credentials):
config = load_config(file)
return cls(**config, credentials=credentials)

@property
def _llm_type(self) -> str:
"""Return type of llm."""
return "IBM GENAI"

@classmethod
def is_lc_serializable(cls) -> bool:
return True

@property
def lc_secrets(self) -> dict[str, str]:
return {"credentials": "CREDENTIALS"}

def _call(
self,
prompt: str,
Expand Down Expand Up @@ -104,6 +116,7 @@ def _generate(

params = to_model_instance(self.params, GenerateParams)
params.stop_sequences = stop or params.stop_sequences

if params.stream:
if len(prompts) != 1:
raise GenAiException(ValueError("Streaming works only for a single prompt."))
Expand Down
23 changes: 23 additions & 0 deletions src/genai/extensions/langchain/utils.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from pathlib import Path
from typing import Any, Optional, Union

from langchain.schema import LLMResult
from pydantic import BaseModel

from genai.schemas import GenerateParams
from genai.schemas.responses import (
ChatResponse,
ChatStreamResponse,
Expand Down Expand Up @@ -64,3 +66,24 @@ def create_llm_output(*, model: str, token_usage: Optional[dict] = None, **kwarg
final_token_usage = extract_token_usage({})
update_token_usage(target=final_token_usage, sources=[token_usage])
return {"model_name": model, "token_usage": final_token_usage, **kwargs}


def load_config(file: Union[str, Path]) -> dict:
def parse_config() -> dict:
file_path = Path(file) if isinstance(file, str) else file
if file_path.suffix == ".json":
with open(file_path) as f:
import json

return json.load(f)
elif file_path.suffix == ".yaml":
with open(file_path, "r") as f:
import yaml

return yaml.safe_load(f)
else:
raise ValueError("File type must be json or yaml")

config = parse_config()
config["params"] = GenerateParams(**config.get("params", {}))
return config
2 changes: 1 addition & 1 deletion src/genai/extensions/localserver/local_api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def __init__(

# Set the API Key
if api_key is None and insecure_api is False:
self.api_key = uuid.uuid4()
self.api_key = str(uuid.uuid4())
elif api_key is None and insecure_api is True:
self.api_key = "test"

Expand Down

0 comments on commit ef9be36

Please sign in to comment.