Skip to content

Commit

Permalink
langchain[patch]: Cerebrium model_api_request deprecation (langchain-…
Browse files Browse the repository at this point in the history
…ai#12704)

- **Description:** As part of my conversation with Cerebrium team,
`model_api_request` will be no longer available in cerebrium lib so it
needs to be replaced.
  - **Issue:** langchain-ai#12705 12705,
  - **Dependencies:** Cerebrium team (agreed)
  - **Tag maintainer:** @eyurtsev 
  - **Twitter handle:** No official Twitter account sorry :D

---------

Co-authored-by: Bagatur <baskaryan@gmail.com>
  • Loading branch information
2 people authored and aymeric-roucher committed Dec 11, 2023
1 parent 983c6c1 commit 1343efb
Showing 1 changed file with 18 additions and 20 deletions.
38 changes: 18 additions & 20 deletions libs/langchain/langchain/llms/cerebriumai.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import logging
from typing import Any, Dict, List, Mapping, Optional

import requests
from langchain_core.pydantic_v1 import Extra, Field, root_validator

from langchain.callbacks.manager import CallbackManagerForLLMRun
Expand Down Expand Up @@ -89,24 +90,21 @@ def _call(
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
"""Call to CerebriumAI endpoint."""
try:
from cerebrium import model_api_request
except ImportError:
raise ValueError(
"Could not import cerebrium python package. "
"Please install it with `pip install cerebrium`."
)

headers: Dict = {
"Authorization": self.cerebriumai_api_key,
"Content-Type": "application/json",
}
params = self.model_kwargs or {}
response = model_api_request(
self.endpoint_url,
{"prompt": prompt, **params, **kwargs},
self.cerebriumai_api_key,
)
text = response["data"]["result"]
if stop is not None:
# I believe this is required since the stop tokens
# are not enforced by the model parameters
text = enforce_stop_tokens(text, stop)
return text
payload = {"prompt": prompt, **params, **kwargs}
response = requests.post(self.endpoint_url, json=payload, headers=headers)
if response.status_code == 200:
data = response.json()
text = data["result"]
if stop is not None:
# I believe this is required since the stop tokens
# are not enforced by the model parameters
text = enforce_stop_tokens(text, stop)
return text
else:
response.raise_for_status()
return ""

0 comments on commit 1343efb

Please sign in to comment.