diff --git a/src/unstract/sdk/__init__.py b/src/unstract/sdk/__init__.py index b3246670..6da01775 100644 --- a/src/unstract/sdk/__init__.py +++ b/src/unstract/sdk/__init__.py @@ -1,4 +1,4 @@ -__version__ = "0.41.0" +__version__ = "0.41.1" def get_sdk_version(): diff --git a/src/unstract/sdk/utils/token_counter.py b/src/unstract/sdk/utils/token_counter.py index 3337f6a0..444b0514 100644 --- a/src/unstract/sdk/utils/token_counter.py +++ b/src/unstract/sdk/utils/token_counter.py @@ -2,6 +2,8 @@ from llama_index.core.callbacks.schema import EventPayload from llama_index.core.utilities.token_counting import TokenCounter +from openai.types import CompletionUsage +from openai.types.chat import ChatCompletion class Constants: @@ -10,6 +12,9 @@ class Constants: KEY_EVAL_COUNT = "eval_count" KEY_PROMPT_EVAL_COUNT = "prompt_eval_count" KEY_RAW_RESPONSE = "_raw_response" + KEY_TEXT_TOKEN_COUNT = "inputTextTokenCount" + KEY_TOKEN_COUNT = "tokenCount" + KEY_RESULTS = "results" INPUT_TOKENS = "input_tokens" OUTPUT_TOKENS = "output_tokens" PROMPT_TOKENS = "prompt_tokens" @@ -32,62 +37,91 @@ def __init__(self, input_tokens, output_tokens): @staticmethod def get_llm_token_counts(payload: dict[str, Any]) -> TokenCounter: - token_counter = TokenCounter( - input_tokens=Constants.DEFAULT_TOKEN_COUNT, - output_tokens=Constants.DEFAULT_TOKEN_COUNT, - ) + prompt_tokens = Constants.DEFAULT_TOKEN_COUNT + completion_tokens = Constants.DEFAULT_TOKEN_COUNT if EventPayload.PROMPT in payload: completion_raw = payload.get(EventPayload.COMPLETION).raw if completion_raw: - if hasattr(completion_raw, Constants.KEY_USAGE): - token_counts: dict[ - str, int - ] = TokenCounter._get_prompt_completion_tokens(completion_raw) - token_counter = TokenCounter( - input_tokens=token_counts[Constants.PROMPT_TOKENS], - output_tokens=token_counts[Constants.COMPLETION_TOKENS], - ) - elif hasattr(completion_raw, Constants.KEY_RAW_RESPONSE): - if hasattr( - completion_raw._raw_response, - Constants.KEY_USAGE_METADATA, - ): - usage = completion_raw._raw_response.usage_metadata - token_counter = TokenCounter( - input_tokens=usage.prompt_token_count, - output_tokens=usage.candidates_token_count, - ) - else: - prompt_tokens = Constants.DEFAULT_TOKEN_COUNT - completion_tokens = Constants.DEFAULT_TOKEN_COUNT - if hasattr(completion_raw, Constants.KEY_PROMPT_EVAL_COUNT): - prompt_tokens = completion_raw.prompt_eval_count - if hasattr(completion_raw, Constants.KEY_EVAL_COUNT): - completion_tokens = completion_raw.eval_count - token_counter = TokenCounter( - input_tokens=prompt_tokens, - output_tokens=completion_tokens, - ) + # For Open AI models, token count is part of ChatCompletion + if isinstance(completion_raw, ChatCompletion): + if hasattr(completion_raw, Constants.KEY_USAGE): + token_counts: dict[ + str, int + ] = TokenCounter._get_prompt_completion_tokens(completion_raw) + prompt_tokens = token_counts[Constants.PROMPT_TOKENS] + completion_tokens = token_counts[Constants.COMPLETION_TOKENS] + # For other models + elif isinstance(completion_raw, dict): + # For Gemini models + if completion_raw.get(Constants.KEY_RAW_RESPONSE): + if hasattr( + completion_raw.get(Constants.KEY_RAW_RESPONSE), + Constants.KEY_USAGE_METADATA, + ): + usage = completion_raw.get( + Constants.KEY_RAW_RESPONSE + ).usage_metadata + prompt_tokens = usage.prompt_token_count + completion_tokens = usage.candidates_token_count + elif completion_raw.get(Constants.KEY_USAGE): + token_counts: dict[ + str, int + ] = TokenCounter._get_prompt_completion_tokens(completion_raw) + prompt_tokens = token_counts[Constants.PROMPT_TOKENS] + completion_tokens = token_counts[Constants.COMPLETION_TOKENS] + # For Bedrock models + elif Constants.KEY_TEXT_TOKEN_COUNT in completion_raw: + prompt_tokens = completion_raw[Constants.KEY_TEXT_TOKEN_COUNT] + if Constants.KEY_RESULTS in completion_raw: + result_list: list = completion_raw[Constants.KEY_RESULTS] + if len(result_list) > 0: + result: dict = result_list[0] + if Constants.KEY_TOKEN_COUNT in result: + completion_tokens = result.get( + Constants.KEY_TOKEN_COUNT + ) + else: + if completion_raw.get(Constants.KEY_PROMPT_EVAL_COUNT): + prompt_tokens = completion_raw.get( + Constants.KEY_PROMPT_EVAL_COUNT + ) + if completion_raw.get(Constants.KEY_EVAL_COUNT): + completion_tokens = completion_raw.get( + Constants.KEY_EVAL_COUNT + ) + # For Anthropic models elif EventPayload.MESSAGES in payload: response_raw = payload.get(EventPayload.RESPONSE).raw if response_raw: token_counts: dict[ str, int ] = TokenCounter._get_prompt_completion_tokens(response_raw) - token_counter = TokenCounter( - input_tokens=token_counts[Constants.PROMPT_TOKENS], - output_tokens=token_counts[Constants.COMPLETION_TOKENS], - ) + prompt_tokens = token_counts[Constants.PROMPT_TOKENS] + completion_tokens = token_counts[Constants.COMPLETION_TOKENS] + token_counter = TokenCounter( + input_tokens=prompt_tokens, + output_tokens=completion_tokens, + ) return token_counter @staticmethod def _get_prompt_completion_tokens(response) -> dict[str, int]: + usage = None prompt_tokens = Constants.DEFAULT_TOKEN_COUNT completion_tokens = Constants.DEFAULT_TOKEN_COUNT - - if hasattr(response, Constants.KEY_USAGE): + # For OpenAI models,response is an obj of CompletionUsage + if ( + isinstance(response, ChatCompletion) + and hasattr(response, Constants.KEY_USAGE) + and isinstance(response.usage, CompletionUsage) + ): usage = response.usage + # For LLM models other than OpenAI, response is a dict + elif isinstance(response, dict) and Constants.KEY_USAGE in response: + usage = response.get(Constants.KEY_USAGE) + + if usage: if hasattr(usage, Constants.INPUT_TOKENS): prompt_tokens = usage.input_tokens elif hasattr(usage, Constants.PROMPT_TOKENS):