From 1343efb001072fc202f973dd7f7a7272ee960a0e Mon Sep 17 00:00:00 2001 From: geret1 <75851744+geret1@users.noreply.github.com> Date: Mon, 4 Dec 2023 18:26:32 +0100 Subject: [PATCH] langchain[patch]: Cerebrium model_api_request deprecation (#12704) - **Description:** As part of my conversation with Cerebrium team, `model_api_request` will be no longer available in cerebrium lib so it needs to be replaced. - **Issue:** #12705 12705, - **Dependencies:** Cerebrium team (agreed) - **Tag maintainer:** @eyurtsev - **Twitter handle:** No official Twitter account sorry :D --------- Co-authored-by: Bagatur --- libs/langchain/langchain/llms/cerebriumai.py | 38 ++++++++++---------- 1 file changed, 18 insertions(+), 20 deletions(-) diff --git a/libs/langchain/langchain/llms/cerebriumai.py b/libs/langchain/langchain/llms/cerebriumai.py index 00fe2c1683f162..0a162f5dfeaa4f 100644 --- a/libs/langchain/langchain/llms/cerebriumai.py +++ b/libs/langchain/langchain/llms/cerebriumai.py @@ -1,6 +1,7 @@ import logging from typing import Any, Dict, List, Mapping, Optional +import requests from langchain_core.pydantic_v1 import Extra, Field, root_validator from langchain.callbacks.manager import CallbackManagerForLLMRun @@ -89,24 +90,21 @@ def _call( run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any, ) -> str: - """Call to CerebriumAI endpoint.""" - try: - from cerebrium import model_api_request - except ImportError: - raise ValueError( - "Could not import cerebrium python package. " - "Please install it with `pip install cerebrium`." - ) - + headers: Dict = { + "Authorization": self.cerebriumai_api_key, + "Content-Type": "application/json", + } params = self.model_kwargs or {} - response = model_api_request( - self.endpoint_url, - {"prompt": prompt, **params, **kwargs}, - self.cerebriumai_api_key, - ) - text = response["data"]["result"] - if stop is not None: - # I believe this is required since the stop tokens - # are not enforced by the model parameters - text = enforce_stop_tokens(text, stop) - return text + payload = {"prompt": prompt, **params, **kwargs} + response = requests.post(self.endpoint_url, json=payload, headers=headers) + if response.status_code == 200: + data = response.json() + text = data["result"] + if stop is not None: + # I believe this is required since the stop tokens + # are not enforced by the model parameters + text = enforce_stop_tokens(text, stop) + return text + else: + response.raise_for_status() + return ""