diff --git a/examples/user/langchain_chat_generate.py b/examples/user/langchain_chat_generate.py index 63bcc0ce..f58bb1d6 100644 --- a/examples/user/langchain_chat_generate.py +++ b/examples/user/langchain_chat_generate.py @@ -5,8 +5,8 @@ from genai.credentials import Credentials from genai.extensions.langchain.chat_llm import LangChainChatInterface -from genai.schemas import GenerateParams -from genai.schemas.generate_params import ChatOptions +from genai.schemas import ChatOptions, GenerateParams, ReturnOptions +from genai.schemas.generate_params import HAPOptions, ModerationsOptions # make sure you have a .env file under genai root with # GENAI_KEY= @@ -25,6 +25,13 @@ temperature=0.5, top_k=50, top_p=1, + stream=True, + return_options=ReturnOptions(input_text=False, input_tokens=True), + moderations=ModerationsOptions( + # Threshold is set to very low level to flag everything (testing purposes) + # or set to True to enable HAP with default settings + hap=HAPOptions(input=True, output=False, threshold=0.01) + ), ), ) @@ -50,6 +57,8 @@ conversation_id = result.generations[0][0].generation_info["meta"]["conversation_id"] print(f"New conversation with ID '{conversation_id}' has been created!") print(f"Response: {result.generations[0][0].text}") +print(result.llm_output) +print(result.generations[0][0].generation_info) prompt = "Show me some simple code example." print(f"Request: {prompt}") diff --git a/examples/user/langchain_evaluator.py b/examples/user/langchain_evaluator.py new file mode 100644 index 00000000..01865493 --- /dev/null +++ b/examples/user/langchain_evaluator.py @@ -0,0 +1,34 @@ +import os + +from dotenv import load_dotenv +from langchain.evaluation import EvaluatorType, load_evaluator + +from genai.credentials import Credentials +from genai.extensions.langchain import LangChainChatInterface +from genai.schemas import GenerateParams + +# make sure you have a .env file under genai root with +# GENAI_KEY= +load_dotenv() +api_key = os.getenv("GENAI_KEY", None) +api_endpoint = os.getenv("GENAI_API", None) +credentials = Credentials(api_key, api_endpoint=api_endpoint) + +# Load a trajectory (conversation) evaluator +llm = LangChainChatInterface( + model="meta-llama/llama-2-70b-chat", + credentials=credentials, + params=GenerateParams( + decoding_method="sample", + min_new_tokens=1, + max_new_tokens=100, + length_penalty={ + "decay_factor": 1.5, + "start_index": 50, + }, + temperature=1.2, + stop_sequences=["<|endoftext|>", "}]"], + ), +) +evaluator = load_evaluator(evaluator=EvaluatorType.AGENT_TRAJECTORY, llm=llm) +print(evaluator) diff --git a/examples/user/langchain_generate.py b/examples/user/langchain_generate.py index f7a465bf..6c7a06d9 100644 --- a/examples/user/langchain_generate.py +++ b/examples/user/langchain_generate.py @@ -7,7 +7,8 @@ from genai.credentials import Credentials from genai.extensions.langchain import LangChainInterface -from genai.schemas import GenerateParams, ReturnOptions +from genai.schemas import GenerateParams +from genai.schemas.generate_params import HAPOptions, ModerationsOptions # make sure you have a .env file under genai root with # GENAI_KEY= @@ -29,7 +30,7 @@ def on_llm_new_token( llm = LangChainInterface( - model="google/flan-ul2", + model="google/flan-t5-xl", credentials=Credentials(api_key, api_endpoint), params=GenerateParams( decoding_method="sample", @@ -39,7 +40,11 @@ def on_llm_new_token( temperature=0.5, top_k=50, top_p=1, - return_options=ReturnOptions(generated_tokens=True, token_logprobs=True, input_tokens=True), + moderations=ModerationsOptions( + # Threshold is set to very low level to flag everything (testing purposes) + # or set to True to enable HAP with default settings + hap=HAPOptions(input=True, output=True, threshold=0.01) + ), ), ) @@ -47,4 +52,6 @@ def on_llm_new_token( prompts=["Tell me about IBM."], callbacks=[Callback()], ) -print(result) +print(f"Response: {result.generations[0][0].text}") +print(result.llm_output) +print(result.generations[0][0].generation_info) diff --git a/examples/user/llama_index_stream.py b/examples/user/llama_index_stream.py new file mode 100644 index 00000000..b86219b2 --- /dev/null +++ b/examples/user/llama_index_stream.py @@ -0,0 +1,26 @@ +import os + +from dotenv import load_dotenv +from llama_index.llms import LangChainLLM + +from genai import Credentials +from genai.extensions.langchain import LangChainInterface +from genai.schemas import GenerateParams + +load_dotenv() +api_key = os.environ.get("GENAI_KEY") +api_url = os.environ.get("GENAI_API") +langchain_model = LangChainInterface( + model="meta-llama/llama-2-70b-chat", + credentials=Credentials(api_key, api_endpoint=api_url), + params=GenerateParams( + decoding_method="sample", + min_new_tokens=1, + max_new_tokens=10, + ), +) + +llm = LangChainLLM(llm=langchain_model) +response_gen = llm.stream_chat("What is a molecule?") +for delta in response_gen: + print(delta.delta) diff --git a/src/genai/extensions/langchain/chat_llm.py b/src/genai/extensions/langchain/chat_llm.py index 9ca87ba6..0c558be8 100644 --- a/src/genai/extensions/langchain/chat_llm.py +++ b/src/genai/extensions/langchain/chat_llm.py @@ -6,7 +6,6 @@ from pydantic import ConfigDict from genai import Credentials, Model -from genai.exceptions import GenAiException from genai.schemas import GenerateParams from genai.schemas.chat import AIMessage, BaseMessage, HumanMessage, SystemMessage from genai.schemas.generate_params import ChatOptions @@ -27,16 +26,14 @@ from .utils import ( create_generation_info_from_response, create_llm_output, - extract_token_usage, load_config, - update_token_usage, + update_token_usage_stream, ) except ImportError: raise ImportError("Could not import langchain: Please install ibm-generative-ai[langchain] extension.") __all__ = ["LangChainChatInterface"] - logger = logging.getLogger(__name__) Message = Union[LCBaseMessage, BaseMessage] @@ -77,6 +74,7 @@ class LangChainChatInterface(BaseChatModel): model: str params: Optional[GenerateParams] = None model_config = ConfigDict(extra="forbid", protected_namespaces=()) + streaming: Optional[bool] = None @classmethod def is_lc_serializable(cls) -> bool: @@ -123,20 +121,33 @@ def _stream( model = Model(self.model, params=params, credentials=self.credentials) stream = model.chat_stream(messages=convert_messages_to_genai(messages), options=options, **kwargs) + conversation_id: Optional[str] = None for response in stream: - result = response.results[0] if response else None - if not result: + if not response: continue - generated_text = result.generated_text or "" - generation_info = create_generation_info_from_response(response) - chunk = ChatGenerationChunk( - message=LCAIMessageChunk(content=generated_text, generation_info=generation_info), - generation_info=generation_info, - ) - yield chunk - if run_manager: - run_manager.on_llm_new_token(token=generated_text, chunk=chunk, response=response) + def send_chunk(*, text: str = "", generation_info: dict): + logger.info("Chunk received: {}".format(text)) + chunk = ChatGenerationChunk( + message=LCAIMessageChunk(content=text, generation_info=generation_info), + generation_info=generation_info, + ) + yield chunk + if run_manager: + run_manager.on_llm_new_token(token=text, chunk=chunk, response=response) + + if not conversation_id: + conversation_id = response.conversation_id + else: + response.conversation_id = conversation_id + + if response.moderation: + generation_info = create_generation_info_from_response(response, result=response.moderation) + yield from send_chunk(generation_info=generation_info) + + for result in response.results or []: + generation_info = create_generation_info_from_response(response, result=result) + yield from send_chunk(text=result.generated_text or "", generation_info=generation_info) def _generate( self, @@ -149,28 +160,63 @@ def _generate( ) -> ChatResult: params = to_model_instance(self.params, GenerateParams) params.stop_sequences = stop or params.stop_sequences - - model = Model(self.model, params=params, credentials=self.credentials) - response = model.chat(messages=convert_messages_to_genai(messages), options=options, **kwargs) - result = response.results[0] - assert result - - message = LCAIMessage(content=result.generated_text or "") + params.stream = params.stream or self.streaming + + def handle_stream(): + final_generation: Optional[ChatGenerationChunk] = None + for result in self._stream( + messages=messages, + stop=stop, + run_manager=run_manager, + options=options, + **kwargs, + ): + if final_generation: + token_usage = result.generation_info.pop("token_usage") + final_generation += result + update_token_usage_stream( + target=final_generation.generation_info["token_usage"], + source=token_usage, + ) + else: + final_generation = result + + assert final_generation and final_generation.generation_info + return { + "text": final_generation.text, + "generation_info": final_generation.generation_info.copy(), + } + + def handle_non_stream(): + model = Model(self.model, params=params, credentials=self.credentials) + response = model.chat(messages=convert_messages_to_genai(messages), options=options, **kwargs) + + assert response.results + result = response.results[0] + + return { + "text": result.generated_text or "", + "generation_info": create_generation_info_from_response(response, result=result), + } + + result = handle_stream() if params.stream else handle_non_stream() return ChatResult( generations=[ - ChatGeneration(message=message, generation_info=create_generation_info_from_response(response)) + ChatGeneration( + message=LCAIMessage(content=result["text"]), + generation_info=result["generation_info"].copy(), + ) ], llm_output=create_llm_output( model=self.model, - token_usage=extract_token_usage(result.model_dump()), + token_usages=[result["generation_info"]["token_usage"]], ), ) def get_num_tokens(self, text: str) -> int: model = Model(self.model, params=self.params, credentials=self.credentials) response = model.tokenize([text], return_tokens=False)[0] - if response.token_count is None: - raise GenAiException("Invalid tokenize result!") + assert response.token_count is not None return response.token_count def get_num_tokens_from_messages(self, messages: list[LCBaseMessage]) -> int: @@ -179,8 +225,5 @@ def get_num_tokens_from_messages(self, messages: list[LCBaseMessage]) -> int: return sum([response.token_count for response in responses if response.token_count]) def _combine_llm_outputs(self, llm_outputs: list[Optional[dict]]) -> dict: - overall_token_usage: dict = extract_token_usage({}) - update_token_usage( - target=overall_token_usage, sources=[output.get("token_usage") for output in llm_outputs if output] - ) - return {"model_name": self.model, "token_usage": overall_token_usage} + token_usages = [output.get("token_usage") for output in llm_outputs if output] + return create_llm_output(model=self.model, token_usages=token_usages) diff --git a/src/genai/extensions/langchain/llm.py b/src/genai/extensions/langchain/llm.py index 08308d52..85d6113a 100644 --- a/src/genai/extensions/langchain/llm.py +++ b/src/genai/extensions/langchain/llm.py @@ -20,11 +20,12 @@ from langchain.schema.output import GenerationChunk from .utils import ( - create_generation_info_from_result, + create_generation_info_from_response, create_llm_output, + extract_token_usage, load_config, - update_llm_result, update_token_usage, + update_token_usage_stream, ) except ImportError: raise ImportError("Could not import langchain: Please install ibm-generative-ai[langchain] extension.") @@ -54,6 +55,7 @@ class LangChainInterface(LLM): model: str params: Optional[GenerateParams] = None model_config = ConfigDict(extra="forbid", protected_namespaces=()) + streaming: Optional[bool] = None @property def _identifying_params(self) -> Mapping[str, Any]: @@ -116,41 +118,50 @@ def _generate( params = to_model_instance(self.params, GenerateParams) params.stop_sequences = stop or params.stop_sequences + params.stream = params.stream or self.streaming if params.stream: if len(prompts) != 1: raise GenAiException(ValueError("Streaming works only for a single prompt.")) - generation = GenerationChunk(text="", generation_info={}) + generation = GenerationChunk(text="", generation_info={"token_usage": extract_token_usage({})}) + for result in self._stream( prompt=prompts[0], stop=params.stop_sequences, run_manager=run_manager, **kwargs, ): - chunk = GenerationChunk(text=result.text) - if result.generation_info: - update_token_usage( - target=final_result.llm_output, sources=[result.generation_info.get("token_usage")] - ) - chunk.generation_info = result.generation_info.copy() - generation += chunk + token_usage = result.generation_info.pop("token_usage") + generation += result + update_token_usage_stream( + target=generation.generation_info["token_usage"], + source=token_usage, + ) final_result.generations.append([generation]) + update_token_usage( + target=final_result.llm_output["token_usage"], source=generation.generation_info["token_usage"] + ) + return final_result model = Model(model=self.model, params=params, credentials=self.credentials) - for response in model.generate( + for response in model.generate_as_completed( prompts=prompts, **kwargs, + raw_response=True, ): - chunk = GenerationChunk( - text=response.generated_text or "", - ) - logger.info("Output of GENAI call: {}".format(chunk.text)) - chunk.generation_info = create_generation_info_from_result(response) - update_llm_result(current=final_result, generation_info=chunk.generation_info) - final_result.generations.append([chunk]) + for result in response.results: + generation_info = create_generation_info_from_response(response, result=result) + + chunk = GenerationChunk( + text=result.generated_text or "", + generation_info=generation_info, + ) + logger.info("Output of GENAI call: {}".format(chunk.text)) + update_token_usage(target=final_result.llm_output["token_usage"], source=generation_info["token_usage"]) + final_result.generations.append([chunk]) return final_result @@ -189,12 +200,22 @@ def _stream( params.stop_sequences = stop or params.stop_sequences model = Model(model=self.model, params=params, credentials=self.credentials) - for result in model.generate_stream(prompts=[prompt], **kwargs): - logger.info("Chunk received: {}".format(result.generated_text)) - chunk = GenerationChunk( - text=result.generated_text or "", - generation_info=create_generation_info_from_result(result), - ) - yield chunk - if run_manager: - run_manager.on_llm_new_token(token=chunk.text, chunk=chunk, response=result) + for response in model.generate_stream(prompts=[prompt], raw_response=True, **kwargs): + + def send_chunk(*, text: Optional[str] = None, generation_info: dict): + logger.info("Chunk received: {}".format(text)) + chunk = GenerationChunk( + text=text or "", + generation_info=generation_info.copy(), + ) + yield chunk + if run_manager: + run_manager.on_llm_new_token(token=chunk.text, chunk=chunk, response=response) + + if response.moderation: + generation_info = create_generation_info_from_response(response, result=response.moderation) + yield from send_chunk(generation_info=generation_info) + + for result in response.results or []: + generation_info = create_generation_info_from_response(response, result=result) + yield from send_chunk(text=result.generated_text, generation_info=generation_info) diff --git a/src/genai/extensions/langchain/utils.py b/src/genai/extensions/langchain/utils.py index 89a6bf67..94afd211 100644 --- a/src/genai/extensions/langchain/utils.py +++ b/src/genai/extensions/langchain/utils.py @@ -1,7 +1,6 @@ from pathlib import Path from typing import Any, Optional, Union -from langchain.schema import LLMResult from pydantic import BaseModel from genai.schemas import GenerateParams @@ -17,54 +16,73 @@ def extract_token_usage(result: dict[str, Any]): def get_count_value(key: str) -> int: return result.get(key, 0) or 0 + input_token_count = get_count_value("input_token_count") + generated_token_count = get_count_value("generated_token_count") + return { - "prompt_tokens": get_count_value("input_token_count"), - "completion_tokens": get_count_value("generated_token_count"), - "total_tokens": get_count_value("input_token_count") + get_count_value("generated_token_count"), + "prompt_tokens": input_token_count, + "completion_tokens": generated_token_count, + "total_tokens": input_token_count + generated_token_count, # For backward compatibility - "generated_token_count": get_count_value("generated_token_count"), - "input_token_count": get_count_value("input_token_count"), + "generated_token_count": generated_token_count, + "input_token_count": input_token_count, } -def update_token_usage(*, target: dict[str, Any], sources: list[Optional[dict[str, Any]]]): - for source in sources: - if not source: - continue +def update_token_usage(*, target: dict[str, Any], source: Optional[dict[str, Any]]): + if not source: + return + + for key, value in extract_token_usage(source).items(): + if key in target: + target[key] += value + else: + target[key] = value - for key, value in extract_token_usage(source).items(): - if key in target: - target[key] += value - else: - target[key] = value +def update_token_usage_stream(*, target: dict[str, Any], source: Optional[dict]): + if not source: + return -def update_llm_result(current: LLMResult, generation_info: dict[str, Any]): - if current.llm_output is None: - current.llm_output = {} + def get_value(key: str, override=False) -> int: + current = target.get(key, 0) or 0 + new = source.get(key, 0) or 0 - if not current.llm_output["token_usage"]: - current.llm_output["token_usage"] = {} + if new != 0 and (current == 0 or override): + return new + else: + return current - token_usage = current.llm_output["token_usage"] - update_token_usage(target=token_usage, sources=[generation_info]) + completion_tokens = get_value("completion_tokens", override=True) + prompt_tokens = get_value("prompt_tokens") + target.update( + { + "prompt_tokens": prompt_tokens, + "input_token_count": prompt_tokens, + "completion_tokens": completion_tokens, + "total_tokens": completion_tokens + prompt_tokens, + "generated_token_count": completion_tokens, + } + ) def create_generation_info_from_response( - response: Union[GenerateResponse, GenerateStreamResponse, ChatResponse, ChatStreamResponse] + response: Union[GenerateResponse, GenerateStreamResponse, ChatResponse, ChatStreamResponse], + *, + result: BaseModel, ) -> dict[str, Any]: - result = response.results[0] - return {"meta": response.model_dump(exclude={"results", "model_id"}), **create_generation_info_from_result(result)} - - -def create_generation_info_from_result(source: Union[BaseModel, dict]) -> dict: - iterator = source if isinstance(source, BaseModel) else source.items() - return {k: v for k, v in iterator if k not in {"generated_text"} and v is not None} + result_meta = result.model_dump(exclude={"generated_text"}, exclude_none=True) + return { + "meta": response.model_dump(exclude={"results", "model_id", "moderation"}, exclude_none=True), + "token_usage": extract_token_usage(result_meta), + **result_meta, + } -def create_llm_output(*, model: str, token_usage: Optional[dict] = None, **kwargs) -> dict[str, Any]: +def create_llm_output(*, model: str, token_usages: Optional[list[Optional[dict]]] = None, **kwargs) -> dict[str, Any]: final_token_usage = extract_token_usage({}) - update_token_usage(target=final_token_usage, sources=[token_usage]) + for source in token_usages or []: + update_token_usage(target=final_token_usage, source=source) return {"model_name": model, "token_usage": final_token_usage, **kwargs} diff --git a/src/genai/model.py b/src/genai/model.py index ffc6b0fd..eeb909da 100644 --- a/src/genai/model.py +++ b/src/genai/model.py @@ -2,7 +2,7 @@ import time from collections import deque from collections.abc import Generator -from typing import Any, Callable, List, Optional, Union +from typing import Any, Callable, List, Literal, Optional, Union, overload import httpx from tqdm.auto import tqdm @@ -66,9 +66,33 @@ def __init__( self.creds = credentials self.service = ServiceInterface(service_url=credentials.api_endpoint, api_key=credentials.api_key) + @overload def generate_stream( - self, prompts: Union[list[str], list[PromptPattern]], options: Options = None + self, + prompts: Union[list[str], list[PromptPattern]], + options: Optional[Options] = None, + *, + raw_response: Optional[Literal[False]] = None, ) -> Generator[GenerateStreamResult, None, None]: + ... + + @overload + def generate_stream( + self, + prompts: Union[list[str], list[PromptPattern]], + options: Optional[Options] = None, + *, + raw_response: Literal[True], + ) -> Generator[ApiGenerateStreamResponse, None, None]: + ... + + def generate_stream( + self, + prompts: Union[list[str], list[PromptPattern]], + options: Optional[Options] = None, + *, + raw_response: Optional[bool] = False, + ): if len(prompts) > 0 and isinstance(prompts[0], PromptPattern): prompts = PromptPattern.list_str(prompts) @@ -84,9 +108,14 @@ def generate_stream( logger=logger, ResponseModel=ApiGenerateStreamResponse, ): + if raw_response: + yield response + continue + if response.moderation: yield GenerateStreamResult(moderation=response.moderation) - for result in response.results: + + for result in response.results or []: yield result except Exception as ex: @@ -126,9 +155,33 @@ def chat_stream( ) yield from generation_stream_handler(response_stream, logger=logger, ResponseModel=ChatStreamResponse) + @overload + def generate_as_completed( + self, + prompts: Union[list[str], list[PromptPattern]], + options: Optional[Options] = None, + *, + raw_response: Optional[Literal[False]] = None, + ) -> Generator[GenerateResult, None, None]: + ... + + @overload + def generate_as_completed( + self, + prompts: Union[list[str], list[PromptPattern]], + options: Optional[Options] = None, + *, + raw_response: Literal[True], + ) -> Generator[GenerateResponse, None, None]: + ... + def generate_as_completed( - self, prompts: Union[list[str], list[PromptPattern]], options: Options = None - ) -> Generator[GenerateResult]: + self, + prompts: Union[list[str], list[PromptPattern]], + options: Optional[Options] = None, + *, + raw_response: Optional[bool] = False, + ): """The generate endpoint is the centerpiece of the GENAI alpha. It provides a simplified and flexible, yet powerful interface to the supported models as a service. Given a text prompt as inputs, and required parameters @@ -137,9 +190,7 @@ def generate_as_completed( Args: prompts (list[str]): The list of one or more prompt strings. options (Options, optional): Additional parameters to pass in the query payload. Defaults to None. - - Yields: - Generator[GenerateResult]: A generator of results + raw_response (optional bool): Yields the whole response object instead of it's results. """ if len(prompts) > 0 and isinstance(prompts[0], PromptPattern): prompts = PromptPattern.list_str(prompts) @@ -155,7 +206,7 @@ def get_remaining_limit(): remaining_limit = get_remaining_limit() - def execute(inputs: list, attempt=0) -> List[GenerateResult]: + def execute(inputs: list, attempt=0) -> GenerateResponse: response = self.service.generate( model=self.model, inputs=inputs, @@ -166,8 +217,7 @@ def execute(inputs: list, attempt=0) -> List[GenerateResult]: raw_response = response.json() for i, result in enumerate(raw_response["results"]): result["input_text"] = inputs[i] - generate_response = GenerateResponse(**raw_response) - return generate_response.results + return GenerateResponse(**raw_response) elif ( response.status_code == httpx.codes.TOO_MANY_REQUESTS and attempt < ConnectionManager.MAX_RETRIES_GENERATE @@ -189,12 +239,17 @@ def execute(inputs: list, attempt=0) -> List[GenerateResult]: prompts_to_process = min(remaining_limit, Metadata.DEFAULT_MAX_PROMPTS, len(remaining_prompts)) remaining_limit -= prompts_to_process inputs = [remaining_prompts.popleft() for _ in range(prompts_to_process)] - for result in execute(inputs): - yield result + response = execute(inputs) + if raw_response: + yield response + else: + yield from response.results except Exception as ex: raise to_genai_error(ex) - def generate(self, prompts: Union[list[str], list[PromptPattern]], options: Options = None) -> list[GenerateResult]: + def generate( + self, prompts: Union[list[str], list[PromptPattern]], options: Optional[Options] = None + ) -> list[GenerateResult]: """The generate endpoint is the centerpiece of the GENAI alpha. It provides a simplified and flexible, yet powerful interface to the supported models as a service. Given a text prompt as inputs, and required parameters diff --git a/src/genai/schemas/responses.py b/src/genai/schemas/responses.py index 38e35a33..f086ad5f 100644 --- a/src/genai/schemas/responses.py +++ b/src/genai/schemas/responses.py @@ -305,5 +305,5 @@ class ChatResponse(GenerateResponse): class ChatStreamResponse(GenerateStreamResponse): - conversation_id: str + conversation_id: Optional[str] = None moderation: Optional[ModerationResult] = None diff --git a/src/genai/services/service_interface.py b/src/genai/services/service_interface.py index cc12da46..b827556b 100644 --- a/src/genai/services/service_interface.py +++ b/src/genai/services/service_interface.py @@ -95,9 +95,9 @@ def generate( self, model: str, inputs: list, - params: GenerateParams = None, + params: Optional[GenerateParams] = None, streaming: bool = False, - options: Options = None, + options: Optional[Options] = None, ): """Generate a completion text for the given model, inputs, and params. @@ -114,7 +114,7 @@ def generate( try: parameters: Optional[dict] = sanitize_params(params) endpoint = self.service_url + ServiceInterface.GENERATE - res = RequestHandler.post( + return RequestHandler.post( endpoint, key=self.key, model_id=model, @@ -123,7 +123,6 @@ def generate( streaming=streaming, options=options, ) - return res except Exception as e: raise to_genai_error(e) diff --git a/tests/extensions/test_langchain.py b/tests/extensions/test_langchain.py index 2dffa333..74cd17e7 100644 --- a/tests/extensions/test_langchain.py +++ b/tests/extensions/test_langchain.py @@ -75,8 +75,10 @@ async def test_async_langchain_interface(self, credentials, params, multi_prompt assert generation.text == expected_response.results[idx].generated_text expected_result = expected_response.results[idx].model_dump() - for key in {"generated_token_count", "input_text", "stop_reason"}: + for key in {"input_text", "stop_reason"}: assert generation.generation_info[key] == expected_result[key] + for key in {"input_token_count", "generated_token_count"}: + assert generation.generation_info["token_usage"][key] == expected_result[key] def test_langchain_stream(self, credentials, params, prompts, httpx_mock: HTTPXMock): GENERATE_STREAM_RESPONSES = SimpleResponse.generate_stream( @@ -114,7 +116,7 @@ def test_langchain_stream(self, credentials, params, prompts, httpx_mock: HTTPXM chunk = retrieved_kwargs["chunk"] assert isinstance(chunk, GenerationChunk) response = retrieved_kwargs["response"] - assert response == result.results[0] + assert response == result def test_prompt_translator(self): from langchain.prompts import PromptTemplate diff --git a/tests/extensions/test_langchain_chat.py b/tests/extensions/test_langchain_chat.py index a1f3661c..fade9eb7 100644 --- a/tests/extensions/test_langchain_chat.py +++ b/tests/extensions/test_langchain_chat.py @@ -66,6 +66,13 @@ def test_generate(self, credentials, params, messages, httpx_mock: HTTPXMock): "generated_token_count": expected_result.generated_token_count, "input_token_count": expected_result.input_token_count, "stop_reason": expected_result.stop_reason, + "token_usage": { + "prompt_tokens": expected_result.input_token_count, + "completion_tokens": expected_result.generated_token_count, + "total_tokens": expected_result.generated_token_count + (expected_result.input_token_count or 0), + "generated_token_count": expected_result.generated_token_count, + "input_token_count": expected_result.input_token_count, + }, } assert result.llm_output == { "model_name": self.model, @@ -122,7 +129,9 @@ def test_stream(self, credentials, params, messages, httpx_mock: HTTPXMock): assert isinstance(result, AIMessage) expected_response = expected_generated_responses[idx] assert (result.content or "") == (expected_response.results[0].generated_text or "") - assert result.generation_info == create_generation_info_from_response(expected_response) + assert result.generation_info == create_generation_info_from_response( + expected_response, result=expected_response.results[0] + ) # Verify that callbacks were called assert callback.on_llm_new_token.call_count == len(expected_generated_responses)